mcp_task.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package agent 20 21 import ( 22 "context" 23 "encoding/json" 24 "errors" 25 "fmt" 26 "os" 27 "path/filepath" 28 "strconv" 29 "strings" 30 "time" 31 32 "github.com/Tencent/AI-Infra-Guard/common/utils" 33 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 34 ) 35 36 type McpTask struct { 37 Server string 38 } 39 40 func (m *McpTask) GetName() string { 41 return TaskTypeMcpScan 42 } 43 44 func (m *McpTask) Execute(ctx context.Context, request TaskRequest, callbacks TaskCallbacks) error { 45 type ScanMcpRequest struct { 46 Content string `json:"-"` 47 Model struct { 48 Model string `json:"model"` 49 Token string `json:"token"` 50 BaseUrl string `json:"base_url"` 51 } `json:"model"` 52 Headers map[string]string `json:"headers"` 53 } 54 55 var params ScanMcpRequest 56 if err := json.Unmarshal(request.Params, ¶ms); err != nil { 57 return err 58 } 59 params.Content = request.Content 60 files := request.Attachments 61 transport := "code" // code or url 62 if len(files) > 0 || strings.Contains(request.Content, "github.com") { 63 transport = "code" 64 } else { 65 transport = "url" 66 } 67 language := request.Language 68 if language == "" { 69 language = "zh" 70 } 71 72 var folder string 73 var serverUrl string 74 if transport == "code" { 75 // 创建临时目录用于存储上传的文件 76 tempDir := "uploads" 77 if err := os.MkdirAll(tempDir, 0755); err != nil { 78 gologger.Errorf("%s: %v", "createTempDir", err) 79 return err 80 } 81 if len(files) > 0 { 82 // 远程下载 83 for _, file := range files { 84 // 下载文件 85 ext := "" 86 supports := []string{".zip", ".tar.gz", ".tgz", ".whl"} 87 for _, support := range supports { 88 if strings.HasSuffix(file, support) { 89 ext = support 90 break 91 } 92 } 93 if ext == "" { 94 gologger.Errorln("Unsupported file type", strings.Join(supports, ",")) 95 continue 96 } 97 98 fileName := filepath.Join(tempDir, fmt.Sprintf("tmp-%d%s", time.Now().UnixMicro(), ext)) 99 err := utils.DownloadFile(m.Server, request.SessionId, file, fileName) 100 if err != nil { 101 return fmt.Errorf("download failed: %v", err) 102 } 103 extractPath, _ := filepath.Abs(filepath.Join(tempDir, fmt.Sprintf("tmp-%d", time.Now().UnixMicro()))) 104 switch ext { 105 case ".zip", ".whl": 106 err = utils.ExtractZipFile(fileName, extractPath) 107 case ".tgz", ".tar.gz": 108 err = utils.ExtractTGZ(fileName, extractPath) 109 default: 110 return errors.New("Unsupported file type: " + strings.Join(supports, ",")) 111 } 112 if err != nil { 113 return errors.New(fmt.Sprintf("extract failed: %v", err)) 114 } 115 folder = extractPath 116 } 117 } else { 118 extractPath, _ := filepath.Abs(filepath.Join(tempDir, fmt.Sprintf("tmp-%d", time.Now().UnixMicro()))) 119 err := utils.GitClone(params.Content, extractPath, 10*time.Minute) 120 if err != nil { 121 return fmt.Errorf("clone failed: %v", err) 122 } 123 folder = extractPath 124 } 125 126 // 判断文件夹是否存在 127 if info, err := os.Stat(folder); os.IsNotExist(err) || !info.IsDir() { 128 return fmt.Errorf("folder does not exist or is not a directory: %s", folder) 129 } 130 } else if transport == "url" { 131 serverUrl = params.Content 132 } 133 134 var argv []string = make([]string, 0) 135 argv = append(argv, "run", "main.py") 136 argv = append(argv, "--model", params.Model.Model) 137 argv = append(argv, "--base_url", params.Model.BaseUrl) 138 argv = append(argv, "--api_key", params.Model.Token) 139 argv = append(argv, "--prompt", params.Content) 140 argv = append(argv, "--debug") 141 argv = append(argv, "--language", language) 142 if params.Headers != nil { 143 for k, v := range params.Headers { 144 argv = append(argv, "--header", fmt.Sprintf("%s:%s", k, v)) 145 } 146 } 147 148 var taskTitles []string 149 if transport == "code" { 150 argv = append(argv, "--repo", folder) 151 taskTitles = []string{ 152 "Info Collection", 153 "Code Audit", 154 "Vulnerability Review", 155 } 156 } else if transport == "url" { 157 argv = append(argv, "--server_url", serverUrl) 158 taskTitles = []string{ 159 "Info Collection", 160 "Malicious Testing", 161 "Vulnerability Testing", 162 "Vulnerability Review", 163 } 164 } 165 166 var tasks []SubTask 167 //taskTitles := []string{ 168 // "信息收集", 169 // "代码审计", 170 // "漏洞整理", 171 //} 172 173 for i, title := range taskTitles { 174 tasks = append(tasks, CreateSubTask(SubTaskStatusTodo, title, 0, strconv.Itoa(i+1))) 175 } 176 callbacks.PlanUpdateCallback(tasks) 177 config := CmdConfig{StatusId: ""} 178 mcpDir, err := utils.ResolveMcpScanDir() 179 if err != nil { 180 return fmt.Errorf("resolve mcp-scan directory: %v", err) 181 } 182 uvBin, err := utils.ResolveUvBin() 183 if err != nil { 184 return fmt.Errorf("resolve uv binary: %v", err) 185 } 186 err = utils.RunCmdWithContext(ctx, mcpDir, uvBin, argv, func(line string) { 187 ParseStdoutLine(m.Server, mcpDir, tasks, line, callbacks, &config, false) 188 }) 189 return err 190 }