api_test.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package websocket 16 17 import ( 18 "bytes" 19 "encoding/json" 20 "net/http" 21 "net/http/httptest" 22 "os" 23 "testing" 24 25 "github.com/Tencent/AI-Infra-Guard/pkg/database" 26 "github.com/gin-gonic/gin" 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 "gorm.io/datatypes" 30 ) 31 32 func init() { 33 gin.SetMode(gin.TestMode) 34 } 35 36 // --------------------------------------------------------------------------- 37 // helpers 38 // --------------------------------------------------------------------------- 39 40 // newTestTaskManager builds a minimal TaskManager backed by an in-memory DB. 41 func newTestTaskManager(t *testing.T) (*TaskManager, func()) { 42 t.Helper() 43 44 f, err := os.CreateTemp("", "ws-testdb-*.db") 45 require.NoError(t, err) 46 dbPath := f.Name() 47 f.Close() 48 49 cfg := database.NewConfig(dbPath) 50 db, err := database.InitDB(cfg) 51 require.NoError(t, err) 52 53 ts := database.NewTaskStore(db) 54 require.NoError(t, ts.Init()) 55 56 ms := database.NewModelStore(db) 57 require.NoError(t, ms.Init()) 58 59 am := NewAgentManager() 60 sseM := NewSSEManager() 61 tm := NewTaskManager(am, ts, ms, nil, sseM) 62 63 cleanup := func() { 64 sqlDB, _ := db.DB() 65 if sqlDB != nil { 66 sqlDB.Close() 67 } 68 os.Remove(dbPath) 69 } 70 return tm, cleanup 71 } 72 73 // newRouter wires the three task-API handlers onto a minimal gin engine. 74 func newRouter(tm *TaskManager) *gin.Engine { 75 r := gin.New() 76 r.POST("/api/v1/app/taskapi/tasks", func(c *gin.Context) { 77 SubmitTask(c, tm) 78 }) 79 r.GET("/api/v1/app/taskapi/status/:id", func(c *gin.Context) { 80 GetTaskStatus(c, tm) 81 }) 82 r.GET("/api/v1/app/taskapi/result/:id", func(c *gin.Context) { 83 GetTaskResult(c, tm) 84 }) 85 return r 86 } 87 88 // postJSON sends a JSON POST and returns the recorder. 89 func postJSON(t *testing.T, r *gin.Engine, path string, body interface{}) *httptest.ResponseRecorder { 90 t.Helper() 91 b, err := json.Marshal(body) 92 require.NoError(t, err) 93 req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b)) 94 req.Header.Set("Content-Type", "application/json") 95 w := httptest.NewRecorder() 96 r.ServeHTTP(w, req) 97 return w 98 } 99 100 // decodeAPIResponse decodes the response body into an APIResponse-like map. 101 func decodeAPIResponse(t *testing.T, w *httptest.ResponseRecorder) map[string]interface{} { 102 t.Helper() 103 var resp map[string]interface{} 104 require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) 105 return resp 106 } 107 108 // --------------------------------------------------------------------------- 109 // isValidSessionID (internal helper, white-box test) 110 // --------------------------------------------------------------------------- 111 112 func TestIsValidSessionID(t *testing.T) { 113 cases := []struct { 114 id string 115 valid bool 116 }{ 117 {"abc123", true}, 118 {"abc-123_XYZ", true}, 119 {"a", true}, 120 // Exactly 50 chars – valid 121 {"12345678901234567890123456789012345678901234567890", true}, 122 // 51 chars – invalid 123 {"123456789012345678901234567890123456789012345678901", false}, 124 {"", false}, 125 {"has space", false}, 126 {"has/slash", false}, 127 {"has.dot", false}, 128 {"has@at", false}, 129 } 130 for _, tc := range cases { 131 t.Run(tc.id, func(t *testing.T) { 132 assert.Equal(t, tc.valid, isValidSessionID(tc.id)) 133 }) 134 } 135 } 136 137 // --------------------------------------------------------------------------- 138 // SubmitTask – invalid / missing body 139 // --------------------------------------------------------------------------- 140 141 func TestSubmitTask_InvalidJSON(t *testing.T) { 142 tm, cleanup := newTestTaskManager(t) 143 defer cleanup() 144 r := newRouter(tm) 145 146 req := httptest.NewRequest(http.MethodPost, "/api/v1/app/taskapi/tasks", 147 bytes.NewBufferString("not-json")) 148 req.Header.Set("Content-Type", "application/json") 149 w := httptest.NewRecorder() 150 r.ServeHTTP(w, req) 151 152 assert.Equal(t, http.StatusOK, w.Code) 153 resp := decodeAPIResponse(t, w) 154 assert.Equal(t, float64(1), resp["status"]) 155 } 156 157 func TestSubmitTask_InvalidTaskType(t *testing.T) { 158 tm, cleanup := newTestTaskManager(t) 159 defer cleanup() 160 r := newRouter(tm) 161 162 body := map[string]interface{}{ 163 "type": "unknown_type", 164 "content": map[string]interface{}{}, 165 } 166 w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body) 167 assert.Equal(t, http.StatusOK, w.Code) 168 resp := decodeAPIResponse(t, w) 169 assert.Equal(t, float64(1), resp["status"]) 170 assert.Contains(t, resp["message"], "无效的任务类型") 171 } 172 173 func TestSubmitTask_MCPScan_MissingModelFields(t *testing.T) { 174 tm, cleanup := newTestTaskManager(t) 175 defer cleanup() 176 r := newRouter(tm) 177 178 // model.model and model.token are required for mcp_scan 179 body := map[string]interface{}{ 180 "type": "mcp_scan", 181 "content": map[string]interface{}{ 182 "prompt": "scan this", 183 "model": map[string]interface{}{}, // empty – missing required fields 184 }, 185 } 186 w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body) 187 assert.Equal(t, http.StatusOK, w.Code) 188 resp := decodeAPIResponse(t, w) 189 assert.Equal(t, float64(1), resp["status"]) 190 assert.Contains(t, resp["message"], "model.model") 191 } 192 193 func TestSubmitTask_MCPScan_MissingToken(t *testing.T) { 194 tm, cleanup := newTestTaskManager(t) 195 defer cleanup() 196 r := newRouter(tm) 197 198 body := map[string]interface{}{ 199 "type": "mcp_scan", 200 "content": map[string]interface{}{ 201 "prompt": "scan this", 202 "model": map[string]interface{}{ 203 "model": "gpt-4", 204 // token intentionally omitted 205 }, 206 }, 207 } 208 w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body) 209 resp := decodeAPIResponse(t, w) 210 assert.Equal(t, float64(1), resp["status"]) 211 } 212 213 func TestSubmitTask_AIInfraScan_NoAgents(t *testing.T) { 214 tm, cleanup := newTestTaskManager(t) 215 defer cleanup() 216 r := newRouter(tm) 217 218 // ai_infra_scan is valid but no agents are registered → should fail at dispatch 219 body := map[string]interface{}{ 220 "type": "ai_infra_scan", 221 "content": map[string]interface{}{ 222 "target": []string{"http://127.0.0.1:9999"}, 223 "timeout": 5, 224 }, 225 } 226 w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body) 227 assert.Equal(t, http.StatusOK, w.Code) 228 resp := decodeAPIResponse(t, w) 229 // No agents available → status=1, message contains "Agent" 230 assert.Equal(t, float64(1), resp["status"]) 231 } 232 233 func TestSubmitTask_ModelRedteam_NoAgents(t *testing.T) { 234 tm, cleanup := newTestTaskManager(t) 235 defer cleanup() 236 r := newRouter(tm) 237 238 body := map[string]interface{}{ 239 "type": "model_redteam_report", 240 "content": map[string]interface{}{ 241 "model": []map[string]interface{}{ 242 {"model": "gpt-4", "token": "sk-x", "base_url": "https://api.openai.com/v1"}, 243 }, 244 "eval_model": map[string]interface{}{ 245 "model": "gpt-4", "token": "sk-x", 246 }, 247 "dataset": map[string]interface{}{ 248 "dataFile": []string{"JailBench-Tiny"}, "numPrompts": 10, "randomSeed": 42, 249 }, 250 }, 251 } 252 w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body) 253 assert.Equal(t, http.StatusOK, w.Code) 254 resp := decodeAPIResponse(t, w) 255 // No agents → fail 256 assert.Equal(t, float64(1), resp["status"]) 257 } 258 259 func TestSubmitTask_AgentScan_EmptyAgentID(t *testing.T) { 260 tm, cleanup := newTestTaskManager(t) 261 defer cleanup() 262 r := newRouter(tm) 263 264 body := map[string]interface{}{ 265 "type": "agent_scan", 266 "content": map[string]interface{}{ 267 "agent_id": "", 268 }, 269 } 270 w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body) 271 resp := decodeAPIResponse(t, w) 272 assert.Equal(t, float64(1), resp["status"]) 273 assert.Contains(t, resp["message"], "agent_id") 274 } 275 276 // --------------------------------------------------------------------------- 277 // GetTaskStatus 278 // --------------------------------------------------------------------------- 279 280 func TestGetTaskStatus_EmptyID(t *testing.T) { 281 tm, cleanup := newTestTaskManager(t) 282 defer cleanup() 283 r := newRouter(tm) 284 285 // Gin will not match the route with empty param; test with placeholder 286 req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/", nil) 287 w := httptest.NewRecorder() 288 r.ServeHTTP(w, req) 289 // No route match → 404 290 assert.Equal(t, http.StatusNotFound, w.Code) 291 } 292 293 func TestGetTaskStatus_InvalidIDFormat(t *testing.T) { 294 tm, cleanup := newTestTaskManager(t) 295 defer cleanup() 296 r := newRouter(tm) 297 298 req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/invalid@id!", nil) 299 w := httptest.NewRecorder() 300 r.ServeHTTP(w, req) 301 302 assert.Equal(t, http.StatusOK, w.Code) 303 resp := decodeAPIResponse(t, w) 304 assert.Equal(t, float64(1), resp["status"]) 305 assert.Contains(t, resp["message"], "无效的任务ID格式") 306 } 307 308 func TestGetTaskStatus_NotFound(t *testing.T) { 309 tm, cleanup := newTestTaskManager(t) 310 defer cleanup() 311 r := newRouter(tm) 312 313 req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/no-such-id", nil) 314 w := httptest.NewRecorder() 315 r.ServeHTTP(w, req) 316 317 assert.Equal(t, http.StatusOK, w.Code) 318 resp := decodeAPIResponse(t, w) 319 assert.Equal(t, float64(1), resp["status"]) 320 assert.Contains(t, resp["message"], "任务不存在") 321 } 322 323 func TestGetTaskStatus_ExistingSession(t *testing.T) { 324 tm, cleanup := newTestTaskManager(t) 325 defer cleanup() 326 r := newRouter(tm) 327 328 // Seed DB directly 329 require.NoError(t, tm.taskStore.CreateUser(&database.User{ 330 UserID: "api-u1", Username: "apiuser", Email: "api@t.com", 331 })) 332 require.NoError(t, tm.taskStore.CreateSession(&database.Session{ 333 ID: "valid-session-id", 334 Username: "apiuser", 335 Title: "Test API Task", 336 TaskType: "ai_infra_scan", 337 Content: "http://127.0.0.1", 338 Status: "todo", 339 })) 340 341 req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/valid-session-id", nil) 342 w := httptest.NewRecorder() 343 r.ServeHTTP(w, req) 344 345 assert.Equal(t, http.StatusOK, w.Code) 346 resp := decodeAPIResponse(t, w) 347 assert.Equal(t, float64(0), resp["status"]) 348 data, ok := resp["data"].(map[string]interface{}) 349 require.True(t, ok) 350 assert.Equal(t, "valid-session-id", data["session_id"]) 351 assert.Equal(t, "todo", data["status"]) 352 } 353 354 // --------------------------------------------------------------------------- 355 // GetTaskResult 356 // --------------------------------------------------------------------------- 357 358 func TestGetTaskResult_InvalidIDFormat(t *testing.T) { 359 tm, cleanup := newTestTaskManager(t) 360 defer cleanup() 361 r := newRouter(tm) 362 363 req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/result/bad-id-here!", nil) 364 w := httptest.NewRecorder() 365 r.ServeHTTP(w, req) 366 367 assert.Equal(t, http.StatusOK, w.Code) 368 resp := decodeAPIResponse(t, w) 369 assert.Equal(t, float64(1), resp["status"]) 370 assert.Contains(t, resp["message"], "无效的任务ID格式") 371 } 372 373 func TestGetTaskResult_NoResults(t *testing.T) { 374 tm, cleanup := newTestTaskManager(t) 375 defer cleanup() 376 r := newRouter(tm) 377 378 // Session exists but no resultUpdate events 379 require.NoError(t, tm.taskStore.CreateUser(&database.User{ 380 UserID: "api-u2", Username: "apiuser2", Email: "api2@t.com", 381 })) 382 require.NoError(t, tm.taskStore.CreateSession(&database.Session{ 383 ID: "session-no-result", 384 Username: "apiuser2", 385 Title: "Task without result", 386 TaskType: "mcp_scan", 387 Content: "test", 388 Status: "doing", 389 })) 390 391 req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/result/session-no-result", nil) 392 w := httptest.NewRecorder() 393 r.ServeHTTP(w, req) 394 395 assert.Equal(t, http.StatusOK, w.Code) 396 resp := decodeAPIResponse(t, w) 397 assert.Equal(t, float64(1), resp["status"]) 398 } 399 400 func TestGetTaskResult_WithResult(t *testing.T) { 401 tm, cleanup := newTestTaskManager(t) 402 defer cleanup() 403 r := newRouter(tm) 404 405 require.NoError(t, tm.taskStore.CreateUser(&database.User{ 406 UserID: "api-u3", Username: "apiuser3", Email: "api3@t.com", 407 })) 408 require.NoError(t, tm.taskStore.CreateSession(&database.Session{ 409 ID: "session-with-result", 410 Username: "apiuser3", 411 Title: "Done Task", 412 TaskType: "ai_infra_scan", 413 Content: "test", 414 Status: "done", 415 })) 416 417 // Store a resultUpdate event 418 resultPayload := map[string]interface{}{ 419 "findings": []string{"vuln-A", "vuln-B"}, 420 "score": 99, 421 } 422 eventJSON, err := json.Marshal(resultPayload) 423 require.NoError(t, err) 424 425 require.NoError(t, tm.taskStore.CreateTaskMessage(&database.TaskMessage{ 426 ID: "result-ev-1", 427 SessionID: "session-with-result", 428 Type: "resultUpdate", 429 EventData: datatypes.JSON(eventJSON), 430 Timestamp: 1000, 431 })) 432 433 req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/result/session-with-result", nil) 434 w := httptest.NewRecorder() 435 r.ServeHTTP(w, req) 436 437 assert.Equal(t, http.StatusOK, w.Code) 438 resp := decodeAPIResponse(t, w) 439 assert.Equal(t, float64(0), resp["status"]) 440 data, ok := resp["data"].(map[string]interface{}) 441 require.True(t, ok) 442 findings, ok := data["findings"].([]interface{}) 443 require.True(t, ok) 444 assert.Len(t, findings, 2) 445 assert.Equal(t, float64(99), data["score"]) 446 } 447 448 // --------------------------------------------------------------------------- 449 // resolveTaskAPIUsername 450 // --------------------------------------------------------------------------- 451 452 func TestResolveTaskAPIUsername_Default(t *testing.T) { 453 c, _ := gin.CreateTestContext(httptest.NewRecorder()) 454 c.Request = httptest.NewRequest(http.MethodGet, "/", nil) 455 assert.Equal(t, "api_user", resolveTaskAPIUsername(c)) 456 } 457 458 func TestResolveTaskAPIUsername_FromHeader(t *testing.T) { 459 c, _ := gin.CreateTestContext(httptest.NewRecorder()) 460 req := httptest.NewRequest(http.MethodGet, "/", nil) 461 req.Header.Set("username", "header-user") 462 c.Request = req 463 assert.Equal(t, "header-user", resolveTaskAPIUsername(c)) 464 } 465 466 func TestResolveTaskAPIUsername_FromContext(t *testing.T) { 467 c, _ := gin.CreateTestContext(httptest.NewRecorder()) 468 c.Request = httptest.NewRequest(http.MethodGet, "/", nil) 469 c.Set("api_user", "ctx-user") 470 assert.Equal(t, "ctx-user", resolveTaskAPIUsername(c)) 471 }