/ common / agent / mcp_task.go
mcp_task.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package agent
 20  
 21  import (
 22  	"context"
 23  	"encoding/json"
 24  	"errors"
 25  	"fmt"
 26  	"os"
 27  	"path/filepath"
 28  	"strconv"
 29  	"strings"
 30  	"time"
 31  
 32  	"github.com/Tencent/AI-Infra-Guard/common/utils"
 33  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 34  )
 35  
 36  type McpTask struct {
 37  	Server string
 38  }
 39  
 40  func (m *McpTask) GetName() string {
 41  	return TaskTypeMcpScan
 42  }
 43  
 44  func (m *McpTask) Execute(ctx context.Context, request TaskRequest, callbacks TaskCallbacks) error {
 45  	type ScanMcpRequest struct {
 46  		Content string `json:"-"`
 47  		Model   struct {
 48  			Model   string `json:"model"`
 49  			Token   string `json:"token"`
 50  			BaseUrl string `json:"base_url"`
 51  		} `json:"model"`
 52  		Headers map[string]string `json:"headers"`
 53  	}
 54  
 55  	var params ScanMcpRequest
 56  	if err := json.Unmarshal(request.Params, &params); err != nil {
 57  		return err
 58  	}
 59  	params.Content = request.Content
 60  	files := request.Attachments
 61  	transport := "code" // code or url
 62  	if len(files) > 0 || strings.Contains(request.Content, "github.com") {
 63  		transport = "code"
 64  	} else {
 65  		transport = "url"
 66  	}
 67  	language := request.Language
 68  	if language == "" {
 69  		language = "zh"
 70  	}
 71  
 72  	var folder string
 73  	var serverUrl string
 74  	if transport == "code" {
 75  		// 创建临时目录用于存储上传的文件
 76  		tempDir := "uploads"
 77  		if err := os.MkdirAll(tempDir, 0755); err != nil {
 78  			gologger.Errorf("%s: %v", "createTempDir", err)
 79  			return err
 80  		}
 81  		if len(files) > 0 {
 82  			// 远程下载
 83  			for _, file := range files {
 84  				// 下载文件
 85  				ext := ""
 86  				supports := []string{".zip", ".tar.gz", ".tgz", ".whl"}
 87  				for _, support := range supports {
 88  					if strings.HasSuffix(file, support) {
 89  						ext = support
 90  						break
 91  					}
 92  				}
 93  				if ext == "" {
 94  					gologger.Errorln("Unsupported file type", strings.Join(supports, ","))
 95  					continue
 96  				}
 97  
 98  				fileName := filepath.Join(tempDir, fmt.Sprintf("tmp-%d%s", time.Now().UnixMicro(), ext))
 99  				err := utils.DownloadFile(m.Server, request.SessionId, file, fileName)
100  				if err != nil {
101  					return fmt.Errorf("download failed: %v", err)
102  				}
103  				extractPath, _ := filepath.Abs(filepath.Join(tempDir, fmt.Sprintf("tmp-%d", time.Now().UnixMicro())))
104  				switch ext {
105  				case ".zip", ".whl":
106  					err = utils.ExtractZipFile(fileName, extractPath)
107  				case ".tgz", ".tar.gz":
108  					err = utils.ExtractTGZ(fileName, extractPath)
109  				default:
110  					return errors.New("Unsupported file type: " + strings.Join(supports, ","))
111  				}
112  				if err != nil {
113  					return errors.New(fmt.Sprintf("extract failed: %v", err))
114  				}
115  				folder = extractPath
116  			}
117  		} else {
118  			extractPath, _ := filepath.Abs(filepath.Join(tempDir, fmt.Sprintf("tmp-%d", time.Now().UnixMicro())))
119  			err := utils.GitClone(params.Content, extractPath, 10*time.Minute)
120  			if err != nil {
121  				return fmt.Errorf("clone failed: %v", err)
122  			}
123  			folder = extractPath
124  		}
125  
126  		// 判断文件夹是否存在
127  		if info, err := os.Stat(folder); os.IsNotExist(err) || !info.IsDir() {
128  			return fmt.Errorf("folder does not exist or is not a directory: %s", folder)
129  		}
130  	} else if transport == "url" {
131  		serverUrl = params.Content
132  	}
133  
134  	var argv []string = make([]string, 0)
135  	argv = append(argv, "run", "main.py")
136  	argv = append(argv, "--model", params.Model.Model)
137  	argv = append(argv, "--base_url", params.Model.BaseUrl)
138  	argv = append(argv, "--api_key", params.Model.Token)
139  	argv = append(argv, "--prompt", params.Content)
140  	argv = append(argv, "--debug")
141  	argv = append(argv, "--language", language)
142  	if params.Headers != nil {
143  		for k, v := range params.Headers {
144  			argv = append(argv, "--header", fmt.Sprintf("%s:%s", k, v))
145  		}
146  	}
147  
148  	var taskTitles []string
149  	if transport == "code" {
150  		argv = append(argv, "--repo", folder)
151  		taskTitles = []string{
152  			"Info Collection",
153  			"Code Audit",
154  			"Vulnerability Review",
155  		}
156  	} else if transport == "url" {
157  		argv = append(argv, "--server_url", serverUrl)
158  		taskTitles = []string{
159  			"Info Collection",
160  			"Malicious Testing",
161  			"Vulnerability Testing",
162  			"Vulnerability Review",
163  		}
164  	}
165  
166  	var tasks []SubTask
167  	//taskTitles := []string{
168  	//	"信息收集",
169  	//	"代码审计",
170  	//	"漏洞整理",
171  	//}
172  
173  	for i, title := range taskTitles {
174  		tasks = append(tasks, CreateSubTask(SubTaskStatusTodo, title, 0, strconv.Itoa(i+1)))
175  	}
176  	callbacks.PlanUpdateCallback(tasks)
177  	config := CmdConfig{StatusId: ""}
178  	mcpDir, err := utils.ResolveMcpScanDir()
179  	if err != nil {
180  		return fmt.Errorf("resolve mcp-scan directory: %v", err)
181  	}
182  	uvBin, err := utils.ResolveUvBin()
183  	if err != nil {
184  		return fmt.Errorf("resolve uv binary: %v", err)
185  	}
186  	err = utils.RunCmdWithContext(ctx, mcpDir, uvBin, argv, func(line string) {
187  		ParseStdoutLine(m.Server, mcpDir, tasks, line, callbacks, &config, false)
188  	})
189  	return err
190  }