/ common / utils / utils.go
utils.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  // Package utils 工具集合
 20  package utils
 21  
 22  import (
 23  	"archive/tar"
 24  	"archive/zip"
 25  	"bufio"
 26  	"bytes"
 27  	"compress/gzip"
 28  	"context"
 29  	"errors"
 30  	"encoding/base64"
 31  	"fmt"
 32  	"io"
 33  	"net"
 34  	"os"
 35  	"os/exec"
 36  	"path/filepath"
 37  	"runtime"
 38  	"strconv"
 39  	"strings"
 40  	"time"
 41  
 42  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 43  
 44  	"github.com/spaolacci/murmur3"
 45  )
 46  
 47  // Duration2String 将时间段转换为可读的字符串格式
 48  // 如果时间超过60秒则返回分钟,否则返回秒
 49  func Duration2String(t time.Duration) string {
 50  	sceond := t.Seconds()
 51  	if sceond >= 60 {
 52  		return fmt.Sprintf("%.2f min", t.Minutes())
 53  	} else {
 54  		return fmt.Sprintf("%.2f s", sceond)
 55  	}
 56  }
 57  
 58  // InsertInto 在字符串中每隔指定间隔插入分隔符
 59  // s: 源字符串
 60  // interval: 插入间隔
 61  // sep: 分隔符
 62  func InsertInto(s string, interval int, sep rune) string {
 63  	var buffer bytes.Buffer
 64  	before := interval - 1
 65  	last := len(s) - 1
 66  	for i, char := range s {
 67  		buffer.WriteRune(char)
 68  		if i%interval == before && i != last {
 69  			buffer.WriteRune(sep)
 70  		}
 71  	}
 72  	buffer.WriteRune(sep)
 73  	return buffer.String()
 74  }
 75  
 76  // FaviconHash 计算网站图标的哈希值
 77  // 将数据进行base64编码后使用murmur3哈希算法计算
 78  func FaviconHash(data []byte) int32 {
 79  	stdBase64 := base64.StdEncoding.EncodeToString(data)
 80  	stdBase64 = InsertInto(stdBase64, 76, '\n')
 81  	hasher := murmur3.New32WithSeed(0)
 82  	hasher.Write([]byte(stdBase64))
 83  	return int32(hasher.Sum32())
 84  }
 85  
 86  // ScanDir 递归扫描目录,返回所有文件的完整路径
 87  // path: 要扫描的目录路径
 88  // 返回文件路径列表和可能的错误
 89  func ScanDir(path string) ([]string, error) {
 90  	files := make([]string, 0)
 91  	dir, err := os.ReadDir(path)
 92  	if err != nil {
 93  		return nil, err
 94  	}
 95  	for _, fi := range dir {
 96  		if fi.IsDir() {
 97  			newDir, err := ScanDir(filepath.Join(path, fi.Name()))
 98  			if err != nil {
 99  				return files, err
100  			}
101  			files = append(files, newDir...)
102  		} else {
103  			files = append(files, filepath.Join(path, fi.Name()))
104  		}
105  	}
106  	return files, nil
107  }
108  
109  // IsCIDR 检查给定的字符串是否为有效的CIDR格式
110  func IsCIDR(cidr string) bool {
111  	_, _, err := net.ParseCIDR(cidr)
112  	return err == nil
113  }
114  
115  // IsFileExists 检查文件是否存在
116  // path: 文件路径
117  // 返回布尔值表示文件是否存在
118  func IsFileExists(path string) bool {
119  	_, err := os.Stat(path)
120  	if err != nil {
121  		if os.IsExist(err) {
122  			return true
123  		}
124  		return false
125  	}
126  	return true
127  }
128  
129  // IsDir 判断给定路径是否为目录
130  // path: 待检查的路径
131  // 返回布尔值表示是否为目录
132  func IsDir(path string) bool {
133  	s, err := os.Stat(path)
134  	if err != nil {
135  		return false
136  	}
137  	return s.IsDir()
138  }
139  
140  // TrimProtocol 移除URL中的HTTP/HTTPS协议前缀
141  // targetURL: 目标URL
142  // 返回去除协议前缀后的URL
143  func TrimProtocol(targetURL string) string {
144  	URL := strings.TrimSpace(targetURL)
145  	if strings.HasPrefix(strings.ToLower(URL), "http://") || strings.HasPrefix(strings.ToLower(URL), "https://") {
146  		URL = URL[strings.Index(URL, "//")+2:]
147  	}
148  	URL = strings.TrimRight(URL, "/")
149  	return URL
150  }
151  
152  // CompareVersions 比较两个版本号字符串
153  // version1, version2: 待比较的版本号
154  // 返回值: 1 表示 version1 大于 version2
155  //
156  //	-1 表示 version1 小于 version2
157  //	 0 表示两个版本号相等
158  func CompareVersions(version1, version2 string) int {
159  	v1Parts := strings.Split(version1, ".")
160  	v2Parts := strings.Split(version2, ".")
161  
162  	// Determine the max length to iterate over
163  	maxLen := len(v1Parts)
164  	if len(v2Parts) > maxLen {
165  		maxLen = len(v2Parts)
166  	}
167  
168  	for i := 0; i < maxLen; i++ {
169  		var num1, num2 int
170  
171  		if i < len(v1Parts) {
172  			num1, _ = strconv.Atoi(v1Parts[i])
173  		}
174  
175  		if i < len(v2Parts) {
176  			num2, _ = strconv.Atoi(v2Parts[i])
177  		}
178  
179  		if num1 > num2 {
180  			return 1
181  		} else if num1 < num2 {
182  			return -1
183  		}
184  	}
185  
186  	return 0
187  }
188  
189  // GetMiddleText 获取两个字符串之间的文本内容
190  // left: 左边界字符串
191  // right: 右边界字符串
192  // html: 源文本
193  // 返回左右边界之间的文本,如果未找到则返回空字符串
194  func GetMiddleText(left, right, html string) string {
195  	start := strings.Index(html, left)
196  	if start == -1 {
197  		return "" // 如果找不到 left,返回空字符串
198  	}
199  	start += len(left)
200  
201  	end := strings.Index(html[start:], right)
202  	if end == -1 {
203  		return "" // 如果找不到 right,返回空字符串
204  	}
205  	end += start
206  
207  	return html[start:end]
208  }
209  
210  // PortInfo 存储端口和地址信息
211  type PortInfo struct {
212  	Port    int
213  	Address string
214  }
215  
216  // GetLocalOpenPorts 获取本地开放的端口及其地址信息
217  func GetLocalOpenPorts() ([]PortInfo, error) {
218  	var portInfos []PortInfo
219  	switch runtime.GOOS {
220  	case "windows":
221  		cmd := exec.Command("netstat", "-an")
222  		output, err := cmd.Output()
223  		if err != nil {
224  			return nil, fmt.Errorf("执行netstat命令失败: %v", err)
225  		}
226  
227  		scanner := bufio.NewScanner(strings.NewReader(string(output)))
228  		for scanner.Scan() {
229  			line := scanner.Text()
230  			if strings.Contains(line, "LISTENING") {
231  				parts := strings.Fields(line)
232  				if len(parts) >= 2 {
233  					addrPort := strings.Split(parts[1], ":")
234  					if len(addrPort) == 2 {
235  						port, err := strconv.Atoi(addrPort[1])
236  						if err == nil {
237  							addr := addrPort[0]
238  							portInfos = append(portInfos, PortInfo{
239  								Port:    port,
240  								Address: addr,
241  							})
242  						}
243  					}
244  				}
245  			}
246  		}
247  
248  	case "darwin", "linux":
249  		cmd := exec.Command("lsof", "-i", "-P", "-n")
250  		output, err := cmd.Output()
251  		if err != nil {
252  			return nil, fmt.Errorf("执行lsof命令失败: %v", err)
253  		}
254  
255  		scanner := bufio.NewScanner(strings.NewReader(string(output)))
256  		for scanner.Scan() {
257  			line := scanner.Text()
258  			if strings.Contains(line, "LISTEN") {
259  				parts := strings.Fields(line)
260  				for _, part := range parts {
261  					if strings.Contains(part, ":") {
262  						addrPort := strings.Split(part, ":")
263  						if len(addrPort) == 2 {
264  							port, err := strconv.Atoi(addrPort[1])
265  							if err == nil {
266  								addr := addrPort[0]
267  								if addr == "*" || addr == "0.0.0.0" {
268  									addr = "0.0.0.0"
269  								} else if addr == "127.0.0.1" || addr == "localhost" {
270  									addr = "127.0.0.1"
271  								}
272  								portInfos = append(portInfos, PortInfo{
273  									Port:    port,
274  									Address: addr,
275  								})
276  							}
277  						}
278  					}
279  				}
280  			}
281  		}
282  
283  	default:
284  		return nil, fmt.Errorf("不支持的操作系统: %s", runtime.GOOS)
285  	}
286  
287  	// 去重
288  	seen := make(map[string]bool)
289  	var result []PortInfo
290  	for _, info := range portInfos {
291  		key := fmt.Sprintf("%s:%d", info.Address, info.Port)
292  		if !seen[key] {
293  			seen[key] = true
294  			result = append(result, info)
295  		}
296  	}
297  
298  	return result, nil
299  }
300  
301  // ExtractZipFile 解压ZIP文件
302  func ExtractZipFile(zipFile string, destPath string) error {
303  	// 打开ZIP文件
304  	reader, err := zip.OpenReader(zipFile)
305  	if err != nil {
306  		return fmt.Errorf("打开ZIP文件失败: %v", err)
307  	}
308  	defer reader.Close()
309  
310  	// 确保目标目录存在
311  	if err := os.MkdirAll(destPath, 0755); err != nil {
312  		return fmt.Errorf("创建目标目录失败: %v", err)
313  	}
314  
315  	// 解压文件
316  	for _, file := range reader.File {
317  		// 检查文件路径是否安全
318  		filePath := filepath.Join(destPath, file.Name)
319  		if !strings.HasPrefix(filePath, filepath.Clean(destPath)+string(os.PathSeparator)) {
320  			gologger.Errorln(fmt.Sprintf("不安全的路径: %s", file.Name))
321  			continue
322  		}
323  
324  		// 创建目录
325  		if file.FileInfo().IsDir() {
326  			if err := os.MkdirAll(filePath, 0755); err != nil {
327  				return fmt.Errorf("创建目录失败: %v", err)
328  			}
329  			continue
330  		}
331  
332  		// 确保文件的父目录存在
333  		if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
334  			return fmt.Errorf("创建父目录失败: %v", err)
335  		}
336  
337  		// 创建文件
338  		outFile, err := os.Create(filePath)
339  		if err != nil {
340  			return fmt.Errorf("创建文件失败: %v", err)
341  		}
342  		defer outFile.Close()
343  
344  		// 打开文件内容
345  		rc, err := file.Open()
346  		if err != nil {
347  			return fmt.Errorf("打开压缩文件内容失败: %v", err)
348  		}
349  		defer rc.Close()
350  
351  		// 复制内容
352  		if _, err := io.Copy(outFile, rc); err != nil {
353  			return fmt.Errorf("复制文件内容失败: %v", err)
354  		}
355  	}
356  
357  	return nil
358  }
359  
360  // ExtractTGZ 文件解压
361  func ExtractTGZ(src, dest string) error {
362  	// 打开 .tgz 文件
363  	file, err := os.Open(src)
364  	if err != nil {
365  		return err
366  	}
367  	defer file.Close()
368  
369  	// 创建 gzip Reader
370  	gzr, err := gzip.NewReader(file)
371  	if err != nil {
372  		return err
373  	}
374  	defer gzr.Close()
375  
376  	// 创建 tar Reader
377  	tr := tar.NewReader(gzr)
378  
379  	// 遍历 tar 文件中的每个条目
380  	for {
381  		header, err := tr.Next()
382  		if err == io.EOF {
383  			break // 读取完毕
384  		}
385  		if err != nil {
386  			return err
387  		}
388  
389  		// 安全处理目标路径,防止路径穿越攻击
390  		targetPath, err := safePath(dest, header.Name)
391  		if err != nil {
392  			return err
393  		}
394  
395  		// 根据文件类型处理
396  		switch header.Typeflag {
397  		case tar.TypeDir: // 目录
398  			if err := os.MkdirAll(targetPath, 0755); err != nil {
399  				return err
400  			}
401  		case tar.TypeReg: // 普通文件
402  			if err := writeFile(targetPath, tr, header.Mode); err != nil {
403  				return err
404  			}
405  		// 可选:处理符号链接等其他类型
406  		default:
407  			fmt.Printf("未处理类型: %v in %s\n", header.Typeflag, header.Name)
408  		}
409  	}
410  	return nil
411  }
412  
413  // 安全路径检查,防止路径穿越
414  func safePath(dest, name string) (string, error) {
415  	cleanedDest := filepath.Clean(dest)
416  	cleanedPath := filepath.Clean(filepath.Join(cleanedDest, name))
417  
418  	// 检查目标路径是否在目标目录下,防止路径穿越(path traversal)攻击
419  	if !strings.HasPrefix(cleanedPath, cleanedDest+string(os.PathSeparator)) && cleanedPath != cleanedDest {
420  		return "", fmt.Errorf("非法路径: %s", name)
421  	}
422  	return cleanedPath, nil
423  }
424  
425  // 写入文件内容
426  func writeFile(path string, r io.Reader, mode int64) error {
427  	// 确保目录存在
428  	if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
429  		return err
430  	}
431  
432  	// 创建文件并设置权限
433  	file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(mode))
434  	if err != nil {
435  		return err
436  	}
437  	defer file.Close()
438  
439  	// 复制内容
440  	if _, err := io.Copy(file, r); err != nil {
441  		return err
442  	}
443  	return nil
444  }
445  
446  // GitClone 克隆Git仓库
447  func GitClone(repoURL, targetDir string, timeout time.Duration) error {
448  	var err error
449  	for i := 0; i < 3; i++ {
450  		err = func() error {
451  			ctx := context.Background()
452  			ctx, cancel := context.WithTimeout(ctx, timeout)
453  			defer cancel()
454  			cmd := exec.CommandContext(ctx, "git", "clone", "--", repoURL, targetDir)
455  			done := make(chan error)
456  			go func() {
457  				_, err := cmd.CombinedOutput()
458  				done <- err
459  			}()
460  
461  			select {
462  			case <-ctx.Done():
463  				_ = cmd.Process.Kill()
464  				return fmt.Errorf("操作超时")
465  			case err = <-done:
466  				return err
467  			}
468  		}()
469  		if err == nil {
470  			return nil
471  		}
472  	}
473  	return err
474  }
475  
476  func RunCmd(dir, name string, arg []string, callback func(line string)) error {
477  	return RunCmdWithContext(context.Background(), dir, name, arg, callback)
478  }
479  
480  func RunCmdWithContext(ctx context.Context, dir, name string, arg []string, callback func(line string)) error {
481  	if ctx == nil {
482  		ctx = context.Background()
483  	}
484  
485  	// 命令行执行,stdio读取
486  	cmd := exec.CommandContext(ctx, name, arg...)
487  	cmd.Dir = dir
488  	cmd.Env = os.Environ()
489  	// 获取命令行
490  	cmdStr := name + " " + strings.Join(arg, " ")
491  	gologger.Infof("开始执行命令: %s", cmdStr)
492  	// 使用管道获取标准输出
493  	stdout, err := cmd.StdoutPipe()
494  	if err != nil {
495  		return err
496  	}
497  	cmd.Stderr = cmd.Stdout // 将错误输出合并到标准输出
498  
499  	// 启动扫描器goroutine
500  	scanner := bufio.NewScanner(stdout)
501  	// 设置更大的缓冲区以处理超长文本行
502  	// 默认64KB,这里设置为1MB
503  	const maxCapacity = 1024 * 1024 * 10 // 1MB
504  	buf := make([]byte, 0, 64*1024)
505  	scanner.Buffer(buf, maxCapacity)
506  
507  	done := make(chan error) // 改为传递错误信息
508  	go func() {
509  		defer close(done)
510  		for scanner.Scan() {
511  			line := scanner.Text()
512  			callback(line)
513  		}
514  		// 检查扫描器是否遇到错误
515  		if err := scanner.Err(); err != nil {
516  			// 管道关闭是正常的结束条件,不应视为错误
517  			if strings.Contains(err.Error(), "file already closed") ||
518  				strings.Contains(err.Error(), "broken pipe") {
519  				done <- nil
520  				return
521  			}
522  			done <- fmt.Errorf("读取输出时发生错误: %v", err)
523  			return
524  		}
525  		done <- nil
526  	}()
527  
528  	// 启动命令
529  	if err = cmd.Start(); err != nil {
530  		return err
531  	}
532  
533  	// 等待命令执行完成
534  	cmdErr := cmd.Wait()
535  
536  	// 等待读取完成并检查读取错误
537  	readErr := <-done
538  
539  	// 优先返回读取错误,其次返回命令执行错误
540  	if errors.Is(ctx.Err(), context.Canceled) {
541  		return ctx.Err()
542  	}
543  	if readErr != nil {
544  		return readErr
545  	}
546  	if cmdErr != nil {
547  		return cmdErr
548  	}
549  
550  	return nil
551  }
552  
553  func IsHostname(hostname string) bool {
554  	ips := strings.Split(hostname, ":")
555  	if len(ips) != 2 {
556  		return false
557  	}
558  	p := net.ParseIP(strings.TrimSpace(ips[0]))
559  	if p == nil {
560  		return false
561  	}
562  	return true
563  }
564  
565  // StrInSlice checks if a string is in a slice of strings.
566  func StrInSlice(str string, list []string) bool {
567  	for _, v := range list {
568  		if v == str {
569  			return true
570  		}
571  	}
572  	return false
573  }