/ internal / mcp / plugins.go
plugins.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package mcp
 20  
 21  import (
 22  	"context"
 23  	"fmt"
 24  	"os"
 25  	"regexp"
 26  	"strings"
 27  
 28  	"github.com/Tencent/AI-Infra-Guard/common/utils/models"
 29  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 30  	"github.com/Tencent/AI-Infra-Guard/internal/mcp/utils"
 31  	"github.com/mark3labs/mcp-go/client"
 32  	"gopkg.in/yaml.v3"
 33  )
 34  
 35  type PluginConfig struct {
 36  	Info struct {
 37  		ID          string   `yaml:"id" json:"id"`
 38  		Name        string   `yaml:"name" json:"name"`
 39  		Description string   `yaml:"description" json:"description"`
 40  		Author      string   `yaml:"author" json:"author"`
 41  		Category    []string `yaml:"categories" json:"category"`
 42  	} `yaml:"info" json:"info"`
 43  	Rules          []Rule `yaml:"rules,omitempty" json:"rules,omitempty"`
 44  	PromptTemplate string `yaml:"prompt_template" json:"prompt_template"`
 45  }
 46  
 47  type Rule struct {
 48  	Name        string `yaml:"name"`
 49  	Pattern     string `yaml:"pattern"`
 50  	Description string `yaml:"description"`
 51  }
 52  
 53  func NewYAMLPlugin(configPath string) (*PluginConfig, error) {
 54  	data, err := os.ReadFile(configPath)
 55  	if err != nil {
 56  		return nil, err
 57  	}
 58  
 59  	var config PluginConfig
 60  	err = yaml.Unmarshal(data, &config)
 61  	if err != nil {
 62  		return nil, err
 63  	}
 64  
 65  	return &config, nil
 66  }
 67  
 68  // 威胁级别常量
 69  type Level string
 70  
 71  const (
 72  	LevelLow      Level = "low"
 73  	LevelMedium   Level = "medium"
 74  	LevelHigh     Level = "high"
 75  	LevelCritical Level = "critical"
 76  )
 77  
 78  type MCPType string
 79  
 80  const (
 81  	MCPTypeCommand MCPType = "command"
 82  	MCPTypeSSE     MCPType = "sse"
 83  	MCPTypeSTREAM  MCPType = "stream"
 84  	MCPTypeCode    MCPType = "code"
 85  )
 86  
 87  // Issue 安全问题
 88  type Issue struct {
 89  	Title       string `json:"title"`
 90  	Description string `json:"description"`
 91  	Level       Level  `json:"level"`
 92  	Suggestion  string `json:"suggestion"`
 93  	RiskType    string `json:"risk_type"`
 94  }
 95  
 96  type McpInput struct {
 97  	Input string
 98  	Type  MCPType // 输入类型:命令行、SSE链接、Stream链接、MCP代码
 99  }
100  
101  type McpPluginConfig struct {
102  	Client       *client.Client
103  	CodePath     string
104  	McpStructure string
105  	AIModel      *models.OpenAI
106  	Language     string // zh / en
107  	Logger       *gologger.Logger
108  }
109  
110  // ExtractBatchResults 从文本中提取结果
111  func ParseIssues(input string) []Issue {
112  	var vulns []Issue
113  	// 解析漏洞数据的正则表达式
114  	var (
115  		blockRegex    = regexp.MustCompile(`(?s)<result>(.*?)</result>`)
116  		titleRegex    = regexp.MustCompile(`<title>(.*?)</title>`)
117  		descRegex     = regexp.MustCompile(`(?s)<desc>(.*?)</desc>`)
118  		levelRegex    = regexp.MustCompile(`<level>(.*?)</level>`)
119  		riskTypeRegex = regexp.MustCompile(`<risk_type>(.*?)</risk_type>`)
120  		suggesRegex   = regexp.MustCompile(`(?s)<suggestion>(.*?)</suggestion>`)
121  	)
122  	blocks := blockRegex.FindAllStringSubmatch(input, -1)
123  	for _, block := range blocks {
124  		var vuln Issue
125  		// 提取各个字段
126  		if title := titleRegex.FindStringSubmatch(block[1]); len(title) > 1 {
127  			vuln.Title = strings.TrimSpace(title[1])
128  		}
129  		if desc := descRegex.FindStringSubmatch(block[1]); len(desc) > 1 {
130  			vuln.Description = strings.TrimSpace(desc[1])
131  			if vuln.Description == "" {
132  				continue
133  			}
134  		}
135  		if level := levelRegex.FindStringSubmatch(block[1]); len(level) > 1 {
136  			vuln.Level = Level(strings.TrimSpace(level[1]))
137  		}
138  		if sugges := suggesRegex.FindStringSubmatch(block[1]); len(sugges) > 1 {
139  			vuln.Suggestion = strings.TrimSpace(sugges[1])
140  		}
141  		if riskType := riskTypeRegex.FindStringSubmatch(block[1]); len(riskType) > 1 {
142  			vuln.RiskType = strings.TrimSpace(riskType[1])
143  		}
144  		vulns = append(vulns, vuln)
145  	}
146  	return vulns
147  }
148  
149  func SummaryResult(ctx context.Context, agent utils.Agent, config *McpPluginConfig) ([]Issue, error) {
150  	history := agent.GetHistory()
151  	const summaryPrompt = `
152  The task is now complete, and the discovered vulnerabilities are being returned.
153  **Return Format**
154  All valid results must be wrapped in <arg> tags (e.g., <arg>[RESULTS]</arg>). 
155  If no vulnerabilities are found, return <arg></arg>.  
156  Multiple <result> entries are supported, but only vulnerabilities with severity levels critical, high, or medium should be included.
157  **Rules**
158  1. You must ensure that the vulnerability truly exists. if no vulnerability is found, return empty.
159  2. The desc field in the vulnerability description should include a detailed evidence chain for the vulnerability.
160  3. Determine the severity 'level'' of the vulnerability based on its title and description: critical, high, medium, low.
161  %s
162  **EXAMPLE**
163  1. if no vulnerabilities are found, return <arg></arg>.
164  2. if vulnerabilities are found, return:
165  <arg>
166  	<result>
167  	<title>Vulnerability Name</title>
168  	<desc>Detailed description in Markdown format, including code paths, file locations, code snippets, relevant context, and technical analysis (using professional terminology to explain the vulnerability's principle and potential impact).</desc>
169  	<risk_type>Vulnerability risk type</risk_type>
170  	<level>Severity level (critical, high, medium,low)</level>
171  	<suggestion>Step-by-step remediation guidance</suggestion>
172  	</result>
173  	<!-- Additional <result> entries can be added -->
174  </arg>
175  
176  **请注意,必须是漏洞输出详情,没有漏洞则只输出<arg></arg>**
177  `
178  	history = append(history, map[string]string{
179  		"role":    "user",
180  		"content": fmt.Sprintf(summaryPrompt, utils.LanguagePrompt(config.Language)),
181  	})
182  	var result string = ""
183  	config.Logger.Infoln("generate summary result")
184  	for word := range config.AIModel.ChatStream(ctx, history) {
185  		result += word
186  		config.Logger.Print(word)
187  	}
188  	history = append(history, map[string]string{
189  		"role":    "assistant",
190  		"content": result,
191  	})
192  	// 保存模型输出
193  	return ParseIssues(result), nil
194  }
195  
196  func SummaryChat(ctx context.Context, agent utils.Agent, config *McpPluginConfig, prompt string) (string, error) {
197  	history := agent.GetHistory()
198  	history = append(history, map[string]string{
199  		"role":    "user",
200  		"content": fmt.Sprintf(prompt, utils.LanguagePrompt(config.Language)),
201  	})
202  	var result string = ""
203  	config.Logger.Infoln("generate summary result")
204  	for word := range config.AIModel.ChatStream(ctx, history) {
205  		result += word
206  		config.Logger.Print(word)
207  	}
208  	history = append(history, map[string]string{
209  		"role":    "assistant",
210  		"content": result,
211  	})
212  	// 保存模型输出
213  	return result, nil
214  }
215  
216  func SummaryReport(ctx context.Context, agent utils.Agent, config *McpPluginConfig) (string, error) {
217  	prompt := `
218  You have performed a complete vulnerability scanning process on the target system but ultimately found no reportable vulnerabilities. Now, a brief technical analysis report explaining the reasons needs to be generated. Please output according to the following structure:
219  
220  # Task Role  
221  Cybersecurity Analysis Report Writing Expert  
222  
223  # Core Requirements  
224  1. Provide a structured explanation of the technical reasons why no vulnerabilities were found.  
225  2. Include an analysis of potential possibilities.  
226  3. Propose follow-up action recommendations.  
227  4. Use professional security terminology but avoid excessive jargon.  
228  
229  # Report Framework (Markdown Format)  
230  - Rephrase the core objective of the scan.  
231  - Briefly describe the scanning process and the key components covered (files/interfaces/code scope).  
232  - Reasons why no vulnerabilities were found:  
233    - Provide a detailed explanation of why no vulnerabilities were detected.  
234    - Explain possible reasons.  
235    - Discuss potential opportunities for vulnerability discovery.  
236  
237  **Return Format**  
238  All valid results must be wrapped in <arg> tags (e.g., <arg>[RESULTS]</arg>).  
239  If no vulnerabilities are found, return <arg></arg>.  
240  Multiple <result> entries are supported, but only vulnerabilities with severity levels critical, high, or medium should be included.  
241  
242  **EXAMPLE**  
243  <arg>  
244  	<result>  
245  	<title>No [Vulnerability Type] Found</title>  
246  	<desc>Technical analysis report content in Markdown format</desc>  
247  	</result>  
248  </arg>  
249  
250  If none, return:  
251  <arg></arg>
252  `
253  	return SummaryChat(ctx, agent, config, prompt)
254  }