scanner.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package mcp 20 21 import ( 22 "context" 23 "fmt" 24 "net/http" 25 "os" 26 "path/filepath" 27 "strings" 28 "sync" 29 "time" 30 31 utils2 "github.com/Tencent/AI-Infra-Guard/common/utils" 32 "github.com/Tencent/AI-Infra-Guard/common/utils/models" 33 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 34 "github.com/Tencent/AI-Infra-Guard/internal/mcp/utils" 35 "github.com/mark3labs/mcp-go/client" 36 "github.com/mark3labs/mcp-go/client/transport" 37 "github.com/mark3labs/mcp-go/mcp" 38 ) 39 40 type Scanner struct { 41 mutex sync.Mutex 42 results []*Issue 43 PluginConfigs []*PluginConfig 44 aiModel *models.OpenAI 45 client *client.Client 46 csvResult [][]string 47 codePath string 48 url string 49 callback func(data interface{}) 50 language string 51 logger *gologger.Logger 52 } 53 54 type McpCallbackProcessing struct { 55 Current int `json:"current"` 56 Total int `json:"total"` 57 } 58 59 type McpCallbackReadMe struct { 60 Content string `json:"content"` 61 } 62 63 type McpModuleStart struct { 64 ModuleName string 65 } 66 67 type McpModuleEnd struct { 68 ModuleName string 69 Result string 70 } 71 72 func NewScanner(aiConfig *models.OpenAI, logger *gologger.Logger) *Scanner { 73 if logger == nil { 74 logger = gologger.NewLogger() 75 } 76 s := &Scanner{ 77 results: make([]*Issue, 0), 78 PluginConfigs: make([]*PluginConfig, 0), 79 aiModel: aiConfig, 80 csvResult: make([][]string, 0), 81 language: "zh", 82 logger: logger, 83 } 84 return s 85 } 86 87 func (s *Scanner) SetCallback(callback func(data interface{})) { 88 s.callback = callback 89 } 90 91 func (s *Scanner) GetPluginsByCategory(category string) []*PluginConfig { 92 plugins := make([]*PluginConfig, 0) 93 for _, plugin := range s.PluginConfigs { 94 for _, c := range plugin.Info.Category { 95 if c == category { 96 p := plugin 97 plugins = append(plugins, p) 98 } 99 } 100 } 101 return plugins 102 } 103 104 func (s *Scanner) GetAllPluginNames() ([]string, error) { 105 names := make([]string, 0) 106 execPath, err := os.Executable() 107 if err != nil { 108 return nil, err 109 } 110 111 // 构建data/mcp目录路径 112 dataDir := filepath.Join(filepath.Dir(execPath), "data", "mcp") 113 114 // 如果相对路径不存在,尝试从工作目录查找 115 if _, err := os.Stat(dataDir); os.IsNotExist(err) { 116 wd, _ := os.Getwd() 117 dataDir = filepath.Join(wd, "data", "mcp") 118 } 119 files, err := utils2.ScanDir(dataDir) 120 if err != nil { 121 return nil, err 122 } 123 124 for _, configPath := range files { 125 if !strings.HasSuffix(configPath, ".yaml") { 126 continue 127 } 128 plugin, err := NewYAMLPlugin(configPath) 129 if err != nil { 130 s.logger.Errorf("加载插件配置失败 %s: %v", configPath, err) 131 continue 132 } 133 names = append(names, plugin.Info.ID) 134 } 135 return names, nil 136 } 137 138 func (s *Scanner) RegisterPlugin(plugins []string) error { 139 // 获取当前执行文件的目录 140 execPath, err := os.Executable() 141 if err != nil { 142 return err 143 } 144 145 // 构建data/mcp目录路径 146 dataDir := filepath.Join(filepath.Dir(execPath), "data", "mcp") 147 148 // 如果相对路径不存在,尝试从工作目录查找 149 if _, err := os.Stat(dataDir); os.IsNotExist(err) { 150 wd, _ := os.Getwd() 151 dataDir = filepath.Join(wd, "data", "mcp") 152 } 153 files, err := utils2.ScanDir(dataDir) 154 if err != nil { 155 return err 156 } 157 158 for _, configPath := range files { 159 if !strings.HasSuffix(configPath, ".yaml") { 160 continue 161 } 162 plugin, err := NewYAMLPlugin(configPath) 163 if err != nil { 164 s.logger.Errorf("加载插件配置失败 %s: %v", configPath, err) 165 continue 166 } 167 id := plugin.Info.ID 168 if len(plugins) > 0 { 169 for _, p := range plugins { 170 if p == id { 171 s.logger.Infof("加载插件 %s", plugin.Info.Name) 172 s.PluginConfigs = append(s.PluginConfigs, plugin) 173 break 174 } 175 } 176 } else { 177 s.logger.Infof("加载插件 %s", plugin.Info.Name) 178 s.PluginConfigs = append(s.PluginConfigs, plugin) 179 } 180 } 181 if len(s.PluginConfigs) == 0 { 182 return fmt.Errorf("未加载任何插件") 183 } 184 return nil 185 } 186 187 func (s *Scanner) InputCommand(ctx context.Context, command string, argv []string) (*mcp.InitializeResult, error) { 188 mcpClient, err := client.NewStdioMCPClient( 189 command, 190 argv, 191 ) 192 if err != nil { 193 return nil, err 194 } 195 s.client = mcpClient 196 return utils.InitMcpClient(ctx, s.client) 197 } 198 199 func (s *Scanner) InputUrl(ctx context.Context, url string) (*mcp.InitializeResult, error) { 200 dirs := []string{"", "/mcp", "/sse"} 201 url = strings.TrimRight(url, "/") 202 scan := func(ctx context.Context, url string) (*mcp.InitializeResult, error) { 203 r, err := s.InputStreamLink(ctx, url) 204 if err != nil { 205 r, err = s.InputSSELink(ctx, url) 206 if err != nil { 207 return nil, err 208 } 209 } 210 return r, nil 211 } 212 var err error 213 for _, u := range dirs { 214 link := url + u 215 r, err := scan(ctx, link) 216 if err == nil { 217 return r, nil 218 } 219 } 220 return nil, err 221 } 222 223 func (s *Scanner) InputSSELink(ctx context.Context, link string) (*mcp.InitializeResult, error) { 224 opt := client.WithHTTPClient(&http.Client{Timeout: 10 * time.Second}) 225 mcpClient, err := client.NewSSEMCPClient(link, opt) 226 if err != nil { 227 return nil, err 228 } 229 r, err := utils.InitMcpClient(ctx, mcpClient) 230 if err != nil { 231 return nil, err 232 } 233 s.client = mcpClient 234 return r, err 235 } 236 237 func (s *Scanner) InputStreamLink(ctx context.Context, link string) (*mcp.InitializeResult, error) { 238 mcpClient, err := client.NewStreamableHttpClient(link, transport.WithHTTPTimeout(10*time.Second)) 239 if err != nil { 240 return nil, err 241 } 242 r, err := utils.InitMcpClient(ctx, mcpClient) 243 if err != nil { 244 return nil, err 245 } 246 s.client = mcpClient 247 return r, nil 248 } 249 250 func (s *Scanner) InputCodePath(codePath string) error { 251 c, err := filepath.Abs(codePath) 252 if err != nil { 253 return err 254 } 255 s.codePath = c 256 return nil 257 } 258 259 func (s *Scanner) SetLanguage(language string) error { 260 if language == "" { 261 s.language = "zh" 262 return nil 263 } 264 if language == "zh-CN" { 265 s.language = "zh" 266 } else { 267 s.language = language 268 } 269 return nil 270 } 271 272 type McpTemplate struct { 273 CodePath string 274 DirectoryStructure string 275 StaticAnalysisResults string 276 OriginalReports string 277 McpStructure string 278 } 279 280 type McpResult struct { 281 Issues []Issue 282 Report []Issue 283 } 284 285 type CallbackWriteLog struct { 286 Text []byte 287 ModuleName string 288 } 289 290 type tmpWriter struct { 291 Callback func(data interface{}) 292 Mux sync.Mutex 293 cache []byte 294 ModuleName string 295 } 296 297 func (t *tmpWriter) Write(p []byte) (n int, err error) { 298 t.Mux.Lock() 299 defer t.Mux.Unlock() 300 for _, word := range p { 301 t.cache = append(t.cache, word) 302 if word == '\n' { 303 if t.Callback != nil { 304 t.Callback(CallbackWriteLog{t.cache, t.ModuleName}) 305 } 306 t.cache = []byte{} 307 } 308 } 309 return len(p), nil 310 } 311 312 func (t *tmpWriter) Finally() { 313 if len(t.cache) > 0 { 314 if t.Callback != nil { 315 t.Callback(CallbackWriteLog{t.cache, t.ModuleName}) 316 } 317 t.cache = []byte{} 318 } 319 } 320 321 func (s *Scanner) getPluginByID(id string) (*PluginConfig, error) { 322 for _, plugin := range s.PluginConfigs { 323 if plugin.Info.ID == id { 324 return plugin, nil 325 } 326 } 327 return nil, fmt.Errorf("插件 %s 未找到", id) 328 } 329 330 func (s *Scanner) GetCsvResult() [][]string { 331 return s.csvResult 332 }