/ pkg / util / download.go
download.go
 1  package util
 2  
 3  import (
 4  	"fmt"
 5  	"go.uber.org/zap"
 6  	"io"
 7  	"krillin-ai/config"
 8  	"krillin-ai/log"
 9  	"net/http"
10  	"os"
11  	"time"
12  )
13  
14  // 用于显示下载进度,实现io.Writer
15  type progressWriter struct {
16  	Total      uint64
17  	Downloaded uint64
18  	StartTime  time.Time
19  }
20  
21  func (pw *progressWriter) Write(p []byte) (int, error) {
22  	n := len(p)
23  	pw.Downloaded += uint64(n)
24  
25  	// 初始化开始时间
26  	if pw.StartTime.IsZero() {
27  		pw.StartTime = time.Now()
28  	}
29  
30  	percent := float64(pw.Downloaded) / float64(pw.Total) * 100
31  	elapsed := time.Since(pw.StartTime).Seconds()
32  	speed := float64(pw.Downloaded) / 1024 / 1024 / elapsed
33  
34  	fmt.Printf("\r下载进度: %.2f%% (%.2f MB / %.2f MB) | 速度: %.2f MB/s",
35  		percent,
36  		float64(pw.Downloaded)/1024/1024,
37  		float64(pw.Total)/1024/1024,
38  		speed)
39  
40  	return n, nil
41  }
42  
43  // DownloadFile 下载文件并保存到指定路径,支持代理
44  func DownloadFile(urlStr, filepath, proxyAddr string) error {
45  	log.GetLogger().Info("开始下载文件", zap.String("url", urlStr))
46  	client := &http.Client{}
47  	if proxyAddr != "" {
48  		client.Transport = &http.Transport{
49  			Proxy: http.ProxyURL(config.Conf.App.ParsedProxy),
50  		}
51  	}
52  
53  	resp, err := client.Get(urlStr)
54  	if err != nil {
55  		return err
56  	}
57  	defer resp.Body.Close()
58  
59  	size := resp.ContentLength
60  	fmt.Printf("文件大小: %.2f MB\n", float64(size)/1024/1024)
61  
62  	out, err := os.Create(filepath)
63  	if err != nil {
64  		return err
65  	}
66  	defer out.Close()
67  
68  	// 带有进度的 Reader
69  	progress := &progressWriter{
70  		Total: uint64(size),
71  	}
72  	reader := io.TeeReader(resp.Body, progress)
73  
74  	_, err = io.Copy(out, reader)
75  	if err != nil {
76  		return err
77  	}
78  	fmt.Printf("\n") // 进度信息结束,换新行
79  
80  	log.GetLogger().Info("文件下载完成", zap.String("路径", filepath))
81  	return nil
82  }