/ internal / mcp / scanner.go
scanner.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package mcp
 20  
 21  import (
 22  	"context"
 23  	"fmt"
 24  	"net/http"
 25  	"os"
 26  	"path/filepath"
 27  	"strings"
 28  	"sync"
 29  	"time"
 30  
 31  	utils2 "github.com/Tencent/AI-Infra-Guard/common/utils"
 32  	"github.com/Tencent/AI-Infra-Guard/common/utils/models"
 33  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 34  	"github.com/Tencent/AI-Infra-Guard/internal/mcp/utils"
 35  	"github.com/mark3labs/mcp-go/client"
 36  	"github.com/mark3labs/mcp-go/client/transport"
 37  	"github.com/mark3labs/mcp-go/mcp"
 38  )
 39  
 40  type Scanner struct {
 41  	mutex         sync.Mutex
 42  	results       []*Issue
 43  	PluginConfigs []*PluginConfig
 44  	aiModel       *models.OpenAI
 45  	client        *client.Client
 46  	csvResult     [][]string
 47  	codePath      string
 48  	url           string
 49  	callback      func(data interface{})
 50  	language      string
 51  	logger        *gologger.Logger
 52  }
 53  
 54  type McpCallbackProcessing struct {
 55  	Current int `json:"current"`
 56  	Total   int `json:"total"`
 57  }
 58  
 59  type McpCallbackReadMe struct {
 60  	Content string `json:"content"`
 61  }
 62  
 63  type McpModuleStart struct {
 64  	ModuleName string
 65  }
 66  
 67  type McpModuleEnd struct {
 68  	ModuleName string
 69  	Result     string
 70  }
 71  
 72  func NewScanner(aiConfig *models.OpenAI, logger *gologger.Logger) *Scanner {
 73  	if logger == nil {
 74  		logger = gologger.NewLogger()
 75  	}
 76  	s := &Scanner{
 77  		results:       make([]*Issue, 0),
 78  		PluginConfigs: make([]*PluginConfig, 0),
 79  		aiModel:       aiConfig,
 80  		csvResult:     make([][]string, 0),
 81  		language:      "zh",
 82  		logger:        logger,
 83  	}
 84  	return s
 85  }
 86  
 87  func (s *Scanner) SetCallback(callback func(data interface{})) {
 88  	s.callback = callback
 89  }
 90  
 91  func (s *Scanner) GetPluginsByCategory(category string) []*PluginConfig {
 92  	plugins := make([]*PluginConfig, 0)
 93  	for _, plugin := range s.PluginConfigs {
 94  		for _, c := range plugin.Info.Category {
 95  			if c == category {
 96  				p := plugin
 97  				plugins = append(plugins, p)
 98  			}
 99  		}
100  	}
101  	return plugins
102  }
103  
104  func (s *Scanner) GetAllPluginNames() ([]string, error) {
105  	names := make([]string, 0)
106  	execPath, err := os.Executable()
107  	if err != nil {
108  		return nil, err
109  	}
110  
111  	// 构建data/mcp目录路径
112  	dataDir := filepath.Join(filepath.Dir(execPath), "data", "mcp")
113  
114  	// 如果相对路径不存在,尝试从工作目录查找
115  	if _, err := os.Stat(dataDir); os.IsNotExist(err) {
116  		wd, _ := os.Getwd()
117  		dataDir = filepath.Join(wd, "data", "mcp")
118  	}
119  	files, err := utils2.ScanDir(dataDir)
120  	if err != nil {
121  		return nil, err
122  	}
123  
124  	for _, configPath := range files {
125  		if !strings.HasSuffix(configPath, ".yaml") {
126  			continue
127  		}
128  		plugin, err := NewYAMLPlugin(configPath)
129  		if err != nil {
130  			s.logger.Errorf("加载插件配置失败 %s: %v", configPath, err)
131  			continue
132  		}
133  		names = append(names, plugin.Info.ID)
134  	}
135  	return names, nil
136  }
137  
138  func (s *Scanner) RegisterPlugin(plugins []string) error {
139  	// 获取当前执行文件的目录
140  	execPath, err := os.Executable()
141  	if err != nil {
142  		return err
143  	}
144  
145  	// 构建data/mcp目录路径
146  	dataDir := filepath.Join(filepath.Dir(execPath), "data", "mcp")
147  
148  	// 如果相对路径不存在,尝试从工作目录查找
149  	if _, err := os.Stat(dataDir); os.IsNotExist(err) {
150  		wd, _ := os.Getwd()
151  		dataDir = filepath.Join(wd, "data", "mcp")
152  	}
153  	files, err := utils2.ScanDir(dataDir)
154  	if err != nil {
155  		return err
156  	}
157  
158  	for _, configPath := range files {
159  		if !strings.HasSuffix(configPath, ".yaml") {
160  			continue
161  		}
162  		plugin, err := NewYAMLPlugin(configPath)
163  		if err != nil {
164  			s.logger.Errorf("加载插件配置失败 %s: %v", configPath, err)
165  			continue
166  		}
167  		id := plugin.Info.ID
168  		if len(plugins) > 0 {
169  			for _, p := range plugins {
170  				if p == id {
171  					s.logger.Infof("加载插件 %s", plugin.Info.Name)
172  					s.PluginConfigs = append(s.PluginConfigs, plugin)
173  					break
174  				}
175  			}
176  		} else {
177  			s.logger.Infof("加载插件 %s", plugin.Info.Name)
178  			s.PluginConfigs = append(s.PluginConfigs, plugin)
179  		}
180  	}
181  	if len(s.PluginConfigs) == 0 {
182  		return fmt.Errorf("未加载任何插件")
183  	}
184  	return nil
185  }
186  
187  func (s *Scanner) InputCommand(ctx context.Context, command string, argv []string) (*mcp.InitializeResult, error) {
188  	mcpClient, err := client.NewStdioMCPClient(
189  		command,
190  		argv,
191  	)
192  	if err != nil {
193  		return nil, err
194  	}
195  	s.client = mcpClient
196  	return utils.InitMcpClient(ctx, s.client)
197  }
198  
199  func (s *Scanner) InputUrl(ctx context.Context, url string) (*mcp.InitializeResult, error) {
200  	dirs := []string{"", "/mcp", "/sse"}
201  	url = strings.TrimRight(url, "/")
202  	scan := func(ctx context.Context, url string) (*mcp.InitializeResult, error) {
203  		r, err := s.InputStreamLink(ctx, url)
204  		if err != nil {
205  			r, err = s.InputSSELink(ctx, url)
206  			if err != nil {
207  				return nil, err
208  			}
209  		}
210  		return r, nil
211  	}
212  	var err error
213  	for _, u := range dirs {
214  		link := url + u
215  		r, err := scan(ctx, link)
216  		if err == nil {
217  			return r, nil
218  		}
219  	}
220  	return nil, err
221  }
222  
223  func (s *Scanner) InputSSELink(ctx context.Context, link string) (*mcp.InitializeResult, error) {
224  	opt := client.WithHTTPClient(&http.Client{Timeout: 10 * time.Second})
225  	mcpClient, err := client.NewSSEMCPClient(link, opt)
226  	if err != nil {
227  		return nil, err
228  	}
229  	r, err := utils.InitMcpClient(ctx, mcpClient)
230  	if err != nil {
231  		return nil, err
232  	}
233  	s.client = mcpClient
234  	return r, err
235  }
236  
237  func (s *Scanner) InputStreamLink(ctx context.Context, link string) (*mcp.InitializeResult, error) {
238  	mcpClient, err := client.NewStreamableHttpClient(link, transport.WithHTTPTimeout(10*time.Second))
239  	if err != nil {
240  		return nil, err
241  	}
242  	r, err := utils.InitMcpClient(ctx, mcpClient)
243  	if err != nil {
244  		return nil, err
245  	}
246  	s.client = mcpClient
247  	return r, nil
248  }
249  
250  func (s *Scanner) InputCodePath(codePath string) error {
251  	c, err := filepath.Abs(codePath)
252  	if err != nil {
253  		return err
254  	}
255  	s.codePath = c
256  	return nil
257  }
258  
259  func (s *Scanner) SetLanguage(language string) error {
260  	if language == "" {
261  		s.language = "zh"
262  		return nil
263  	}
264  	if language == "zh-CN" {
265  		s.language = "zh"
266  	} else {
267  		s.language = language
268  	}
269  	return nil
270  }
271  
272  type McpTemplate struct {
273  	CodePath              string
274  	DirectoryStructure    string
275  	StaticAnalysisResults string
276  	OriginalReports       string
277  	McpStructure          string
278  }
279  
280  type McpResult struct {
281  	Issues []Issue
282  	Report []Issue
283  }
284  
285  type CallbackWriteLog struct {
286  	Text       []byte
287  	ModuleName string
288  }
289  
290  type tmpWriter struct {
291  	Callback   func(data interface{})
292  	Mux        sync.Mutex
293  	cache      []byte
294  	ModuleName string
295  }
296  
297  func (t *tmpWriter) Write(p []byte) (n int, err error) {
298  	t.Mux.Lock()
299  	defer t.Mux.Unlock()
300  	for _, word := range p {
301  		t.cache = append(t.cache, word)
302  		if word == '\n' {
303  			if t.Callback != nil {
304  				t.Callback(CallbackWriteLog{t.cache, t.ModuleName})
305  			}
306  			t.cache = []byte{}
307  		}
308  	}
309  	return len(p), nil
310  }
311  
312  func (t *tmpWriter) Finally() {
313  	if len(t.cache) > 0 {
314  		if t.Callback != nil {
315  			t.Callback(CallbackWriteLog{t.cache, t.ModuleName})
316  		}
317  		t.cache = []byte{}
318  	}
319  }
320  
321  func (s *Scanner) getPluginByID(id string) (*PluginConfig, error) {
322  	for _, plugin := range s.PluginConfigs {
323  		if plugin.Info.ID == id {
324  			return plugin, nil
325  		}
326  	}
327  	return nil, fmt.Errorf("插件 %s 未找到", id)
328  }
329  
330  func (s *Scanner) GetCsvResult() [][]string {
331  	return s.csvResult
332  }