utils.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 // Package utils 工具集合 20 package utils 21 22 import ( 23 "archive/tar" 24 "archive/zip" 25 "bufio" 26 "bytes" 27 "compress/gzip" 28 "context" 29 "errors" 30 "encoding/base64" 31 "fmt" 32 "io" 33 "net" 34 "os" 35 "os/exec" 36 "path/filepath" 37 "runtime" 38 "strconv" 39 "strings" 40 "time" 41 42 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 43 44 "github.com/spaolacci/murmur3" 45 ) 46 47 // Duration2String 将时间段转换为可读的字符串格式 48 // 如果时间超过60秒则返回分钟,否则返回秒 49 func Duration2String(t time.Duration) string { 50 sceond := t.Seconds() 51 if sceond >= 60 { 52 return fmt.Sprintf("%.2f min", t.Minutes()) 53 } else { 54 return fmt.Sprintf("%.2f s", sceond) 55 } 56 } 57 58 // InsertInto 在字符串中每隔指定间隔插入分隔符 59 // s: 源字符串 60 // interval: 插入间隔 61 // sep: 分隔符 62 func InsertInto(s string, interval int, sep rune) string { 63 var buffer bytes.Buffer 64 before := interval - 1 65 last := len(s) - 1 66 for i, char := range s { 67 buffer.WriteRune(char) 68 if i%interval == before && i != last { 69 buffer.WriteRune(sep) 70 } 71 } 72 buffer.WriteRune(sep) 73 return buffer.String() 74 } 75 76 // FaviconHash 计算网站图标的哈希值 77 // 将数据进行base64编码后使用murmur3哈希算法计算 78 func FaviconHash(data []byte) int32 { 79 stdBase64 := base64.StdEncoding.EncodeToString(data) 80 stdBase64 = InsertInto(stdBase64, 76, '\n') 81 hasher := murmur3.New32WithSeed(0) 82 hasher.Write([]byte(stdBase64)) 83 return int32(hasher.Sum32()) 84 } 85 86 // ScanDir 递归扫描目录,返回所有文件的完整路径 87 // path: 要扫描的目录路径 88 // 返回文件路径列表和可能的错误 89 func ScanDir(path string) ([]string, error) { 90 files := make([]string, 0) 91 dir, err := os.ReadDir(path) 92 if err != nil { 93 return nil, err 94 } 95 for _, fi := range dir { 96 if fi.IsDir() { 97 newDir, err := ScanDir(filepath.Join(path, fi.Name())) 98 if err != nil { 99 return files, err 100 } 101 files = append(files, newDir...) 102 } else { 103 files = append(files, filepath.Join(path, fi.Name())) 104 } 105 } 106 return files, nil 107 } 108 109 // IsCIDR 检查给定的字符串是否为有效的CIDR格式 110 func IsCIDR(cidr string) bool { 111 _, _, err := net.ParseCIDR(cidr) 112 return err == nil 113 } 114 115 // IsFileExists 检查文件是否存在 116 // path: 文件路径 117 // 返回布尔值表示文件是否存在 118 func IsFileExists(path string) bool { 119 _, err := os.Stat(path) 120 if err != nil { 121 if os.IsExist(err) { 122 return true 123 } 124 return false 125 } 126 return true 127 } 128 129 // IsDir 判断给定路径是否为目录 130 // path: 待检查的路径 131 // 返回布尔值表示是否为目录 132 func IsDir(path string) bool { 133 s, err := os.Stat(path) 134 if err != nil { 135 return false 136 } 137 return s.IsDir() 138 } 139 140 // TrimProtocol 移除URL中的HTTP/HTTPS协议前缀 141 // targetURL: 目标URL 142 // 返回去除协议前缀后的URL 143 func TrimProtocol(targetURL string) string { 144 URL := strings.TrimSpace(targetURL) 145 if strings.HasPrefix(strings.ToLower(URL), "http://") || strings.HasPrefix(strings.ToLower(URL), "https://") { 146 URL = URL[strings.Index(URL, "//")+2:] 147 } 148 URL = strings.TrimRight(URL, "/") 149 return URL 150 } 151 152 // CompareVersions 比较两个版本号字符串 153 // version1, version2: 待比较的版本号 154 // 返回值: 1 表示 version1 大于 version2 155 // 156 // -1 表示 version1 小于 version2 157 // 0 表示两个版本号相等 158 func CompareVersions(version1, version2 string) int { 159 v1Parts := strings.Split(version1, ".") 160 v2Parts := strings.Split(version2, ".") 161 162 // Determine the max length to iterate over 163 maxLen := len(v1Parts) 164 if len(v2Parts) > maxLen { 165 maxLen = len(v2Parts) 166 } 167 168 for i := 0; i < maxLen; i++ { 169 var num1, num2 int 170 171 if i < len(v1Parts) { 172 num1, _ = strconv.Atoi(v1Parts[i]) 173 } 174 175 if i < len(v2Parts) { 176 num2, _ = strconv.Atoi(v2Parts[i]) 177 } 178 179 if num1 > num2 { 180 return 1 181 } else if num1 < num2 { 182 return -1 183 } 184 } 185 186 return 0 187 } 188 189 // GetMiddleText 获取两个字符串之间的文本内容 190 // left: 左边界字符串 191 // right: 右边界字符串 192 // html: 源文本 193 // 返回左右边界之间的文本,如果未找到则返回空字符串 194 func GetMiddleText(left, right, html string) string { 195 start := strings.Index(html, left) 196 if start == -1 { 197 return "" // 如果找不到 left,返回空字符串 198 } 199 start += len(left) 200 201 end := strings.Index(html[start:], right) 202 if end == -1 { 203 return "" // 如果找不到 right,返回空字符串 204 } 205 end += start 206 207 return html[start:end] 208 } 209 210 // PortInfo 存储端口和地址信息 211 type PortInfo struct { 212 Port int 213 Address string 214 } 215 216 // GetLocalOpenPorts 获取本地开放的端口及其地址信息 217 func GetLocalOpenPorts() ([]PortInfo, error) { 218 var portInfos []PortInfo 219 switch runtime.GOOS { 220 case "windows": 221 cmd := exec.Command("netstat", "-an") 222 output, err := cmd.Output() 223 if err != nil { 224 return nil, fmt.Errorf("执行netstat命令失败: %v", err) 225 } 226 227 scanner := bufio.NewScanner(strings.NewReader(string(output))) 228 for scanner.Scan() { 229 line := scanner.Text() 230 if strings.Contains(line, "LISTENING") { 231 parts := strings.Fields(line) 232 if len(parts) >= 2 { 233 addrPort := strings.Split(parts[1], ":") 234 if len(addrPort) == 2 { 235 port, err := strconv.Atoi(addrPort[1]) 236 if err == nil { 237 addr := addrPort[0] 238 portInfos = append(portInfos, PortInfo{ 239 Port: port, 240 Address: addr, 241 }) 242 } 243 } 244 } 245 } 246 } 247 248 case "darwin", "linux": 249 cmd := exec.Command("lsof", "-i", "-P", "-n") 250 output, err := cmd.Output() 251 if err != nil { 252 return nil, fmt.Errorf("执行lsof命令失败: %v", err) 253 } 254 255 scanner := bufio.NewScanner(strings.NewReader(string(output))) 256 for scanner.Scan() { 257 line := scanner.Text() 258 if strings.Contains(line, "LISTEN") { 259 parts := strings.Fields(line) 260 for _, part := range parts { 261 if strings.Contains(part, ":") { 262 addrPort := strings.Split(part, ":") 263 if len(addrPort) == 2 { 264 port, err := strconv.Atoi(addrPort[1]) 265 if err == nil { 266 addr := addrPort[0] 267 if addr == "*" || addr == "0.0.0.0" { 268 addr = "0.0.0.0" 269 } else if addr == "127.0.0.1" || addr == "localhost" { 270 addr = "127.0.0.1" 271 } 272 portInfos = append(portInfos, PortInfo{ 273 Port: port, 274 Address: addr, 275 }) 276 } 277 } 278 } 279 } 280 } 281 } 282 283 default: 284 return nil, fmt.Errorf("不支持的操作系统: %s", runtime.GOOS) 285 } 286 287 // 去重 288 seen := make(map[string]bool) 289 var result []PortInfo 290 for _, info := range portInfos { 291 key := fmt.Sprintf("%s:%d", info.Address, info.Port) 292 if !seen[key] { 293 seen[key] = true 294 result = append(result, info) 295 } 296 } 297 298 return result, nil 299 } 300 301 // ExtractZipFile 解压ZIP文件 302 func ExtractZipFile(zipFile string, destPath string) error { 303 // 打开ZIP文件 304 reader, err := zip.OpenReader(zipFile) 305 if err != nil { 306 return fmt.Errorf("打开ZIP文件失败: %v", err) 307 } 308 defer reader.Close() 309 310 // 确保目标目录存在 311 if err := os.MkdirAll(destPath, 0755); err != nil { 312 return fmt.Errorf("创建目标目录失败: %v", err) 313 } 314 315 // 解压文件 316 for _, file := range reader.File { 317 // 检查文件路径是否安全 318 filePath := filepath.Join(destPath, file.Name) 319 if !strings.HasPrefix(filePath, filepath.Clean(destPath)+string(os.PathSeparator)) { 320 gologger.Errorln(fmt.Sprintf("不安全的路径: %s", file.Name)) 321 continue 322 } 323 324 // 创建目录 325 if file.FileInfo().IsDir() { 326 if err := os.MkdirAll(filePath, 0755); err != nil { 327 return fmt.Errorf("创建目录失败: %v", err) 328 } 329 continue 330 } 331 332 // 确保文件的父目录存在 333 if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { 334 return fmt.Errorf("创建父目录失败: %v", err) 335 } 336 337 // 创建文件 338 outFile, err := os.Create(filePath) 339 if err != nil { 340 return fmt.Errorf("创建文件失败: %v", err) 341 } 342 defer outFile.Close() 343 344 // 打开文件内容 345 rc, err := file.Open() 346 if err != nil { 347 return fmt.Errorf("打开压缩文件内容失败: %v", err) 348 } 349 defer rc.Close() 350 351 // 复制内容 352 if _, err := io.Copy(outFile, rc); err != nil { 353 return fmt.Errorf("复制文件内容失败: %v", err) 354 } 355 } 356 357 return nil 358 } 359 360 // ExtractTGZ 文件解压 361 func ExtractTGZ(src, dest string) error { 362 // 打开 .tgz 文件 363 file, err := os.Open(src) 364 if err != nil { 365 return err 366 } 367 defer file.Close() 368 369 // 创建 gzip Reader 370 gzr, err := gzip.NewReader(file) 371 if err != nil { 372 return err 373 } 374 defer gzr.Close() 375 376 // 创建 tar Reader 377 tr := tar.NewReader(gzr) 378 379 // 遍历 tar 文件中的每个条目 380 for { 381 header, err := tr.Next() 382 if err == io.EOF { 383 break // 读取完毕 384 } 385 if err != nil { 386 return err 387 } 388 389 // 安全处理目标路径,防止路径穿越攻击 390 targetPath, err := safePath(dest, header.Name) 391 if err != nil { 392 return err 393 } 394 395 // 根据文件类型处理 396 switch header.Typeflag { 397 case tar.TypeDir: // 目录 398 if err := os.MkdirAll(targetPath, 0755); err != nil { 399 return err 400 } 401 case tar.TypeReg: // 普通文件 402 if err := writeFile(targetPath, tr, header.Mode); err != nil { 403 return err 404 } 405 // 可选:处理符号链接等其他类型 406 default: 407 fmt.Printf("未处理类型: %v in %s\n", header.Typeflag, header.Name) 408 } 409 } 410 return nil 411 } 412 413 // 安全路径检查,防止路径穿越 414 func safePath(dest, name string) (string, error) { 415 cleanedDest := filepath.Clean(dest) 416 cleanedPath := filepath.Clean(filepath.Join(cleanedDest, name)) 417 418 // 检查目标路径是否在目标目录下,防止路径穿越(path traversal)攻击 419 if !strings.HasPrefix(cleanedPath, cleanedDest+string(os.PathSeparator)) && cleanedPath != cleanedDest { 420 return "", fmt.Errorf("非法路径: %s", name) 421 } 422 return cleanedPath, nil 423 } 424 425 // 写入文件内容 426 func writeFile(path string, r io.Reader, mode int64) error { 427 // 确保目录存在 428 if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { 429 return err 430 } 431 432 // 创建文件并设置权限 433 file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(mode)) 434 if err != nil { 435 return err 436 } 437 defer file.Close() 438 439 // 复制内容 440 if _, err := io.Copy(file, r); err != nil { 441 return err 442 } 443 return nil 444 } 445 446 // GitClone 克隆Git仓库 447 func GitClone(repoURL, targetDir string, timeout time.Duration) error { 448 var err error 449 for i := 0; i < 3; i++ { 450 err = func() error { 451 ctx := context.Background() 452 ctx, cancel := context.WithTimeout(ctx, timeout) 453 defer cancel() 454 cmd := exec.CommandContext(ctx, "git", "clone", "--", repoURL, targetDir) 455 done := make(chan error) 456 go func() { 457 _, err := cmd.CombinedOutput() 458 done <- err 459 }() 460 461 select { 462 case <-ctx.Done(): 463 _ = cmd.Process.Kill() 464 return fmt.Errorf("操作超时") 465 case err = <-done: 466 return err 467 } 468 }() 469 if err == nil { 470 return nil 471 } 472 } 473 return err 474 } 475 476 func RunCmd(dir, name string, arg []string, callback func(line string)) error { 477 return RunCmdWithContext(context.Background(), dir, name, arg, callback) 478 } 479 480 func RunCmdWithContext(ctx context.Context, dir, name string, arg []string, callback func(line string)) error { 481 if ctx == nil { 482 ctx = context.Background() 483 } 484 485 // 命令行执行,stdio读取 486 cmd := exec.CommandContext(ctx, name, arg...) 487 cmd.Dir = dir 488 cmd.Env = os.Environ() 489 // 获取命令行 490 cmdStr := name + " " + strings.Join(arg, " ") 491 gologger.Infof("开始执行命令: %s", cmdStr) 492 // 使用管道获取标准输出 493 stdout, err := cmd.StdoutPipe() 494 if err != nil { 495 return err 496 } 497 cmd.Stderr = cmd.Stdout // 将错误输出合并到标准输出 498 499 // 启动扫描器goroutine 500 scanner := bufio.NewScanner(stdout) 501 // 设置更大的缓冲区以处理超长文本行 502 // 默认64KB,这里设置为1MB 503 const maxCapacity = 1024 * 1024 * 10 // 1MB 504 buf := make([]byte, 0, 64*1024) 505 scanner.Buffer(buf, maxCapacity) 506 507 done := make(chan error) // 改为传递错误信息 508 go func() { 509 defer close(done) 510 for scanner.Scan() { 511 line := scanner.Text() 512 callback(line) 513 } 514 // 检查扫描器是否遇到错误 515 if err := scanner.Err(); err != nil { 516 // 管道关闭是正常的结束条件,不应视为错误 517 if strings.Contains(err.Error(), "file already closed") || 518 strings.Contains(err.Error(), "broken pipe") { 519 done <- nil 520 return 521 } 522 done <- fmt.Errorf("读取输出时发生错误: %v", err) 523 return 524 } 525 done <- nil 526 }() 527 528 // 启动命令 529 if err = cmd.Start(); err != nil { 530 return err 531 } 532 533 // 等待命令执行完成 534 cmdErr := cmd.Wait() 535 536 // 等待读取完成并检查读取错误 537 readErr := <-done 538 539 // 优先返回读取错误,其次返回命令执行错误 540 if errors.Is(ctx.Err(), context.Canceled) { 541 return ctx.Err() 542 } 543 if readErr != nil { 544 return readErr 545 } 546 if cmdErr != nil { 547 return cmdErr 548 } 549 550 return nil 551 } 552 553 func IsHostname(hostname string) bool { 554 ips := strings.Split(hostname, ":") 555 if len(ips) != 2 { 556 return false 557 } 558 p := net.ParseIP(strings.TrimSpace(ips[0])) 559 if p == nil { 560 return false 561 } 562 return true 563 } 564 565 // StrInSlice checks if a string is in a slice of strings. 566 func StrInSlice(str string, list []string) bool { 567 for _, v := range list { 568 if v == str { 569 return true 570 } 571 } 572 return false 573 }