/ common / websocket / api_test.go
api_test.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  
 15  package websocket
 16  
 17  import (
 18  	"bytes"
 19  	"encoding/json"
 20  	"net/http"
 21  	"net/http/httptest"
 22  	"os"
 23  	"testing"
 24  
 25  	"github.com/Tencent/AI-Infra-Guard/pkg/database"
 26  	"github.com/gin-gonic/gin"
 27  	"github.com/stretchr/testify/assert"
 28  	"github.com/stretchr/testify/require"
 29  	"gorm.io/datatypes"
 30  )
 31  
 32  func init() {
 33  	gin.SetMode(gin.TestMode)
 34  }
 35  
 36  // ---------------------------------------------------------------------------
 37  // helpers
 38  // ---------------------------------------------------------------------------
 39  
 40  // newTestTaskManager builds a minimal TaskManager backed by an in-memory DB.
 41  func newTestTaskManager(t *testing.T) (*TaskManager, func()) {
 42  	t.Helper()
 43  
 44  	f, err := os.CreateTemp("", "ws-testdb-*.db")
 45  	require.NoError(t, err)
 46  	dbPath := f.Name()
 47  	f.Close()
 48  
 49  	cfg := database.NewConfig(dbPath)
 50  	db, err := database.InitDB(cfg)
 51  	require.NoError(t, err)
 52  
 53  	ts := database.NewTaskStore(db)
 54  	require.NoError(t, ts.Init())
 55  
 56  	ms := database.NewModelStore(db)
 57  	require.NoError(t, ms.Init())
 58  
 59  	am := NewAgentManager()
 60  	sseM := NewSSEManager()
 61  	tm := NewTaskManager(am, ts, ms, nil, sseM)
 62  
 63  	cleanup := func() {
 64  		sqlDB, _ := db.DB()
 65  		if sqlDB != nil {
 66  			sqlDB.Close()
 67  		}
 68  		os.Remove(dbPath)
 69  	}
 70  	return tm, cleanup
 71  }
 72  
 73  // newRouter wires the three task-API handlers onto a minimal gin engine.
 74  func newRouter(tm *TaskManager) *gin.Engine {
 75  	r := gin.New()
 76  	r.POST("/api/v1/app/taskapi/tasks", func(c *gin.Context) {
 77  		SubmitTask(c, tm)
 78  	})
 79  	r.GET("/api/v1/app/taskapi/status/:id", func(c *gin.Context) {
 80  		GetTaskStatus(c, tm)
 81  	})
 82  	r.GET("/api/v1/app/taskapi/result/:id", func(c *gin.Context) {
 83  		GetTaskResult(c, tm)
 84  	})
 85  	return r
 86  }
 87  
 88  // postJSON sends a JSON POST and returns the recorder.
 89  func postJSON(t *testing.T, r *gin.Engine, path string, body interface{}) *httptest.ResponseRecorder {
 90  	t.Helper()
 91  	b, err := json.Marshal(body)
 92  	require.NoError(t, err)
 93  	req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(b))
 94  	req.Header.Set("Content-Type", "application/json")
 95  	w := httptest.NewRecorder()
 96  	r.ServeHTTP(w, req)
 97  	return w
 98  }
 99  
100  // decodeAPIResponse decodes the response body into an APIResponse-like map.
101  func decodeAPIResponse(t *testing.T, w *httptest.ResponseRecorder) map[string]interface{} {
102  	t.Helper()
103  	var resp map[string]interface{}
104  	require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
105  	return resp
106  }
107  
108  // ---------------------------------------------------------------------------
109  // isValidSessionID (internal helper, white-box test)
110  // ---------------------------------------------------------------------------
111  
112  func TestIsValidSessionID(t *testing.T) {
113  	cases := []struct {
114  		id    string
115  		valid bool
116  	}{
117  		{"abc123", true},
118  		{"abc-123_XYZ", true},
119  		{"a", true},
120  		// Exactly 50 chars – valid
121  		{"12345678901234567890123456789012345678901234567890", true},
122  		// 51 chars – invalid
123  		{"123456789012345678901234567890123456789012345678901", false},
124  		{"", false},
125  		{"has space", false},
126  		{"has/slash", false},
127  		{"has.dot", false},
128  		{"has@at", false},
129  	}
130  	for _, tc := range cases {
131  		t.Run(tc.id, func(t *testing.T) {
132  			assert.Equal(t, tc.valid, isValidSessionID(tc.id))
133  		})
134  	}
135  }
136  
137  // ---------------------------------------------------------------------------
138  // SubmitTask – invalid / missing body
139  // ---------------------------------------------------------------------------
140  
141  func TestSubmitTask_InvalidJSON(t *testing.T) {
142  	tm, cleanup := newTestTaskManager(t)
143  	defer cleanup()
144  	r := newRouter(tm)
145  
146  	req := httptest.NewRequest(http.MethodPost, "/api/v1/app/taskapi/tasks",
147  		bytes.NewBufferString("not-json"))
148  	req.Header.Set("Content-Type", "application/json")
149  	w := httptest.NewRecorder()
150  	r.ServeHTTP(w, req)
151  
152  	assert.Equal(t, http.StatusOK, w.Code)
153  	resp := decodeAPIResponse(t, w)
154  	assert.Equal(t, float64(1), resp["status"])
155  }
156  
157  func TestSubmitTask_InvalidTaskType(t *testing.T) {
158  	tm, cleanup := newTestTaskManager(t)
159  	defer cleanup()
160  	r := newRouter(tm)
161  
162  	body := map[string]interface{}{
163  		"type":    "unknown_type",
164  		"content": map[string]interface{}{},
165  	}
166  	w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body)
167  	assert.Equal(t, http.StatusOK, w.Code)
168  	resp := decodeAPIResponse(t, w)
169  	assert.Equal(t, float64(1), resp["status"])
170  	assert.Contains(t, resp["message"], "无效的任务类型")
171  }
172  
173  func TestSubmitTask_MCPScan_MissingModelFields(t *testing.T) {
174  	tm, cleanup := newTestTaskManager(t)
175  	defer cleanup()
176  	r := newRouter(tm)
177  
178  	// model.model and model.token are required for mcp_scan
179  	body := map[string]interface{}{
180  		"type": "mcp_scan",
181  		"content": map[string]interface{}{
182  			"prompt": "scan this",
183  			"model":  map[string]interface{}{}, // empty – missing required fields
184  		},
185  	}
186  	w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body)
187  	assert.Equal(t, http.StatusOK, w.Code)
188  	resp := decodeAPIResponse(t, w)
189  	assert.Equal(t, float64(1), resp["status"])
190  	assert.Contains(t, resp["message"], "model.model")
191  }
192  
193  func TestSubmitTask_MCPScan_MissingToken(t *testing.T) {
194  	tm, cleanup := newTestTaskManager(t)
195  	defer cleanup()
196  	r := newRouter(tm)
197  
198  	body := map[string]interface{}{
199  		"type": "mcp_scan",
200  		"content": map[string]interface{}{
201  			"prompt": "scan this",
202  			"model": map[string]interface{}{
203  				"model": "gpt-4",
204  				// token intentionally omitted
205  			},
206  		},
207  	}
208  	w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body)
209  	resp := decodeAPIResponse(t, w)
210  	assert.Equal(t, float64(1), resp["status"])
211  }
212  
213  func TestSubmitTask_AIInfraScan_NoAgents(t *testing.T) {
214  	tm, cleanup := newTestTaskManager(t)
215  	defer cleanup()
216  	r := newRouter(tm)
217  
218  	// ai_infra_scan is valid but no agents are registered → should fail at dispatch
219  	body := map[string]interface{}{
220  		"type": "ai_infra_scan",
221  		"content": map[string]interface{}{
222  			"target":  []string{"http://127.0.0.1:9999"},
223  			"timeout": 5,
224  		},
225  	}
226  	w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body)
227  	assert.Equal(t, http.StatusOK, w.Code)
228  	resp := decodeAPIResponse(t, w)
229  	// No agents available → status=1, message contains "Agent"
230  	assert.Equal(t, float64(1), resp["status"])
231  }
232  
233  func TestSubmitTask_ModelRedteam_NoAgents(t *testing.T) {
234  	tm, cleanup := newTestTaskManager(t)
235  	defer cleanup()
236  	r := newRouter(tm)
237  
238  	body := map[string]interface{}{
239  		"type": "model_redteam_report",
240  		"content": map[string]interface{}{
241  			"model": []map[string]interface{}{
242  				{"model": "gpt-4", "token": "sk-x", "base_url": "https://api.openai.com/v1"},
243  			},
244  			"eval_model": map[string]interface{}{
245  				"model": "gpt-4", "token": "sk-x",
246  			},
247  			"dataset": map[string]interface{}{
248  				"dataFile": []string{"JailBench-Tiny"}, "numPrompts": 10, "randomSeed": 42,
249  			},
250  		},
251  	}
252  	w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body)
253  	assert.Equal(t, http.StatusOK, w.Code)
254  	resp := decodeAPIResponse(t, w)
255  	// No agents → fail
256  	assert.Equal(t, float64(1), resp["status"])
257  }
258  
259  func TestSubmitTask_AgentScan_EmptyAgentID(t *testing.T) {
260  	tm, cleanup := newTestTaskManager(t)
261  	defer cleanup()
262  	r := newRouter(tm)
263  
264  	body := map[string]interface{}{
265  		"type": "agent_scan",
266  		"content": map[string]interface{}{
267  			"agent_id": "",
268  		},
269  	}
270  	w := postJSON(t, r, "/api/v1/app/taskapi/tasks", body)
271  	resp := decodeAPIResponse(t, w)
272  	assert.Equal(t, float64(1), resp["status"])
273  	assert.Contains(t, resp["message"], "agent_id")
274  }
275  
276  // ---------------------------------------------------------------------------
277  // GetTaskStatus
278  // ---------------------------------------------------------------------------
279  
280  func TestGetTaskStatus_EmptyID(t *testing.T) {
281  	tm, cleanup := newTestTaskManager(t)
282  	defer cleanup()
283  	r := newRouter(tm)
284  
285  	// Gin will not match the route with empty param; test with placeholder
286  	req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/", nil)
287  	w := httptest.NewRecorder()
288  	r.ServeHTTP(w, req)
289  	// No route match → 404
290  	assert.Equal(t, http.StatusNotFound, w.Code)
291  }
292  
293  func TestGetTaskStatus_InvalidIDFormat(t *testing.T) {
294  	tm, cleanup := newTestTaskManager(t)
295  	defer cleanup()
296  	r := newRouter(tm)
297  
298  	req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/invalid@id!", nil)
299  	w := httptest.NewRecorder()
300  	r.ServeHTTP(w, req)
301  
302  	assert.Equal(t, http.StatusOK, w.Code)
303  	resp := decodeAPIResponse(t, w)
304  	assert.Equal(t, float64(1), resp["status"])
305  	assert.Contains(t, resp["message"], "无效的任务ID格式")
306  }
307  
308  func TestGetTaskStatus_NotFound(t *testing.T) {
309  	tm, cleanup := newTestTaskManager(t)
310  	defer cleanup()
311  	r := newRouter(tm)
312  
313  	req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/no-such-id", nil)
314  	w := httptest.NewRecorder()
315  	r.ServeHTTP(w, req)
316  
317  	assert.Equal(t, http.StatusOK, w.Code)
318  	resp := decodeAPIResponse(t, w)
319  	assert.Equal(t, float64(1), resp["status"])
320  	assert.Contains(t, resp["message"], "任务不存在")
321  }
322  
323  func TestGetTaskStatus_ExistingSession(t *testing.T) {
324  	tm, cleanup := newTestTaskManager(t)
325  	defer cleanup()
326  	r := newRouter(tm)
327  
328  	// Seed DB directly
329  	require.NoError(t, tm.taskStore.CreateUser(&database.User{
330  		UserID: "api-u1", Username: "apiuser", Email: "api@t.com",
331  	}))
332  	require.NoError(t, tm.taskStore.CreateSession(&database.Session{
333  		ID:       "valid-session-id",
334  		Username: "apiuser",
335  		Title:    "Test API Task",
336  		TaskType: "ai_infra_scan",
337  		Content:  "http://127.0.0.1",
338  		Status:   "todo",
339  	}))
340  
341  	req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/status/valid-session-id", nil)
342  	w := httptest.NewRecorder()
343  	r.ServeHTTP(w, req)
344  
345  	assert.Equal(t, http.StatusOK, w.Code)
346  	resp := decodeAPIResponse(t, w)
347  	assert.Equal(t, float64(0), resp["status"])
348  	data, ok := resp["data"].(map[string]interface{})
349  	require.True(t, ok)
350  	assert.Equal(t, "valid-session-id", data["session_id"])
351  	assert.Equal(t, "todo", data["status"])
352  }
353  
354  // ---------------------------------------------------------------------------
355  // GetTaskResult
356  // ---------------------------------------------------------------------------
357  
358  func TestGetTaskResult_InvalidIDFormat(t *testing.T) {
359  	tm, cleanup := newTestTaskManager(t)
360  	defer cleanup()
361  	r := newRouter(tm)
362  
363  	req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/result/bad-id-here!", nil)
364  	w := httptest.NewRecorder()
365  	r.ServeHTTP(w, req)
366  
367  	assert.Equal(t, http.StatusOK, w.Code)
368  	resp := decodeAPIResponse(t, w)
369  	assert.Equal(t, float64(1), resp["status"])
370  	assert.Contains(t, resp["message"], "无效的任务ID格式")
371  }
372  
373  func TestGetTaskResult_NoResults(t *testing.T) {
374  	tm, cleanup := newTestTaskManager(t)
375  	defer cleanup()
376  	r := newRouter(tm)
377  
378  	// Session exists but no resultUpdate events
379  	require.NoError(t, tm.taskStore.CreateUser(&database.User{
380  		UserID: "api-u2", Username: "apiuser2", Email: "api2@t.com",
381  	}))
382  	require.NoError(t, tm.taskStore.CreateSession(&database.Session{
383  		ID:       "session-no-result",
384  		Username: "apiuser2",
385  		Title:    "Task without result",
386  		TaskType: "mcp_scan",
387  		Content:  "test",
388  		Status:   "doing",
389  	}))
390  
391  	req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/result/session-no-result", nil)
392  	w := httptest.NewRecorder()
393  	r.ServeHTTP(w, req)
394  
395  	assert.Equal(t, http.StatusOK, w.Code)
396  	resp := decodeAPIResponse(t, w)
397  	assert.Equal(t, float64(1), resp["status"])
398  }
399  
400  func TestGetTaskResult_WithResult(t *testing.T) {
401  	tm, cleanup := newTestTaskManager(t)
402  	defer cleanup()
403  	r := newRouter(tm)
404  
405  	require.NoError(t, tm.taskStore.CreateUser(&database.User{
406  		UserID: "api-u3", Username: "apiuser3", Email: "api3@t.com",
407  	}))
408  	require.NoError(t, tm.taskStore.CreateSession(&database.Session{
409  		ID:       "session-with-result",
410  		Username: "apiuser3",
411  		Title:    "Done Task",
412  		TaskType: "ai_infra_scan",
413  		Content:  "test",
414  		Status:   "done",
415  	}))
416  
417  	// Store a resultUpdate event
418  	resultPayload := map[string]interface{}{
419  		"findings": []string{"vuln-A", "vuln-B"},
420  		"score":    99,
421  	}
422  	eventJSON, err := json.Marshal(resultPayload)
423  	require.NoError(t, err)
424  
425  	require.NoError(t, tm.taskStore.CreateTaskMessage(&database.TaskMessage{
426  		ID:        "result-ev-1",
427  		SessionID: "session-with-result",
428  		Type:      "resultUpdate",
429  		EventData: datatypes.JSON(eventJSON),
430  		Timestamp: 1000,
431  	}))
432  
433  	req := httptest.NewRequest(http.MethodGet, "/api/v1/app/taskapi/result/session-with-result", nil)
434  	w := httptest.NewRecorder()
435  	r.ServeHTTP(w, req)
436  
437  	assert.Equal(t, http.StatusOK, w.Code)
438  	resp := decodeAPIResponse(t, w)
439  	assert.Equal(t, float64(0), resp["status"])
440  	data, ok := resp["data"].(map[string]interface{})
441  	require.True(t, ok)
442  	findings, ok := data["findings"].([]interface{})
443  	require.True(t, ok)
444  	assert.Len(t, findings, 2)
445  	assert.Equal(t, float64(99), data["score"])
446  }
447  
448  // ---------------------------------------------------------------------------
449  // resolveTaskAPIUsername
450  // ---------------------------------------------------------------------------
451  
452  func TestResolveTaskAPIUsername_Default(t *testing.T) {
453  	c, _ := gin.CreateTestContext(httptest.NewRecorder())
454  	c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
455  	assert.Equal(t, "api_user", resolveTaskAPIUsername(c))
456  }
457  
458  func TestResolveTaskAPIUsername_FromHeader(t *testing.T) {
459  	c, _ := gin.CreateTestContext(httptest.NewRecorder())
460  	req := httptest.NewRequest(http.MethodGet, "/", nil)
461  	req.Header.Set("username", "header-user")
462  	c.Request = req
463  	assert.Equal(t, "header-user", resolveTaskAPIUsername(c))
464  }
465  
466  func TestResolveTaskAPIUsername_FromContext(t *testing.T) {
467  	c, _ := gin.CreateTestContext(httptest.NewRecorder())
468  	c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
469  	c.Set("api_user", "ctx-user")
470  	assert.Equal(t, "ctx-user", resolveTaskAPIUsername(c))
471  }