api.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 // Package websocket provides API endpoints for AI Infrastructure Guard task management 20 // 21 // This package implements RESTful APIs for: 22 // - Task submission and management 23 // - Task status monitoring 24 // - Task result retrieval 25 // - Support for multiple task types: MCP scan, AI infra scan, and model redteam testing 26 // 27 // API Endpoints: 28 // - POST /api/v1/app/taskapi/tasks - Create new tasks 29 // - GET /api/v1/app/taskapi/status/{id} - Get task status and logs 30 // - GET /api/v1/app/taskapi/result/{id} - Get task results 31 package websocket 32 33 import ( 34 "encoding/json" 35 "net/http" 36 "strings" 37 "time" 38 39 "github.com/Tencent/AI-Infra-Guard/common/agent" 40 "github.com/Tencent/AI-Infra-Guard/pkg/database" 41 "github.com/gin-gonic/gin" 42 "github.com/google/uuid" 43 "trpc.group/trpc-go/trpc-go/log" 44 ) 45 46 // ModelParams represents model configuration parameters 47 type ModelParams struct { 48 BaseUrl string `json:"base_url" example:"https://api.openai.com/v1"` // Model API base URL 49 Token string `json:"token" example:"sk-xxx"` // API access token 50 Model string `json:"model" example:"gpt-4"` // Model name 51 Limit int `json:"limit,omitempty" example:"1000"` // Request limit 52 } 53 54 // MCPTaskRequest represents MCP task request structure 55 // @Description MCP (Model Context Protocol) security scan task parameters 56 type MCPTaskRequest struct { 57 Prompt string `json:"prompt,omitempty" example:"Enter a URL for remote MCP scan, or leave empty for source-code scan"` // Scan description or MCP server URL 58 Model struct { 59 Model string `json:"model" example:"gpt-4"` // Model name - required 60 Token string `json:"token" example:"sk-xxx"` // API key - required 61 BaseUrl string `json:"base_url,omitempty" example:"https://api.openai.com/v1"` // Base URL - optional 62 } `json:"model"` // Model configuration - required 63 Thread int `json:"thread,omitempty" example:"4"` // Concurrent thread count 64 Language string `json:"language,omitempty" example:"zh"` // Language code - optional 65 Attachments string `json:"attachments,omitempty" example:"file1.zip"` // Attachment file path (upload first) 66 Headers map[string]string `json:"headers,omitempty" example:"{\"Authorization\":\"Bearer token\"}"` 67 } 68 69 // AIInfraScanTaskRequest represents AI infrastructure scan task request structure 70 // @Description AI infrastructure security scan task parameters: target URLs, custom headers, and optional model config for result analysis 71 type AIInfraScanTaskRequest struct { 72 Target []string `json:"target" example:"https://example.com"` // List of scan target URLs 73 Headers map[string]string `json:"headers" example:"{\"Authorization\":\"Bearer token\"}"` // Custom request headers 74 Timeout int `json:"timeout" example:"30"` // Request timeout in seconds 75 Model struct { 76 Model string `json:"model" binding:"required" example:"gpt-4"` // Model name - required 77 Token string `json:"token" binding:"required" example:"sk-xxx"` // API key - required 78 BaseUrl string `json:"base_url,omitempty" example:"https://api.openai.com/v1"` // Base URL - optional 79 } `json:"model,omitempty"` // Model configuration - optional, used for assisted vulnerability analysis 80 } 81 82 // PromptSecurityTaskRequest represents prompt security test task request structure 83 // @Description Prompt security (red team) task parameters. Supports dataset selection or manual prompt input. 84 // @Description Supported datasets: 85 // @Description - JailBench-Tiny: small jailbreak benchmark dataset 86 // @Description - JailbreakPrompts-Tiny: small jailbreak prompt dataset 87 // @Description - ChatGPT-Jailbreak-Prompts: ChatGPT jailbreak prompt dataset 88 // @Description - JADE-db-v3.0: JADE database v3.0 89 // @Description - HarmfulEvalBenchmark: harmful content evaluation benchmark dataset 90 type PromptSecurityTaskRequest struct { 91 Model []ModelParams `json:"model"` // List of models under test 92 EvalModel ModelParams `json:"eval_model"` // Evaluation model configuration 93 Datasets struct { 94 DataFile []string `json:"dataFile" example:"[\"JailBench-Tiny\",\"JailbreakPrompts-Tiny\"]"` // Dataset file list 95 NumPrompts int `json:"numPrompts" example:"100"` // Number of prompts 96 RandomSeed int `json:"randomSeed" example:"42"` // Random seed 97 } `json:"dataset"` // Dataset configuration 98 Prompt string `json:"prompt"` // Custom test prompt - optional 99 Techniques []string `json:"techniques"` // Attack technique list - optional 100 } 101 102 // AgentScanTaskRequest represents Agent security scan task request structure 103 // @Description Agent security scan task parameters. agent_id and agent_config are mutually exclusive: 104 // agent_id references a config pre-saved on the server; agent_config passes YAML content inline without prior saving. 105 type AgentScanTaskRequest struct { 106 AgentID string `json:"agent_id,omitempty" example:"demo-agent"` // Agent config name (mutually exclusive with agent_config) 107 AgentConfig string `json:"agent_config,omitempty" example:"provider: dify\nbase_url: ..."` // Inline YAML config content (mutually exclusive with agent_id) 108 EvalModel ModelParams `json:"eval_model"` // Evaluation model config - optional, falls back to system default 109 Language string `json:"language,omitempty" example:"zh"` // Language code - optional 110 Prompt string `json:"prompt,omitempty" example:"Focus on privilege escalation and data leakage risks"` // Additional scan instructions - optional 111 } 112 113 // APIResponse is the common API response structure 114 type APIResponse struct { 115 Status int `json:"status" example:"0"` // Status code: 0=success, 1=failure 116 Message string `json:"message" example:"ok"` // Response message 117 Data interface{} `json:"data"` // Response data 118 } 119 120 // TaskStatusResponse holds the task status response 121 type TaskStatusResponse struct { 122 SessionID string `json:"session_id" example:"550e8400-e29b-41d4-a716-446655440000"` // Task session ID 123 Status string `json:"status" example:"running"` // Task status: pending, running, completed, failed 124 Title string `json:"title" example:"MCP Scan Task"` // Task title 125 CreatedAt int64 `json:"created_at" example:"1640995200000"` // Creation timestamp (ms) 126 UpdatedAt int64 `json:"updated_at" example:"1640995200000"` // Last update timestamp (ms) 127 Log string `json:"log" example:"Task execution log..."` // Task execution log 128 } 129 130 // TaskCreateResponse holds the task creation response 131 type TaskCreateResponse struct { 132 SessionID string `json:"session_id" example:"550e8400-e29b-41d4-a716-446655440000"` // Task session ID 133 } 134 135 func resolveTaskAPIUsername(c *gin.Context) string { 136 username := strings.TrimSpace(c.GetString("api_user")) 137 if username != "" { 138 return username 139 } 140 141 username = strings.TrimSpace(c.GetHeader("username")) 142 if username != "" { 143 return username 144 } 145 146 return "api_user" 147 } 148 149 func resolveDefaultTaskAPIModel(tm *TaskManager, username string) (*database.ModelParams, error) { 150 if tm == nil || tm.modelStore == nil { 151 return nil, nil 152 } 153 154 models, err := tm.modelStore.GetUserModels(username) 155 if err != nil { 156 return nil, err 157 } 158 if len(models) == 0 { 159 return nil, nil 160 } 161 162 model := models[0] 163 return &database.ModelParams{ 164 Model: model.ModelName, 165 Token: model.Token, 166 BaseUrl: model.BaseURL, 167 Limit: model.Limit, 168 }, nil 169 } 170 171 // SubmitTask is the task creation handler 172 // @Summary Create a new task 173 // @Description Submit a new task for processing. Supports three types of tasks: 174 // @Description 1. MCP Scan (mcp_scan): Model Context Protocol security scanning 175 // @Description 2. AI Infra Scan (ai_infra_scan): AI infrastructure security scanning 176 // @Description 3. Model Redteam Report (model_redteam_report): AI model red team testing 177 // @Description 178 // @Description Request Body Examples: 179 // @Description 180 // @Description MCP Scan Task: 181 // @Description { 182 // @Description "type": "mcp_scan", 183 // @Description "content": { 184 // @Description "prompt": "Custom prompt for scan", 185 // @Description "model": { 186 // @Description "model": "gpt-4", 187 // @Description "token": "sk-xxx", 188 // @Description "base_url": "https://api.openai.com/v1" 189 // @Description }, 190 // @Description "thread": 4, 191 // @Description "language": "zh", 192 // @Description "attachments": "file.zip", 193 // @Description "headers": { 194 // @Description "Authorization": "Bearer token" 195 // @Description } 196 // @Description } 197 // @Description } 198 // @Description 199 // @Description AI Infra Scan Task: 200 // @Description { 201 // @Description "type": "ai_infra_scan", 202 // @Description "content": { 203 // @Description "target": ["https://example.com"], 204 // @Description "headers": { 205 // @Description "Authorization": "Bearer token" 206 // @Description }, 207 // @Description "timeout": 30, 208 // @Description "model": { 209 // @Description "model": "gpt-4", 210 // @Description "token": "sk-xxx", 211 // @Description "base_url": "https://api.openai.com/v1" 212 // @Description } 213 // @Description } 214 // @Description } 215 // @Description 216 // @Description Model Redteam Task: 217 // @Description { 218 // @Description "type": "model_redteam_report", 219 // @Description "content": { 220 // @Description "model": [{ 221 // @Description "model": "gpt-4", 222 // @Description "token": "sk-xxx", 223 // @Description "base_url": "https://api.openai.com/v1" 224 // @Description }], 225 // @Description "eval_model": { 226 // @Description "model": "gpt-4", 227 // @Description "token": "sk-xxx" 228 // @Description }, 229 // @Description "dataset": { 230 // @Description "dataFile": ["JailBench-Tiny", "JailbreakPrompts-Tiny"], 231 // @Description "numPrompts": 100, 232 // @Description "randomSeed": 42 233 // @Description }, 234 // @Description "prompt": "How to make a bomb?", 235 // @Description "techniques": [""] 236 // @Description } 237 // @Description } 238 // @Tags taskapi 239 // @Accept json 240 // @Produce json 241 // @Param request body object{content=object,type=string} true "Task request body. Content should be JSON object containing task-specific parameters based on type" 242 // @Success 200 {object} APIResponse{data=TaskCreateResponse} "Task created successfully" 243 // @Failure 400 {object} APIResponse "Invalid request parameters" 244 // @Failure 500 {object} APIResponse "Internal server error" 245 // @Router /api/v1/app/taskapi/tasks [post] 246 func SubmitTask(c *gin.Context, tm *TaskManager) { 247 var content struct { 248 Content json.RawMessage `json:"content"` 249 Type string `json:"type"` 250 } 251 if err := c.ShouldBindJSON(&content); err != nil { 252 c.JSON(http.StatusOK, gin.H{ 253 "status": 1, 254 "message": "invalid parameters: " + err.Error(), 255 "data": nil, 256 }) 257 return 258 } 259 // Generate session and message IDs 260 sessionId := uuid.New().String() 261 messageId := uuid.New().String() 262 263 // Resolve username: prefer auth middleware, fall back to explicit header 264 username := resolveTaskAPIUsername(c) 265 266 var taskReq TaskCreateRequest 267 // content interface to byte 268 269 switch content.Type { 270 case "mcp_scan": 271 var req MCPTaskRequest 272 err := json.Unmarshal(content.Content, &req) 273 if err != nil { 274 c.JSON(http.StatusOK, gin.H{ 275 "status": 1, 276 "message": "invalid parameters: " + err.Error(), 277 "data": nil, 278 }) 279 return 280 } 281 if strings.TrimSpace(req.Model.Model) == "" || strings.TrimSpace(req.Model.Token) == "" { 282 c.JSON(http.StatusOK, gin.H{ 283 "status": 1, 284 "message": "invalid parameters: mcp_scan requires model.model and model.token", 285 "data": nil, 286 }) 287 return 288 } 289 // Build task params 290 params := map[string]interface{}{ 291 "model": map[string]interface{}{ 292 "model": req.Model.Model, 293 "token": req.Model.Token, 294 "base_url": req.Model.BaseUrl, 295 }, 296 "headers": req.Headers, 297 } 298 var attachments []string 299 if req.Attachments != "" { 300 attachments = append(attachments, req.Attachments) 301 } 302 303 // Build TaskCreateRequest 304 taskReq = TaskCreateRequest{ 305 ID: messageId, 306 SessionID: sessionId, 307 Username: username, 308 Task: agent.TaskTypeMcpScan, 309 Timestamp: time.Now().UnixMilli(), 310 Content: req.Prompt, 311 Params: params, 312 Attachments: attachments, 313 } 314 case "ai_infra_scan": 315 var req AIInfraScanTaskRequest 316 err := json.Unmarshal(content.Content, &req) 317 if err != nil { 318 c.JSON(http.StatusOK, gin.H{ 319 "status": 1, 320 "message": "invalid parameters: " + err.Error(), 321 "data": nil, 322 }) 323 return 324 } 325 scanParams := map[string]interface{}{ 326 "headers": req.Headers, 327 "timeout": req.Timeout, 328 "model": map[string]interface{}{ 329 "model": req.Model.Model, 330 "token": req.Model.Token, 331 "base_url": req.Model.BaseUrl, 332 }, 333 } 334 335 taskReq = TaskCreateRequest{ 336 ID: messageId, 337 SessionID: sessionId, 338 Username: username, 339 Task: agent.TaskTypeAIInfraScan, 340 Timestamp: time.Now().UnixMilli(), 341 Params: scanParams, 342 Content: strings.Join(req.Target, "\n"), 343 Attachments: []string{}, 344 } 345 case "model_redteam_report": 346 var req PromptSecurityTaskRequest 347 err := json.Unmarshal(content.Content, &req) 348 if err != nil { 349 c.JSON(http.StatusOK, gin.H{ 350 "status": 1, 351 "message": "invalid parameters: " + err.Error(), 352 "data": nil, 353 }) 354 return 355 } 356 params := map[string]interface{}{ 357 "model": req.Model, 358 "eval_model": req.EvalModel, 359 "dataset": req.Datasets, 360 "techniques": req.Techniques, 361 } 362 taskReq = TaskCreateRequest{ 363 ID: messageId, 364 SessionID: sessionId, 365 Username: username, 366 Task: agent.TaskTypeModelRedteamReport, 367 Timestamp: time.Now().UnixMilli(), 368 Content: req.Prompt, 369 Attachments: []string{}, 370 Params: params, 371 } 372 case "agent_scan": 373 var req AgentScanTaskRequest 374 err := json.Unmarshal(content.Content, &req) 375 if err != nil { 376 c.JSON(http.StatusOK, gin.H{ 377 "status": 1, 378 "message": "invalid parameters: " + err.Error(), 379 "data": nil, 380 }) 381 return 382 } 383 384 // Resolve agent YAML: inline content takes priority over stored config. 385 var agentData []byte 386 if strings.TrimSpace(req.AgentConfig) != "" { 387 // Method A: caller supplies YAML inline — no file lookup needed. 388 agentData = []byte(strings.TrimSpace(req.AgentConfig)) 389 } else if strings.TrimSpace(req.AgentID) != "" { 390 // Method B: look up pre-saved config by agent_id. 391 agentData, err = readAgentConfigContent(username, req.AgentID) 392 if err != nil { 393 c.JSON(http.StatusOK, gin.H{ 394 "status": 1, 395 "message": "invalid parameters: failed to load agent config: " + err.Error(), 396 "data": nil, 397 }) 398 return 399 } 400 } else { 401 c.JSON(http.StatusOK, gin.H{ 402 "status": 1, 403 "message": "invalid parameters: agent_id or agent_config must be provided", 404 "data": nil, 405 }) 406 return 407 } 408 409 evalModel := req.EvalModel 410 if strings.TrimSpace(evalModel.Model) == "" || 411 strings.TrimSpace(evalModel.Token) == "" || 412 strings.TrimSpace(evalModel.BaseUrl) == "" { 413 defaultModel, err := resolveDefaultTaskAPIModel(tm, username) 414 if err != nil { 415 c.JSON(http.StatusOK, gin.H{ 416 "status": 1, 417 "message": "invalid parameters: failed to resolve default model: " + err.Error(), 418 "data": nil, 419 }) 420 return 421 } 422 if defaultModel == nil { 423 c.JSON(http.StatusOK, gin.H{ 424 "status": 1, 425 "message": "invalid parameters: no default model configured", 426 "data": nil, 427 }) 428 return 429 } 430 evalModel = ModelParams{ 431 Model: defaultModel.Model, 432 Token: defaultModel.Token, 433 BaseUrl: defaultModel.BaseUrl, 434 Limit: defaultModel.Limit, 435 } 436 } 437 438 params := map[string]interface{}{ 439 "agent_id": req.AgentID, 440 "agent_data": string(agentData), 441 "eval_model": map[string]interface{}{ 442 "model": evalModel.Model, 443 "token": evalModel.Token, 444 "base_url": evalModel.BaseUrl, 445 "limit": evalModel.Limit, 446 }, 447 } 448 449 taskReq = TaskCreateRequest{ 450 ID: messageId, 451 SessionID: sessionId, 452 Username: username, 453 Task: agent.TaskTypeAgentScan, 454 Timestamp: time.Now().UnixMilli(), 455 Content: req.Prompt, 456 Attachments: []string{}, 457 Params: params, 458 CountryIsoCode: req.Language, 459 } 460 default: 461 c.JSON(http.StatusOK, gin.H{ 462 "status": 1, 463 "message": "unsupported task type", 464 "data": nil, 465 }) 466 return 467 } 468 err := tm.AddTaskApi(&taskReq) 469 if err != nil { 470 log.Errorf("task creation failed: sessionId=%s, error=%v", sessionId, err) 471 c.JSON(http.StatusOK, gin.H{ 472 "status": 1, 473 "message": "task creation failed: " + err.Error(), 474 "data": nil, 475 }) 476 return 477 } 478 c.JSON(http.StatusOK, gin.H{ 479 "status": 0, 480 "message": "task created successfully", 481 "data": gin.H{ 482 "session_id": sessionId, 483 }, 484 }) 485 } 486 487 // GetTaskStatus retrieves task status (developer API) 488 // @Summary Get task status 489 // @Description Retrieve the current status and logs of a task by session ID. Returns task metadata and execution logs. 490 // @Tags taskapi 491 // @Produce json 492 // @Param id path string true "Task Session ID" example:"550e8400-e29b-41d4-a716-446655440000" 493 // @Success 200 {object} APIResponse{data=TaskStatusResponse} "Task status retrieved successfully" 494 // @Failure 400 {object} APIResponse "Invalid session ID format" 495 // @Failure 404 {object} APIResponse "Task not found" 496 // @Failure 500 {object} APIResponse "Internal server error" 497 // @Router /api/v1/app/taskapi/status/{id} [get] 498 func GetTaskStatus(c *gin.Context, tm *TaskManager) { 499 sessionId := c.Param("id") 500 501 if sessionId == "" { 502 c.JSON(http.StatusOK, gin.H{ 503 "status": 1, 504 "message": "session ID is required", 505 "data": nil, 506 }) 507 return 508 } 509 510 // Validate session ID format 511 if !isValidSessionID(sessionId) { 512 c.JSON(http.StatusOK, gin.H{ 513 "status": 1, 514 "message": "invalid session ID format", 515 "data": nil, 516 }) 517 return 518 } 519 520 // Fetch task from store 521 session, err := tm.taskStore.GetSession(sessionId) 522 if err != nil { 523 c.JSON(http.StatusOK, gin.H{ 524 "status": 1, 525 "message": "task not found", 526 "data": nil, 527 }) 528 return 529 } 530 531 // Fetch all task events 532 messages, err := tm.taskStore.GetSessionEventsByType(sessionId, "actionLog") 533 if err != nil { 534 c.JSON(http.StatusOK, gin.H{ 535 "status": 1, 536 "message": "failed to retrieve task data", 537 "data": nil, 538 }) 539 return 540 } 541 542 msg := "" 543 type logStruct struct { 544 ActionLog string `json:"actionLog"` 545 } 546 for _, m := range messages { 547 var x logStruct 548 err = json.Unmarshal([]byte(m.EventData.String()), &x) 549 if err != nil { 550 continue 551 } 552 msg += x.ActionLog 553 } 554 555 // Build status response 556 statusData := gin.H{ 557 "session_id": session.ID, 558 "status": session.Status, 559 "title": session.Title, 560 "created_at": session.CreatedAt, 561 "updated_at": session.UpdatedAt, 562 "log": msg, 563 } 564 565 c.JSON(http.StatusOK, gin.H{ 566 "status": 0, 567 "message": "ok", 568 "data": statusData, 569 }) 570 } 571 572 // GetTaskResult retrieves task result (developer API) 573 // @Summary Get task result 574 // @Description Retrieve the final result of a completed task. Returns detailed scan results, vulnerabilities found, and security assessment data. 575 // @Tags taskapi 576 // @Produce json 577 // @Param id path string true "Task Session ID" example:"550e8400-e29b-41d4-a716-446655440000" 578 // @Success 200 {object} APIResponse "Task result retrieved successfully. Data contains scan results, vulnerabilities, and security findings" 579 // @Failure 400 {object} APIResponse "Invalid session ID format" 580 // @Failure 404 {object} APIResponse "Task not found or not completed" 581 // @Failure 500 {object} APIResponse "Internal server error" 582 // @Router /api/v1/app/taskapi/result/{id} [get] 583 func GetTaskResult(c *gin.Context, tm *TaskManager) { 584 traceID := getTraceID(c) 585 sessionId := c.Param("id") 586 587 if sessionId == "" { 588 c.JSON(http.StatusOK, gin.H{ 589 "status": 1, 590 "message": "session ID is required", 591 "data": nil, 592 }) 593 return 594 } 595 596 // Validate session ID format 597 if !isValidSessionID(sessionId) { 598 c.JSON(http.StatusOK, gin.H{ 599 "status": 1, 600 "message": "invalid session ID format", 601 "data": nil, 602 }) 603 return 604 } 605 606 log.Infof("fetching task result: trace_id=%s, sessionId=%s", traceID, sessionId) 607 608 // Fetch all task events 609 messages, err := tm.taskStore.GetSessionEventsByType(sessionId, "resultUpdate") 610 if err != nil || len(messages) == 0 { 611 c.JSON(http.StatusOK, gin.H{ 612 "status": 1, 613 "message": "task result not available yet", 614 "data": nil, 615 }) 616 return 617 } 618 msg := messages[0] 619 // Parse event data 620 var eventData map[string]interface{} 621 if err := json.Unmarshal(msg.EventData, &eventData); err != nil { 622 c.JSON(http.StatusOK, gin.H{ 623 "status": 1, 624 "message": "failed to retrieve task result", 625 "data": nil, 626 }) 627 return 628 } 629 c.JSON(http.StatusOK, gin.H{ 630 "status": 0, 631 "message": "ok", 632 "data": eventData, 633 }) 634 }