/ internal / mcp / utils / utils.go
utils.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package utils
 20  
 21  import (
 22  	"bufio"
 23  	"context"
 24  	"fmt"
 25  	"io"
 26  	"io/fs"
 27  	"os"
 28  	"path/filepath"
 29  	"regexp"
 30  	"strings"
 31  
 32  	"github.com/mark3labs/mcp-go/client"
 33  	"github.com/mark3labs/mcp-go/mcp"
 34  )
 35  
 36  var ignoreFiles = []string{
 37  	// 系统文件
 38  	".DS_Store", "Thumbs.db",
 39  
 40  	// 版本控制相关
 41  	".gitignore", ".gitattributes", ".gitmodules", ".gitkeep", ".git", ".svn",
 42  
 43  	// 环境配置文件
 44  	".env", "env", ".env.local", ".env.example", ".env.test", ".env.production",
 45  
 46  	// Node.js/npm相关
 47  	"package.json", "package-lock.json", "yarn.lock", "pnpm-lock.yaml", "uv.lock",
 48  	".npmrc", ".yarnrc", ".yarn-integrity",
 49  
 50  	// Python相关
 51  	"Pipfile", "Pipfile.lock", "poetry.lock", "requirements.txt", "setup.py",
 52  
 53  	// Java相关
 54  	"pom.xml", "build.gradle", "gradle.properties",
 55  
 56  	// Ruby相关
 57  	"Gemfile", "Gemfile.lock",
 58  
 59  	// IDE和编辑器配置
 60  	".idea", ".vscode", ".editorconfig", ".project",
 61  
 62  	// 构建工具配置
 63  	"webpack.config.js", "rollup.config.js", "gulpfile.js", "gruntfile.js",
 64  	"tsconfig.json", "jsconfig.json", "babel.config.js", ".babelrc",
 65  
 66  	// 测试相关
 67  	"jest.config.js", "karma.conf.js", ".mocharc.json",
 68  
 69  	// 其他常见配置文件
 70  	"dockerfile", ".dockerignore", "composer.json", "composer.lock",
 71  	"Makefile", "CMakeLists.txt",
 72  }
 73  
 74  func IsIgnoreFile(path string) bool {
 75  	for _, ignoreFile := range ignoreFiles {
 76  		if ignoreFile == filepath.Base(path) {
 77  			return true
 78  		}
 79  	}
 80  	return false
 81  }
 82  
 83  type Agent interface {
 84  	GetHistory() []map[string]string
 85  }
 86  
 87  // ListDir 递归列出目录结构并生成树形图
 88  // dir: 要列出的目录路径
 89  // maxLevel: 最大递归深度(0表示不限制)
 90  func ListDir(dir string, maxLevel int, exts string) (string, error) {
 91  	var builder strings.Builder
 92  	err := listDirRecursive(dir, 0, &builder, []bool{}, maxLevel, strings.Split(exts, ","))
 93  	if err != nil {
 94  		return "", err
 95  	}
 96  	return builder.String(), nil
 97  }
 98  
 99  // formatFileSize 格式化文件大小,返回带单位的字符串
100  func formatFileSize(size int64) string {
101  	if size < 1024 {
102  		return fmt.Sprintf("%dB", size)
103  	} else if size < 1024*1024 {
104  		return fmt.Sprintf("%.1fKB", float64(size)/1024)
105  	} else if size < 1024*1024*1024 {
106  		return fmt.Sprintf("%.1fMB", float64(size)/(1024*1024))
107  	} else {
108  		return fmt.Sprintf("%.1fGB", float64(size)/(1024*1024*1024))
109  	}
110  }
111  
112  func listDirRecursive(dir string, depth int, builder *strings.Builder, hasLast []bool, maxLevel int, exts []string) error {
113  	if maxLevel > 0 && depth >= maxLevel {
114  		return nil
115  	}
116  
117  	entries, err := os.ReadDir(dir)
118  	if err != nil {
119  		return err
120  	}
121  
122  	// 过滤忽略文件
123  	var validEntries []fs.DirEntry
124  	for _, entry := range entries {
125  		filename := filepath.Join(dir, entry.Name())
126  		if IsIgnoreFile(filename) {
127  			continue
128  		}
129  		if entry.IsDir() {
130  			validEntries = append(validEntries, entry)
131  			continue
132  		}
133  		if len(exts) > 0 {
134  			isSkip := true
135  			for _, ext := range exts {
136  				if strings.HasSuffix(entry.Name(), ext) {
137  					isSkip = false
138  					break
139  				}
140  			}
141  			if isSkip {
142  				continue
143  			}
144  		}
145  		validEntries = append(validEntries, entry)
146  	}
147  
148  	for i, entry := range validEntries {
149  		// 绘制树形结构线
150  		for d := 0; d < depth; d++ {
151  			if hasLast[d] {
152  				builder.WriteString("    ")
153  			} else {
154  				builder.WriteString("│   ")
155  			}
156  		}
157  
158  		// 判断是否是最后一项
159  		isLastEntry := i == len(validEntries)-1
160  		if isLastEntry {
161  			builder.WriteString("└── ")
162  		} else {
163  			builder.WriteString("├── ")
164  		}
165  
166  		// 添加条目名称和类型和权限以及文件大小
167  		var sizeInfo string
168  		if !entry.IsDir() {
169  			// 获取文件信息
170  			entryPath := filepath.Join(dir, entry.Name())
171  			if fileInfo, err := os.Stat(entryPath); err == nil {
172  				sizeInfo = fmt.Sprintf(" [%s]", formatFileSize(fileInfo.Size()))
173  			}
174  		}
175  		builder.WriteString(fmt.Sprintf("%s (%s)%s\n", entry.Name(), getSimpleType(dir, entry), sizeInfo))
176  
177  		// 递归处理子目录(不超过最大深度时)
178  		if entry.IsDir() && (maxLevel <= 0 || depth < maxLevel) {
179  			newHasLast := append(hasLast, isLastEntry)
180  			err = listDirRecursive(
181  				filepath.Join(dir, entry.Name()),
182  				depth+1,
183  				builder,
184  				newHasLast,
185  				maxLevel,
186  				exts,
187  			)
188  			if err != nil {
189  				return err
190  			}
191  		}
192  	}
193  	return nil
194  }
195  
196  // getSimpleType 简化文件类型显示
197  func getSimpleType(dir string, entry fs.DirEntry) string {
198  	if entry.IsDir() {
199  		return "dir"
200  	}
201  	fsPath := filepath.Join(dir, entry.Name())
202  	if entry.Type().IsRegular() {
203  		if IsTextFile(fsPath) {
204  			return "file"
205  		} else {
206  			return "binary"
207  		}
208  	}
209  	return entry.Type().String()
210  }
211  
212  func InitMcpClient(ctx context.Context, client *client.Client) (*mcp.InitializeResult, error) {
213  	err := client.Start(ctx)
214  	if err != nil {
215  		return nil, err
216  	}
217  	r, err := client.Initialize(context.Background(), mcp.InitializeRequest{})
218  	if err != nil {
219  		return nil, err
220  	}
221  	return r, err
222  }
223  
224  func ListMcpTools(ctx context.Context, client *client.Client) (*mcp.ListToolsResult, error) {
225  	result, err := client.ListTools(ctx, mcp.ListToolsRequest{})
226  	if err != nil {
227  		return nil, err
228  	}
229  	client.CallTool(ctx, mcp.CallToolRequest{})
230  	return result, nil
231  }
232  
233  func LanguagePrompt(language string) string {
234  	var languagePrompt string
235  	if language == "zh" {
236  		languagePrompt = "Response in Chinese."
237  	} else {
238  		languagePrompt = "Response in English."
239  	}
240  	return languagePrompt
241  }
242  
243  // IsTextFile 检查文件是否为文本文件
244  func IsTextFile(filename string) bool {
245  	file, err := os.Open(filename)
246  	if err != nil {
247  		return false
248  	}
249  	defer file.Close()
250  
251  	buf := make([]byte, 512)
252  	n, err := file.Read(buf)
253  	if err != nil && err != io.EOF {
254  		return false
255  	}
256  
257  	// 检查每个字节是否为非文本字符
258  	for i := 0; i < n; i++ {
259  		b := buf[i]
260  		if b <= 8 || b == 0x0B || b == 0x0C || (b >= 0x0E && b <= 0x1F) || b == 0x7F {
261  			return false // 发现控制字符或NULL,视为二进制文件
262  		}
263  	}
264  
265  	return true // 未找到非文本字符,视为文本文件
266  }
267  
268  // Grep 在文件或目录中搜索特定模式并返回匹配行及其上下文
269  func Grep(path string, pattern string, contextLines int) (string, error) {
270  	// 检查路径是文件还是目录
271  	fileInfo, err := os.Stat(path)
272  	if err != nil {
273  		return "", err
274  	}
275  
276  	// 支持多个表达式,通过逗号分隔
277  	patterns := strings.Split(pattern, ",")
278  	if len(patterns) == 0 {
279  		return "", fmt.Errorf("未提供搜索模式")
280  	}
281  
282  	// 编译所有正则表达式
283  	regexps := make([]*regexp.Regexp, 0, len(patterns))
284  	for _, p := range patterns {
285  		p = strings.TrimSpace(p)
286  		if p == "" {
287  			continue
288  		}
289  		re, err := regexp.Compile(p)
290  		if err != nil {
291  			return "", fmt.Errorf("正则表达式无效 '%s': %v", p, err)
292  		}
293  		regexps = append(regexps, re)
294  	}
295  
296  	if len(regexps) == 0 {
297  		return "", fmt.Errorf("没有有效的正则表达式")
298  	}
299  
300  	var results []string
301  	if fileInfo.IsDir() {
302  		// 如果是目录,遍历目录中的所有文件
303  		patternStr := strings.Join(patterns, "', '")
304  		results = append(results, fmt.Sprintf("在目录 '%s' 中搜索模式 ['%s']:\n", path, patternStr))
305  		err = grepDirectoryMulti(path, regexps, contextLines, &results)
306  		if err != nil {
307  			return "", err
308  		}
309  	} else {
310  		// 如果是文件,直接搜索文件
311  		fileResults, err := grepFileMulti(path, regexps, contextLines)
312  		if err != nil {
313  			return "", err
314  		}
315  		if fileResults != "" {
316  			results = append(results, fmt.Sprintf("文件: %s\n", path))
317  			results = append(results, fileResults)
318  		}
319  	}
320  
321  	if len(results) == 0 || (len(results) == 1 && strings.HasPrefix(results[0], "在目录")) {
322  		patternStr := strings.Join(patterns, "', '")
323  		return fmt.Sprintf("未找到匹配模式 ['%s'] 的内容", patternStr), nil
324  	}
325  
326  	return strings.Join(results, "\n"), nil
327  }
328  
329  // grepDirectoryMulti 在目录中搜索多个模式
330  func grepDirectoryMulti(dirPath string, regexps []*regexp.Regexp, contextLines int, results *[]string) error {
331  	entries, err := os.ReadDir(dirPath)
332  	if err != nil {
333  		return err
334  	}
335  
336  	foundMatches := false
337  	for _, entry := range entries {
338  		entryPath := fmt.Sprintf("%s/%s", dirPath, entry.Name())
339  
340  		// 跳过隐藏文件和目录
341  		if strings.HasPrefix(entry.Name(), ".") {
342  			continue
343  		}
344  
345  		if entry.IsDir() {
346  			// 递归搜索子目录
347  			err := grepDirectoryMulti(entryPath, regexps, contextLines, results)
348  			if err != nil {
349  				// 只记录错误,继续处理其他文件
350  				*results = append(*results, fmt.Sprintf("搜索目录 %s 时出错: %v", entryPath, err))
351  			}
352  		} else {
353  			// 只处理常见文本文件类型
354  			if IsTextFile(entryPath) {
355  				fileResults, err := grepFileMulti(entryPath, regexps, contextLines)
356  				if err != nil {
357  					// 只记录错误,继续处理其他文件
358  					continue
359  				}
360  
361  				if fileResults != "" {
362  					if !foundMatches {
363  						foundMatches = true
364  					}
365  					*results = append(*results, fmt.Sprintf("\n文件: %s", entryPath))
366  					*results = append(*results, fileResults)
367  				}
368  			}
369  		}
370  	}
371  
372  	return nil
373  }
374  
375  // grepFileMulti 在单个文件中搜索多个模式,支持跨行匹配
376  func grepFileMulti(filename string, regexps []*regexp.Regexp, contextLines int) (string, error) {
377  	// 读取整个文件内容
378  	content, err := os.ReadFile(filename)
379  	if err != nil {
380  		return "", err
381  	}
382  
383  	fileContent := string(content)
384  
385  	// 将文件内容按行分割,保留行信息用于显示上下文
386  	lines := strings.Split(fileContent, "\n")
387  
388  	var results []string
389  	matchFound := false
390  	processedRanges := make(map[string]bool) // 记录已处理的匹配范围,避免重复
391  
392  	// 对每个正则表达式进行匹配
393  	for _, re := range regexps {
394  		// 在整个文件内容中查找所有匹配
395  		matches := re.FindAllStringIndex(fileContent, -1)
396  
397  		for _, match := range matches {
398  			startPos := match[0]
399  			endPos := match[1]
400  
401  			// 生成唯一标识符避免重复处理相同位置的匹配
402  			rangeKey := fmt.Sprintf("%d-%d", startPos, endPos)
403  			if processedRanges[rangeKey] {
404  				continue
405  			}
406  			processedRanges[rangeKey] = true
407  
408  			matchFound = true
409  
410  			// 计算匹配开始和结束的行号
411  			startLineNum := strings.Count(fileContent[:startPos], "\n")
412  			endLineNum := strings.Count(fileContent[:endPos], "\n")
413  
414  			// 计算显示上下文的行范围
415  			contextStart := startLineNum - contextLines
416  			if contextStart < 0 {
417  				contextStart = 0
418  			}
419  			contextEnd := endLineNum + contextLines
420  			if contextEnd >= len(lines) {
421  				contextEnd = len(lines) - 1
422  			}
423  
424  			// 添加匹配信息
425  			matchedContent := fileContent[startPos:endPos]
426  			// 转义显示特殊字符
427  			displayContent := strings.ReplaceAll(matchedContent, "\n", "\\n")
428  			displayContent = strings.ReplaceAll(displayContent, "\t", "\\t")
429  			if len(displayContent) > 100 {
430  				displayContent = displayContent[:100] + "..."
431  			}
432  
433  			results = append(results, fmt.Sprintf("=== 匹配范围: 行 %d-%d ===", startLineNum+1, endLineNum+1))
434  			results = append(results, fmt.Sprintf("匹配内容: %s", displayContent))
435  			results = append(results, "")
436  
437  			// 显示上下文
438  			for i := contextStart; i <= contextEnd; i++ {
439  				if i >= len(lines) {
440  					break
441  				}
442  
443  				prefix := "  "
444  				// 标记匹配行
445  				if i >= startLineNum && i <= endLineNum {
446  					prefix = ">"
447  				}
448  				results = append(results, fmt.Sprintf("%s %d: %s", prefix, i+1, lines[i]))
449  			}
450  			results = append(results, "")
451  		}
452  	}
453  
454  	if !matchFound {
455  		return "", nil
456  	}
457  
458  	return strings.Join(results, "\n"), nil
459  }
460  
461  // ReadFileChunk 读取文件的一部分
462  // 参数:filename 文件名,startLine 开始行号,endLines 结束行号,maxBytes 最大字节数
463  // 返回值:string: 文件内容,error: 错误信息
464  func ReadFileChunk(filename string, startLine int, endLines int, maxBytes int) (string, error) {
465  	file, err := os.Open(filename)
466  	if err != nil {
467  		return "", err
468  	}
469  	defer file.Close()
470  
471  	totalLines := 0
472  	// 读取指定的行数或字节数
473  	var sb strings.Builder
474  	bytesRead := 0
475  	linesRead := 0
476  	currentLine := 0
477  
478  	scanner := bufio.NewScanner(file)
479  	for scanner.Scan() {
480  		currentLine += 1
481  		totalLines += 1
482  		if ((startLine <= 1 && endLines <= 1) || (currentLine >= startLine && currentLine <= endLines)) && bytesRead < maxBytes {
483  			line := scanner.Text()
484  			sb.WriteString(line + "\n")
485  			bytesRead += len(line + "\n")
486  			linesRead = currentLine
487  		}
488  	}
489  
490  	if err := scanner.Err(); err != nil {
491  		return "", err
492  	}
493  
494  	result := sb.String()
495  	if len(result) > 0 {
496  		if linesRead < totalLines {
497  			if startLine == 0 && endLines == 0 {
498  				startLine = 0
499  				endLines = linesRead
500  			}
501  			result += fmt.Sprintf("\n----\n (文件还有更多内容,共 %d 行,准备读取 %d-%d,行当前已读取到第 %d 行,约 %d 字节) 请自行判断是否读取接下来的行\n",
502  				totalLines, startLine, endLines, linesRead, bytesRead)
503  		} else {
504  			result += fmt.Sprintf("\n----\n (文件已读取完毕,最后一行为第 %d 行)\n", linesRead)
505  		}
506  	}
507  	return result, nil
508  }