/ common / websocket / server.go
server.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  // @title AI-Infra-Guard 任务API
 20  // @version 1.0
 21  // @description API for managing AI security scanning tasks
 22  // @BasePath /
 23  package websocket
 24  
 25  import (
 26  	"embed"
 27  	"mime"
 28  	"net/http"
 29  	"os"
 30  	"path/filepath"
 31  	"strings"
 32  
 33  	"github.com/Tencent/AI-Infra-Guard/common/trpc"
 34  	_ "github.com/Tencent/AI-Infra-Guard/docs"
 35  	version "github.com/Tencent/AI-Infra-Guard/internal/options"
 36  	"github.com/Tencent/AI-Infra-Guard/pkg/database"
 37  	"github.com/gin-gonic/gin"
 38  	swaggerFiles "github.com/swaggo/files"
 39  	ginSwagger "github.com/swaggo/gin-swagger"
 40  	"trpc.group/trpc-go/trpc-go/log"
 41  )
 42  
 43  //go:embed static/*
 44  var staticFS embed.FS
 45  
 46  func RunWebServer(options *version.Options) {
 47  	// 1. 初始化trpc-go
 48  	if err := trpc.InitTrpc("./trpc_go.yaml"); err != nil {
 49  		log.Fatalf("Trpc-go初始化失败: %v", err)
 50  	}
 51  	log.Infof("Trpc-go initialized successfully: trace_id=system_startup")
 52  
 53  	r := gin.Default()
 54  	// 2. 添加中间件
 55  	//r.Use(middleware.TrpcMiddleware())
 56  	//r.Use(middleware.RequestLoggerMiddleware()) // 添加请求参数日志中间件
 57  	// r.Use(middleware.MetricsMiddleware()) // 移除HTTP监控中间件,依赖TRPC自动监控
 58  
 59  	// 3. 初始化数据库和Agentmanager
 60  	dbConfig := database.LoadConfigFromEnv() // 从环境变量加载数据库配置
 61  	db, err := database.InitDB(dbConfig)
 62  	if err != nil {
 63  		log.Errorf("数据库初始化失败: trace_id=system_startup, error=%v", err)
 64  
 65  	}
 66  	taskStore := database.NewTaskStore(db)
 67  	if err := taskStore.Init(); err != nil {
 68  		log.Errorf("初始化tasks表失败: trace_id=system_startup, error=%v", err)
 69  		log.Fatalf("初始化tasks表失败: %v", err)
 70  	}
 71  
 72  	// 初始化模型存储
 73  	modelStore := database.NewModelStore(db)
 74  	if err := modelStore.Init(); err != nil {
 75  		log.Errorf("初始化models表失败: trace_id=system_startup, error=%v", err)
 76  
 77  	}
 78  	// 自动添加模型
 79  	modelStore.AutoAddModels()
 80  
 81  	// 初始化AgentManager
 82  	agentManager := NewAgentManager()
 83  
 84  	// 初始化ModelManager
 85  	modelManager := NewModelManager(modelStore)
 86  
 87  	// 初始化文件上传配置(支持环境变量)
 88  	fileConfig := LoadFileUploadConfigFromEnv()
 89  
 90  	// 验证文件上传配置
 91  	if err := fileConfig.ValidateConfig(); err != nil {
 92  		log.Errorf("文件上传配置验证失败: trace_id=system_startup, error=%v", err)
 93  
 94  	}
 95  
 96  	// 初始化SSE管理器
 97  	sseManager := NewSSEManager()
 98  
 99  	taskManager := NewTaskManager(agentManager, taskStore, modelStore, fileConfig, sseManager)
100  	err = taskManager.taskStore.ResetRunningTasks()
101  	if err != nil {
102  		log.Fatalf("重置运行中的任务失败: %v", err)
103  	}
104  
105  	// 将 TaskManager 注入到 AgentManager
106  	agentManager.SetTaskManager(taskManager)
107  
108  	// API 版本分组
109  	v1 := r.Group("/api/v1")
110  	{
111  		v1.GET("/images/:path", func(context *gin.Context) {
112  			path := context.Param("path")
113  			if strings.Contains(path, "..") {
114  				context.String(403, "Forbidden")
115  				return
116  			}
117  			context.File(filepath.Join("uploads", path))
118  		})
119  		// 1. 知识库模块
120  		knowledge := v1.Group("/knowledge")
121  		knowledge.Use(setupIdentityMiddleware())
122  		{
123  			// AI应用指纹
124  			fingerprints := knowledge.Group("/fingerprints")
125  			{
126  				// 管理功能
127  				fingerprints.GET("", HandleListFingerprints)
128  				fingerprints.POST("", HandleCreateFingerprint)
129  				fingerprints.PUT("/:name", HandleEditFingerprint)
130  				fingerprints.DELETE("", HandleDeleteFingerprint)
131  			}
132  			// 漏洞库
133  			vulnerabilities := knowledge.Group("/vulnerabilities")
134  			{
135  				// 管理功能
136  				vulnerabilities.GET("", HandleListVulnerabilities())
137  				vulnerabilities.POST("", HandleCreateVulnerability())
138  				vulnerabilities.PUT("/:cve", HandleEditVulnerability)
139  				vulnerabilities.DELETE("", HandleBatchDeleteVulnerabilities)
140  			}
141  			// 评测集
142  			evaluations := knowledge.Group("/evaluations")
143  			{
144  				// 管理功能
145  				evaluations.GET("/:name", HandleGetEvaluationDetail)
146  				evaluations.GET("", HandleListEvaluations)
147  				evaluations.POST("", HandleCreateEvaluation)
148  				evaluations.PUT("/:name", HandleEditEvaluation)
149  				evaluations.DELETE("", HandleDeleteEvaluation)
150  			}
151  			// MCP
152  			mcp := knowledge.Group("/mcp")
153  			{
154  				mcp.GET("names", GetMcpPluginList)
155  				mcp.GET("", HandleList(MCPROOT, McpLoadFile))
156  				mcp.POST("", HandleCreate(mcpReadAndSave))
157  				mcp.PUT("/:id", HandleEdit(mcpUpdateFunc))
158  				mcp.DELETE("/:id", HandleDelete(mcpDeleteFunc))
159  			}
160  			// Prompt Collections
161  			collections := knowledge.Group("/prompt_collections")
162  			{
163  				collections.GET("", HandleList(PromptCollectionsRoot, promptCollectionLoadFile))
164  				collections.POST("", HandleCreate(promptCollectionReadAndSave))
165  				collections.PUT("/:id", HandleEdit(promptCollectionUpdateFunc))
166  				collections.DELETE("", HandleDelete(promptCollectionDeleteFunc))
167  			}
168  			agentConfigs := knowledge.Group("/agent")
169  			{
170  				agentConfigs.GET("/names", HandleListAgentNames)
171  				agentConfigs.GET("/:name", HandleGetAgentConfig)
172  				agentConfigs.POST("/:name", HandleSaveAgentConfig)
173  				agentConfigs.DELETE("/:name", HandleDeleteAgentConfig)
174  				agentConfigs.POST("/connect", HandleAgentConnect)
175  				agentConfigs.POST("/prompt_test", HandleAgentPromptTest)
176  				agentConfigs.GET("/template", HandleAgentTemplate)
177  			}
178  			// 算子列表
179  			knowledge.GET("/jailbreak", GetJailBreak)
180  		}
181  		appSecurity := v1.Group("/app")
182  		{
183  			appSecurity.Use(setupIdentityMiddleware())
184  			// 任务管理
185  			tasks := appSecurity.Group("/tasks")
186  			{
187  				// 获取任务列表接口
188  				tasks.GET("", func(c *gin.Context) {
189  					HandleGetTaskList(c, taskManager)
190  				})
191  				// 获取任务详情接口
192  				tasks.GET("/:sessionId", func(c *gin.Context) {
193  					HandleGetTaskDetail(c, taskManager)
194  				})
195  				// 分享任务接口
196  				tasks.POST("/share", func(c *gin.Context) {
197  					HandleShare(c, taskManager)
198  				})
199  				// SSE接口
200  				tasks.GET("/sse/:sessionId", func(c *gin.Context) {
201  					HandleTaskSSE(c, taskManager)
202  				})
203  				// 新建任务接口
204  				tasks.POST("", func(c *gin.Context) {
205  					HandleTaskCreate(c, taskManager)
206  				})
207  				// 文件上传接口(完整文件上传)
208  				tasks.POST("/uploadFile", func(c *gin.Context) {
209  					HandleUploadFile(c, taskManager)
210  				})
211  				// 分片上传接口
212  				tasks.POST("/uploadChunk", func(c *gin.Context) {
213  					HandleUploadFileChunk(c, taskManager)
214  				})
215  				// 合并分片接口
216  				tasks.POST("/mergeChunks", func(c *gin.Context) {
217  					HandleMergeFileChunks(c, taskManager)
218  				})
219  				// 文件下载接口
220  				tasks.POST("/:sessionId/downloadFile", func(c *gin.Context) {
221  					HandleDownloadFile(c, taskManager)
222  				})
223  				// 编辑任务接口
224  				tasks.PUT("/:sessionId", func(c *gin.Context) {
225  					HandleUpdateTask(c, taskManager)
226  				})
227  				// 删除任务接口
228  				tasks.DELETE("/:sessionId", func(c *gin.Context) {
229  					HandleDeleteTask(c, taskManager)
230  				})
231  				// 终止任务接口
232  				tasks.POST("/:sessionId/terminate", func(c *gin.Context) {
233  					HandleTerminateTask(c, taskManager)
234  				})
235  			}
236  			// 模型管理
237  			models := appSecurity.Group("/models")
238  			{
239  				// 获取模型列表接口
240  				models.GET("", func(c *gin.Context) {
241  					HandleGetModelList(c, modelManager)
242  				})
243  				// 获取模型详情接口
244  				models.GET("/:modelId", func(c *gin.Context) {
245  					HandleGetModelDetail(c, modelManager)
246  				})
247  				// 创建模型接口
248  				models.POST("", func(c *gin.Context) {
249  					HandleCreateModel(c, modelManager)
250  				})
251  				// 更新模型接口
252  				models.PUT("/:modelId", func(c *gin.Context) {
253  					HandleUpdateModel(c, modelManager)
254  				})
255  				// 删除模型接口(支持单个和批量)
256  				models.DELETE("", func(c *gin.Context) {
257  					HandleDeleteModel(c, modelManager)
258  				})
259  			}
260  		}
261  		// 4. Agent 管理
262  		agents := v1.Group("/agents")
263  		{
264  			// 只需要WebSocket入口
265  			agents.GET("/ws", agentManager.HandleAgentWebSocket())
266  		}
267  		// 提供给第三方的api
268  		taskApi := appSecurity.Group("/taskapi")
269  		{
270  			// 创建任务
271  			taskApi.POST("/tasks", func(c *gin.Context) {
272  				SubmitTask(c, taskManager)
273  			})
274  			// 获取任务状态
275  			taskApi.GET("/status/:id", func(c *gin.Context) {
276  				GetTaskStatus(c, taskManager)
277  			})
278  			// 获取任务结果
279  			taskApi.GET("/result/:id", func(c *gin.Context) {
280  				GetTaskResult(c, taskManager)
281  			})
282  			taskApi.POST("/upload", func(c *gin.Context) {
283  				HandleUploadFile(c, taskManager)
284  			})
285  			// 分片上传接口
286  			taskApi.POST("/uploadChunk", func(c *gin.Context) {
287  				HandleUploadFileChunk(c, taskManager)
288  			})
289  			// 合并分片接口
290  			taskApi.POST("/mergeChunks", func(c *gin.Context) {
291  				HandleMergeFileChunks(c, taskManager)
292  			})
293  		}
294  		// version
295  		v1.GET("/version", func(c *gin.Context) {
296  			filename := "CHANGELOG.md"
297  			data, err := os.ReadFile(filename)
298  			if err != nil {
299  				data = []byte("")
300  			}
301  			c.JSON(http.StatusOK, gin.H{
302  				"version":   version.GetVersion(),
303  				"changelog": string(data),
304  			})
305  		})
306  
307  		// system — data directory auto-sync
308  		system := v1.Group("/system")
309  		system.Use(setupIdentityMiddleware())
310  		{
311  			system.POST("/update-data", HandleTriggerDataUpdate)
312  			system.GET("/update-data", HandleGetUpdateStatus)
313  		}
314  	}
315  
316  	// Swagger UI - 必须在 NoRoute 之前注册
317  	r.GET("/docs/*any", func(c *gin.Context) {
318  		if c.Request.URL.Path == "/docs/" {
319  			c.Redirect(302, "/docs/index.html")
320  		} else {
321  			ginSwagger.WrapHandler(swaggerFiles.Handler)(c)
322  		}
323  	})
324  
325  	// 静态文件处理
326  	r.NoRoute(func(c *gin.Context) {
327  		assetPath := "static" + c.Request.URL.Path
328  		if c.Request.URL.Path == "/" {
329  			assetPath = "static/index.html"
330  		}
331  
332  		assetData, err := staticFS.ReadFile(assetPath)
333  		if err != nil {
334  			assetData, err = staticFS.ReadFile("static/index.html")
335  			if err != nil {
336  				c.String(500, "Internal Server Error")
337  				return
338  			}
339  			c.Header("Content-Type", "text/html")
340  			c.Data(200, "text/html", assetData)
341  			return
342  		}
343  
344  		mimeType := mime.TypeByExtension(filepath.Ext(assetPath))
345  		if mimeType == "" {
346  			mimeType = "text/plain"
347  		}
348  		c.Header("Content-Type", mimeType)
349  		c.Data(200, mimeType, assetData)
350  	})
351  
352  	log.Infof("Starting WebServer: trace_id=system_startup, addr=%s", options.WebServerAddr)
353  	if err := r.Run(options.WebServerAddr); err != nil {
354  		log.Errorf("Could not start WebSocket server: trace_id=system_startup, error=%s", err)
355  	}
356  }
357  
358  // 配置身份认证中间件
359  func setupIdentityMiddleware() gin.HandlerFunc {
360  	return func(c *gin.Context) {
361  		// 优先从请求头获取username字段
362  		username := c.GetHeader("username")
363  
364  		// 如果都没有,使用默认的公共用户
365  		if username == "" {
366  			username = "public_user"
367  		}
368  		// 存储到gin上下文
369  		c.Set("username", username)
370  		c.Next()
371  	}
372  }