server.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 // @title AI-Infra-Guard 任务API 20 // @version 1.0 21 // @description API for managing AI security scanning tasks 22 // @BasePath / 23 package websocket 24 25 import ( 26 "embed" 27 "mime" 28 "net/http" 29 "os" 30 "path/filepath" 31 "strings" 32 33 "github.com/Tencent/AI-Infra-Guard/common/trpc" 34 _ "github.com/Tencent/AI-Infra-Guard/docs" 35 version "github.com/Tencent/AI-Infra-Guard/internal/options" 36 "github.com/Tencent/AI-Infra-Guard/pkg/database" 37 "github.com/gin-gonic/gin" 38 swaggerFiles "github.com/swaggo/files" 39 ginSwagger "github.com/swaggo/gin-swagger" 40 "trpc.group/trpc-go/trpc-go/log" 41 ) 42 43 //go:embed static/* 44 var staticFS embed.FS 45 46 func RunWebServer(options *version.Options) { 47 // 1. 初始化trpc-go 48 if err := trpc.InitTrpc("./trpc_go.yaml"); err != nil { 49 log.Fatalf("Trpc-go初始化失败: %v", err) 50 } 51 log.Infof("Trpc-go initialized successfully: trace_id=system_startup") 52 53 r := gin.Default() 54 // 2. 添加中间件 55 //r.Use(middleware.TrpcMiddleware()) 56 //r.Use(middleware.RequestLoggerMiddleware()) // 添加请求参数日志中间件 57 // r.Use(middleware.MetricsMiddleware()) // 移除HTTP监控中间件,依赖TRPC自动监控 58 59 // 3. 初始化数据库和Agentmanager 60 dbConfig := database.LoadConfigFromEnv() // 从环境变量加载数据库配置 61 db, err := database.InitDB(dbConfig) 62 if err != nil { 63 log.Errorf("数据库初始化失败: trace_id=system_startup, error=%v", err) 64 65 } 66 taskStore := database.NewTaskStore(db) 67 if err := taskStore.Init(); err != nil { 68 log.Errorf("初始化tasks表失败: trace_id=system_startup, error=%v", err) 69 log.Fatalf("初始化tasks表失败: %v", err) 70 } 71 72 // 初始化模型存储 73 modelStore := database.NewModelStore(db) 74 if err := modelStore.Init(); err != nil { 75 log.Errorf("初始化models表失败: trace_id=system_startup, error=%v", err) 76 77 } 78 // 自动添加模型 79 modelStore.AutoAddModels() 80 81 // 初始化AgentManager 82 agentManager := NewAgentManager() 83 84 // 初始化ModelManager 85 modelManager := NewModelManager(modelStore) 86 87 // 初始化文件上传配置(支持环境变量) 88 fileConfig := LoadFileUploadConfigFromEnv() 89 90 // 验证文件上传配置 91 if err := fileConfig.ValidateConfig(); err != nil { 92 log.Errorf("文件上传配置验证失败: trace_id=system_startup, error=%v", err) 93 94 } 95 96 // 初始化SSE管理器 97 sseManager := NewSSEManager() 98 99 taskManager := NewTaskManager(agentManager, taskStore, modelStore, fileConfig, sseManager) 100 err = taskManager.taskStore.ResetRunningTasks() 101 if err != nil { 102 log.Fatalf("重置运行中的任务失败: %v", err) 103 } 104 105 // 将 TaskManager 注入到 AgentManager 106 agentManager.SetTaskManager(taskManager) 107 108 // API 版本分组 109 v1 := r.Group("/api/v1") 110 { 111 v1.GET("/images/:path", func(context *gin.Context) { 112 path := context.Param("path") 113 if strings.Contains(path, "..") { 114 context.String(403, "Forbidden") 115 return 116 } 117 context.File(filepath.Join("uploads", path)) 118 }) 119 // 1. 知识库模块 120 knowledge := v1.Group("/knowledge") 121 knowledge.Use(setupIdentityMiddleware()) 122 { 123 // AI应用指纹 124 fingerprints := knowledge.Group("/fingerprints") 125 { 126 // 管理功能 127 fingerprints.GET("", HandleListFingerprints) 128 fingerprints.POST("", HandleCreateFingerprint) 129 fingerprints.PUT("/:name", HandleEditFingerprint) 130 fingerprints.DELETE("", HandleDeleteFingerprint) 131 } 132 // 漏洞库 133 vulnerabilities := knowledge.Group("/vulnerabilities") 134 { 135 // 管理功能 136 vulnerabilities.GET("", HandleListVulnerabilities()) 137 vulnerabilities.POST("", HandleCreateVulnerability()) 138 vulnerabilities.PUT("/:cve", HandleEditVulnerability) 139 vulnerabilities.DELETE("", HandleBatchDeleteVulnerabilities) 140 } 141 // 评测集 142 evaluations := knowledge.Group("/evaluations") 143 { 144 // 管理功能 145 evaluations.GET("/:name", HandleGetEvaluationDetail) 146 evaluations.GET("", HandleListEvaluations) 147 evaluations.POST("", HandleCreateEvaluation) 148 evaluations.PUT("/:name", HandleEditEvaluation) 149 evaluations.DELETE("", HandleDeleteEvaluation) 150 } 151 // MCP 152 mcp := knowledge.Group("/mcp") 153 { 154 mcp.GET("names", GetMcpPluginList) 155 mcp.GET("", HandleList(MCPROOT, McpLoadFile)) 156 mcp.POST("", HandleCreate(mcpReadAndSave)) 157 mcp.PUT("/:id", HandleEdit(mcpUpdateFunc)) 158 mcp.DELETE("/:id", HandleDelete(mcpDeleteFunc)) 159 } 160 // Prompt Collections 161 collections := knowledge.Group("/prompt_collections") 162 { 163 collections.GET("", HandleList(PromptCollectionsRoot, promptCollectionLoadFile)) 164 collections.POST("", HandleCreate(promptCollectionReadAndSave)) 165 collections.PUT("/:id", HandleEdit(promptCollectionUpdateFunc)) 166 collections.DELETE("", HandleDelete(promptCollectionDeleteFunc)) 167 } 168 agentConfigs := knowledge.Group("/agent") 169 { 170 agentConfigs.GET("/names", HandleListAgentNames) 171 agentConfigs.GET("/:name", HandleGetAgentConfig) 172 agentConfigs.POST("/:name", HandleSaveAgentConfig) 173 agentConfigs.DELETE("/:name", HandleDeleteAgentConfig) 174 agentConfigs.POST("/connect", HandleAgentConnect) 175 agentConfigs.POST("/prompt_test", HandleAgentPromptTest) 176 agentConfigs.GET("/template", HandleAgentTemplate) 177 } 178 // 算子列表 179 knowledge.GET("/jailbreak", GetJailBreak) 180 } 181 appSecurity := v1.Group("/app") 182 { 183 appSecurity.Use(setupIdentityMiddleware()) 184 // 任务管理 185 tasks := appSecurity.Group("/tasks") 186 { 187 // 获取任务列表接口 188 tasks.GET("", func(c *gin.Context) { 189 HandleGetTaskList(c, taskManager) 190 }) 191 // 获取任务详情接口 192 tasks.GET("/:sessionId", func(c *gin.Context) { 193 HandleGetTaskDetail(c, taskManager) 194 }) 195 // 分享任务接口 196 tasks.POST("/share", func(c *gin.Context) { 197 HandleShare(c, taskManager) 198 }) 199 // SSE接口 200 tasks.GET("/sse/:sessionId", func(c *gin.Context) { 201 HandleTaskSSE(c, taskManager) 202 }) 203 // 新建任务接口 204 tasks.POST("", func(c *gin.Context) { 205 HandleTaskCreate(c, taskManager) 206 }) 207 // 文件上传接口(完整文件上传) 208 tasks.POST("/uploadFile", func(c *gin.Context) { 209 HandleUploadFile(c, taskManager) 210 }) 211 // 分片上传接口 212 tasks.POST("/uploadChunk", func(c *gin.Context) { 213 HandleUploadFileChunk(c, taskManager) 214 }) 215 // 合并分片接口 216 tasks.POST("/mergeChunks", func(c *gin.Context) { 217 HandleMergeFileChunks(c, taskManager) 218 }) 219 // 文件下载接口 220 tasks.POST("/:sessionId/downloadFile", func(c *gin.Context) { 221 HandleDownloadFile(c, taskManager) 222 }) 223 // 编辑任务接口 224 tasks.PUT("/:sessionId", func(c *gin.Context) { 225 HandleUpdateTask(c, taskManager) 226 }) 227 // 删除任务接口 228 tasks.DELETE("/:sessionId", func(c *gin.Context) { 229 HandleDeleteTask(c, taskManager) 230 }) 231 // 终止任务接口 232 tasks.POST("/:sessionId/terminate", func(c *gin.Context) { 233 HandleTerminateTask(c, taskManager) 234 }) 235 } 236 // 模型管理 237 models := appSecurity.Group("/models") 238 { 239 // 获取模型列表接口 240 models.GET("", func(c *gin.Context) { 241 HandleGetModelList(c, modelManager) 242 }) 243 // 获取模型详情接口 244 models.GET("/:modelId", func(c *gin.Context) { 245 HandleGetModelDetail(c, modelManager) 246 }) 247 // 创建模型接口 248 models.POST("", func(c *gin.Context) { 249 HandleCreateModel(c, modelManager) 250 }) 251 // 更新模型接口 252 models.PUT("/:modelId", func(c *gin.Context) { 253 HandleUpdateModel(c, modelManager) 254 }) 255 // 删除模型接口(支持单个和批量) 256 models.DELETE("", func(c *gin.Context) { 257 HandleDeleteModel(c, modelManager) 258 }) 259 } 260 } 261 // 4. Agent 管理 262 agents := v1.Group("/agents") 263 { 264 // 只需要WebSocket入口 265 agents.GET("/ws", agentManager.HandleAgentWebSocket()) 266 } 267 // 提供给第三方的api 268 taskApi := appSecurity.Group("/taskapi") 269 { 270 // 创建任务 271 taskApi.POST("/tasks", func(c *gin.Context) { 272 SubmitTask(c, taskManager) 273 }) 274 // 获取任务状态 275 taskApi.GET("/status/:id", func(c *gin.Context) { 276 GetTaskStatus(c, taskManager) 277 }) 278 // 获取任务结果 279 taskApi.GET("/result/:id", func(c *gin.Context) { 280 GetTaskResult(c, taskManager) 281 }) 282 taskApi.POST("/upload", func(c *gin.Context) { 283 HandleUploadFile(c, taskManager) 284 }) 285 // 分片上传接口 286 taskApi.POST("/uploadChunk", func(c *gin.Context) { 287 HandleUploadFileChunk(c, taskManager) 288 }) 289 // 合并分片接口 290 taskApi.POST("/mergeChunks", func(c *gin.Context) { 291 HandleMergeFileChunks(c, taskManager) 292 }) 293 } 294 // version 295 v1.GET("/version", func(c *gin.Context) { 296 filename := "CHANGELOG.md" 297 data, err := os.ReadFile(filename) 298 if err != nil { 299 data = []byte("") 300 } 301 c.JSON(http.StatusOK, gin.H{ 302 "version": version.GetVersion(), 303 "changelog": string(data), 304 }) 305 }) 306 307 // system — data directory auto-sync 308 system := v1.Group("/system") 309 system.Use(setupIdentityMiddleware()) 310 { 311 system.POST("/update-data", HandleTriggerDataUpdate) 312 system.GET("/update-data", HandleGetUpdateStatus) 313 } 314 } 315 316 // Swagger UI - 必须在 NoRoute 之前注册 317 r.GET("/docs/*any", func(c *gin.Context) { 318 if c.Request.URL.Path == "/docs/" { 319 c.Redirect(302, "/docs/index.html") 320 } else { 321 ginSwagger.WrapHandler(swaggerFiles.Handler)(c) 322 } 323 }) 324 325 // 静态文件处理 326 r.NoRoute(func(c *gin.Context) { 327 assetPath := "static" + c.Request.URL.Path 328 if c.Request.URL.Path == "/" { 329 assetPath = "static/index.html" 330 } 331 332 assetData, err := staticFS.ReadFile(assetPath) 333 if err != nil { 334 assetData, err = staticFS.ReadFile("static/index.html") 335 if err != nil { 336 c.String(500, "Internal Server Error") 337 return 338 } 339 c.Header("Content-Type", "text/html") 340 c.Data(200, "text/html", assetData) 341 return 342 } 343 344 mimeType := mime.TypeByExtension(filepath.Ext(assetPath)) 345 if mimeType == "" { 346 mimeType = "text/plain" 347 } 348 c.Header("Content-Type", mimeType) 349 c.Data(200, mimeType, assetData) 350 }) 351 352 log.Infof("Starting WebServer: trace_id=system_startup, addr=%s", options.WebServerAddr) 353 if err := r.Run(options.WebServerAddr); err != nil { 354 log.Errorf("Could not start WebSocket server: trace_id=system_startup, error=%s", err) 355 } 356 } 357 358 // 配置身份认证中间件 359 func setupIdentityMiddleware() gin.HandlerFunc { 360 return func(c *gin.Context) { 361 // 优先从请求头获取username字段 362 username := c.GetHeader("username") 363 364 // 如果都没有,使用默认的公共用户 365 if username == "" { 366 username = "public_user" 367 } 368 // 存储到gin上下文 369 c.Set("username", username) 370 c.Next() 371 } 372 }