download.go
1 package util 2 3 import ( 4 "fmt" 5 "go.uber.org/zap" 6 "io" 7 "krillin-ai/config" 8 "krillin-ai/log" 9 "net/http" 10 "os" 11 "time" 12 ) 13 14 // 用于显示下载进度,实现io.Writer 15 type progressWriter struct { 16 Total uint64 17 Downloaded uint64 18 StartTime time.Time 19 } 20 21 func (pw *progressWriter) Write(p []byte) (int, error) { 22 n := len(p) 23 pw.Downloaded += uint64(n) 24 25 // 初始化开始时间 26 if pw.StartTime.IsZero() { 27 pw.StartTime = time.Now() 28 } 29 30 percent := float64(pw.Downloaded) / float64(pw.Total) * 100 31 elapsed := time.Since(pw.StartTime).Seconds() 32 speed := float64(pw.Downloaded) / 1024 / 1024 / elapsed 33 34 fmt.Printf("\r下载进度: %.2f%% (%.2f MB / %.2f MB) | 速度: %.2f MB/s", 35 percent, 36 float64(pw.Downloaded)/1024/1024, 37 float64(pw.Total)/1024/1024, 38 speed) 39 40 return n, nil 41 } 42 43 // DownloadFile 下载文件并保存到指定路径,支持代理 44 func DownloadFile(urlStr, filepath, proxyAddr string) error { 45 log.GetLogger().Info("开始下载文件", zap.String("url", urlStr)) 46 client := &http.Client{} 47 if proxyAddr != "" { 48 client.Transport = &http.Transport{ 49 Proxy: http.ProxyURL(config.Conf.App.ParsedProxy), 50 } 51 } 52 53 resp, err := client.Get(urlStr) 54 if err != nil { 55 return err 56 } 57 defer resp.Body.Close() 58 59 size := resp.ContentLength 60 fmt.Printf("文件大小: %.2f MB\n", float64(size)/1024/1024) 61 62 out, err := os.Create(filepath) 63 if err != nil { 64 return err 65 } 66 defer out.Close() 67 68 // 带有进度的 Reader 69 progress := &progressWriter{ 70 Total: uint64(size), 71 } 72 reader := io.TeeReader(resp.Body, progress) 73 74 _, err = io.Copy(out, reader) 75 if err != nil { 76 return err 77 } 78 fmt.Printf("\n") // 进度信息结束,换新行 79 80 log.GetLogger().Info("文件下载完成", zap.String("路径", filepath)) 81 return nil 82 }