task_manager.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package websocket 20 21 import ( 22 "encoding/json" 23 "fmt" 24 "io" 25 "mime" 26 "mime/multipart" 27 "net/http" 28 "net/url" 29 "os" 30 "path/filepath" 31 "reflect" 32 "strings" 33 "sync" 34 "time" 35 36 "github.com/Tencent/AI-Infra-Guard/common/agent" 37 38 "github.com/Tencent/AI-Infra-Guard/pkg/database" 39 "github.com/gin-gonic/gin" 40 "gorm.io/datatypes" 41 "trpc.group/trpc-go/trpc-go/log" 42 ) 43 44 // 任务管理器相关数据结构 45 46 const ( 47 WSMsgTypeTaskAssign = "task_assign" // 任务分配 48 49 // 任务状态常量 50 TaskStatusTodo = "todo" // 待执行 51 TaskStatusDoing = "doing" // 执行中 52 TaskStatusDone = "done" // 已完成 53 TaskStatusError = "error" 54 TaskStatusTerminated = "terminated" // 已终止 55 ) 56 57 type TaskManager struct { 58 mu sync.RWMutex 59 tasks map[string]*TaskCreateRequest // sessionId -> 任务请求 60 agentManager *AgentManager // 新增:引用 AgentManager 61 taskStore *database.TaskStore // 新增:引用 TaskStore 62 modelStore *database.ModelStore // 新增:引用 ModelStore 63 fileConfig *FileUploadConfig // 新增:文件上传配置 64 sseManager *SSEManager // 新增:SSE管理器 65 } 66 67 func NewTaskManager(agentManager *AgentManager, taskStore *database.TaskStore, modelStore *database.ModelStore, fileConfig *FileUploadConfig, sseManager *SSEManager) *TaskManager { 68 if fileConfig == nil { 69 fileConfig = DefaultFileUploadConfig() 70 } 71 if sseManager == nil { 72 sseManager = NewSSEManager() 73 } 74 return &TaskManager{ 75 tasks: make(map[string]*TaskCreateRequest), 76 agentManager: agentManager, // 注入 AgentManager 77 taskStore: taskStore, // 注入 TaskStore 78 modelStore: modelStore, // 注入 ModelStore 79 fileConfig: fileConfig, // 注入文件上传配置 80 sseManager: sseManager, // 注入SSE管理器 81 } 82 } 83 84 // 添加任务 85 func (tm *TaskManager) AddTask(req *TaskCreateRequest, traceID string) error { 86 log.Infof("开始添加任务: trace_id=%s, sessionId=%s, taskType=%s, username=%s", traceID, req.SessionID, req.Task, req.Username) 87 88 // 监控相关代码已移除 89 90 // 1. 先检查数据库中是否已存在相同的sessionId 91 existingSession, err := tm.taskStore.GetSession(req.SessionID) 92 if err == nil && existingSession != nil { 93 log.Errorf("任务已存在: trace_id=%s, sessionId=%s, username=%s", traceID, req.SessionID, req.Username) 94 return fmt.Errorf("任务已存在,sessionId: %s", req.SessionID) 95 } 96 97 // 2. 预存任务到数据库(状态为todo,assigned_agent为空) 98 session := &database.Session{ 99 ID: req.SessionID, 100 Username: req.Username, 101 Title: tm.generateTaskTitle(req), 102 TaskType: req.Task, 103 Content: req.Content, 104 Params: mustMarshalJSON(req.Params), 105 Attachments: mustMarshalJSON(req.Attachments), 106 Status: TaskStatusDoing, 107 AssignedAgent: "", // 预存时为空 108 CountryIsoCode: req.CountryIsoCode, 109 Share: true, 110 } 111 112 err = tm.taskStore.CreateSession(session) 113 if err != nil { 114 log.Errorf("预存任务到数据库失败: trace_id=%s, sessionId=%s, error=%v", traceID, req.SessionID, err) 115 return fmt.Errorf("预存任务失败: %v", err) 116 } 117 118 log.Infof("任务预存成功: trace_id=%s, sessionId=%s", traceID, req.SessionID) 119 120 // 3. 等待SSE连接建立 121 timeout := 100 * time.Second 122 start := time.Now() 123 for time.Since(start) < timeout { 124 if tm.sseManager.HasConnection(req.SessionID) { 125 break // 连接已建立 126 } 127 time.Sleep(500 * time.Millisecond) // 每50ms检查一次 128 } 129 130 if !tm.sseManager.HasConnection(req.SessionID) { 131 // SSE连接超时,清理预存的任务 132 tm.cleanupFailedTask(req.SessionID, traceID) 133 log.Errorf("SSE连接建立超时: trace_id=%s, sessionId=%s, username=%s, timeout=%v", traceID, req.SessionID, req.Username, timeout) 134 return fmt.Errorf("SSE连接建立超时,请重试,sessionId: %s", req.SessionID) 135 } 136 137 // 4. 存储任务到内存(dispatchTask需要从内存中获取任务) 138 tm.mu.Lock() 139 tm.tasks[req.SessionID] = req 140 tm.mu.Unlock() 141 142 // 5. 尝试分发任务 143 err = tm.dispatchTask(req.SessionID, traceID) 144 if err != nil { 145 // 分发失败,清理内存和数据库中的预存内容 146 tm.cleanupFailedTask(req.SessionID, traceID) 147 log.Errorf("任务分发失败: trace_id=%s, sessionId=%s, error=%v", traceID, req.SessionID, err) 148 return fmt.Errorf("任务分发失败: %v", err) 149 } 150 151 log.Infof("任务添加成功: trace_id=%s, sessionId=%s, taskType=%s", traceID, req.SessionID, req.Task) 152 return nil 153 } 154 155 // 一键添加任务并执行 156 func (tm *TaskManager) AddTaskApi(req *TaskCreateRequest) error { 157 // 1. 先检查数据库中是否已存在相同的sessionId 158 existingSession, err := tm.taskStore.GetSession(req.SessionID) 159 if err == nil && existingSession != nil { 160 return fmt.Errorf("任务已存在,sessionId: %s", req.SessionID) 161 } 162 163 // 2. 预存任务到数据库(状态为todo,assigned_agent为空) 164 session := &database.Session{ 165 ID: req.SessionID, 166 Username: req.Username, 167 Title: tm.generateTaskTitle(req), 168 TaskType: req.Task, 169 Content: req.Content, 170 Params: mustMarshalJSON(req.Params), 171 Attachments: mustMarshalJSON(req.Attachments), 172 Status: TaskStatusTodo, 173 AssignedAgent: "", // 预存时为空 174 CountryIsoCode: req.CountryIsoCode, 175 Share: true, 176 } 177 err = tm.taskStore.CreateSession(session) 178 if err != nil { 179 return fmt.Errorf("预存任务失败: %v", err) 180 } 181 182 // 获取可用 Agent(简化:不做额外健康检查) 183 availableAgents := tm.agentManager.GetAvailableAgents() 184 if len(availableAgents) == 0 { 185 return fmt.Errorf("没有可用的Agent") 186 } 187 188 // 3. 选择 Agent(简单策略:选择第一个,相信GetAvailableAgents的过滤结果) 189 selectedAgent := availableAgents[0] 190 191 // 4. 更新session的assigned_agent和开始时间 192 err = tm.taskStore.UpdateSessionAssignedAgent(req.SessionID, selectedAgent.agentID) 193 if err != nil { 194 return fmt.Errorf("无法更新session的assigned_agent") 195 } 196 197 // 6. 构造任务分配消息 198 taskMsg := WSMessage{ 199 Type: WSMsgTypeTaskAssign, 200 Content: TaskContent{ 201 SessionID: req.SessionID, 202 TaskType: req.Task, 203 Content: req.Content, 204 Params: req.Params, 205 Attachments: req.Attachments, 206 Timeout: 3600, 207 CountryIsoCode: req.CountryIsoCode, 208 }, 209 } 210 211 // 7. 直接发送给 Agent(简化:无重试,无额外健康检查) 212 selectedAgent.stateMu.RLock() 213 agentID := selectedAgent.agentID 214 selectedAgent.stateMu.RUnlock() 215 216 // 设置写超时并直接发送 217 selectedAgent.conn.SetWriteDeadline(time.Now().Add(writeWait)) 218 err = selectedAgent.conn.WriteJSON(taskMsg) 219 if err != nil { 220 return fmt.Errorf("下发任务给 %s 失败: %v", agentID, err) 221 } 222 223 log.Infof("任务分发成功: sessionId=%s, agentId=%s", req.SessionID, agentID) 224 return nil 225 } 226 227 // cleanupFailedTask 清理失败的任务(内存和数据库) 228 func (tm *TaskManager) cleanupFailedTask(sessionId string, traceID string) { 229 log.Infof("开始清理失败任务: trace_id=%s, sessionId=%s", traceID, sessionId) 230 231 // 清理内存中的任务 232 tm.mu.Lock() 233 delete(tm.tasks, sessionId) 234 tm.mu.Unlock() 235 236 // 清理数据库中的预存任务 237 err := tm.taskStore.DeleteSession(sessionId) 238 if err != nil { 239 log.Errorf("清理数据库中的失败任务失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 240 } else { 241 log.Infof("失败任务清理完成: trace_id=%s, sessionId=%s", traceID, sessionId) 242 } 243 } 244 245 // 获取任务 246 func (tm *TaskManager) GetTask(sessionId string) (*TaskCreateRequest, bool) { 247 tm.mu.RLock() 248 defer tm.mu.RUnlock() 249 task, ok := tm.tasks[sessionId] 250 return task, ok 251 } 252 253 // 新增:任务分发方法(简化版本,减少死锁风险) 254 func (tm *TaskManager) dispatchTask(sessionId string, traceID string) error { 255 log.Infof("开始分发任务: trace_id=%s, sessionId=%s", traceID, sessionId) 256 257 // 1. 获取任务 258 task, exists := tm.GetTask(sessionId) 259 if !exists { 260 log.Errorf("任务不存在: trace_id=%s, sessionId=%s", traceID, sessionId) 261 return fmt.Errorf("任务不存在") 262 } 263 264 // 2. 获取可用 Agent(简化:不做额外健康检查) 265 availableAgents := tm.agentManager.GetAvailableAgents() 266 if len(availableAgents) == 0 { 267 log.Warnf("没有可用的Agent: trace_id=%s, sessionId=%s", traceID, sessionId) 268 return fmt.Errorf("没有可用的Agent") 269 } 270 271 log.Infof("找到可用Agent数量: trace_id=%s, sessionId=%s, count=%d", traceID, sessionId, len(availableAgents)) 272 273 // 3. 选择 Agent(简单策略:选择第一个,相信GetAvailableAgents的过滤结果) 274 selectedAgent := availableAgents[0] 275 276 // 4. 更新session的assigned_agent和开始时间 277 err := tm.taskStore.UpdateSessionAssignedAgent(task.SessionID, selectedAgent.agentID) 278 if err != nil { 279 log.Errorf("无法更新session的assigned_agent: trace_id=%s, sessionId=%s, agentId=%s, error=%v", traceID, task.SessionID, selectedAgent.agentID, err) 280 return fmt.Errorf("无法更新session的assigned_agent") 281 } 282 283 // 5. 处理params中的modelid,获取模型信息 284 enhancedParams := make(map[string]interface{}) 285 for k, v := range task.Params { 286 enhancedParams[k] = v 287 } 288 addModel := func(modelId string) (*database.ModelParams, error) { 289 model, err := tm.modelStore.GetModel(modelId) 290 if err != nil { 291 // 检查是否是记录不存在的错误 292 if err.Error() == "record not found" { 293 log.Errorf("模型不存在: trace_id=%s, sessionId=%s, modelID=%s", traceID, sessionId, modelId) 294 return nil, fmt.Errorf("模型ID '%s' 不存在,请检查模型配置", modelId) 295 } 296 log.Errorf("获取模型信息失败: trace_id=%s, sessionId=%s, modelID=%s, error=%v", traceID, sessionId, modelId, err) 297 return nil, fmt.Errorf("获取模型信息失败: %v", err) 298 } 299 // 测试模型是否有效 300 //ai := models.NewOpenAI(model.Token, model.ModelName, model.BaseURL) 301 //err = ai.Vaild(context.Background()) 302 //if err != nil { 303 // log.Errorf("模型无效: trace_id=%s, sessionId=%s, modelID=%s, error=%v", traceID, sessionId, modelId, err) 304 // return nil, fmt.Errorf("模型无效: %v", err) 305 //} 306 p := database.ModelParams{ 307 Model: model.ModelName, 308 Token: model.Token, 309 BaseUrl: model.BaseURL, 310 Limit: model.Limit, 311 } 312 return &p, nil 313 } 314 if task.Params != nil { 315 if modelID, exists := task.Params["model_id"]; exists { 316 log.Infof("找到模型ID: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, modelID) 317 switch v := modelID.(type) { 318 case string: 319 modelInfo, err := addModel(v) 320 if err != nil { 321 return err 322 } 323 enhancedParams["model"] = modelInfo 324 case []interface{}: 325 modelsList := make([]*database.ModelParams, 0) 326 log.Infof("找到多个模型ID: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, v) 327 for _, vv := range v { 328 vv, ok := vv.(string) 329 if !ok { 330 log.Errorf("无效的模型ID类型: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, vv) 331 continue 332 } 333 modelInfo, err := addModel(vv) 334 if err != nil { 335 return err 336 } 337 modelsList = append(modelsList, modelInfo) 338 } 339 enhancedParams["model"] = modelsList 340 default: 341 log.Errorf("无效的模型ID类型: trace_id=%s, sessionId=%s, modelID=%v", traceID, sessionId, v) 342 } 343 } 344 if evalModelStr, exists := task.Params["eval_model_id"]; exists { 345 evalModelId, ok := evalModelStr.(string) 346 if ok { 347 evalModelInfo, err := addModel(evalModelId) 348 if err != nil { 349 return err 350 } 351 enhancedParams["eval_model"] = evalModelInfo 352 } 353 } 354 // 处理agent_id,将其转换为agent_data(yaml文本) 355 if agentIdStr, exists := task.Params["agent_id"]; exists { 356 agentId, ok := agentIdStr.(string) 357 if ok && agentId != "" { 358 log.Infof("找到AgentID: trace_id=%s, sessionId=%s, agentID=%s", traceID, sessionId, agentId) 359 // 使用任务的用户名读取配置,如果为空则使用公共用户 360 username := task.Username 361 if username == "" { 362 username = PublicUser 363 } 364 agentData, err := readAgentConfigContent(username, agentId) 365 if err != nil { 366 log.Errorf("获取Agent配置失败: trace_id=%s, sessionId=%s, agentID=%s, error=%v", traceID, sessionId, agentId, err) 367 return fmt.Errorf("获取Agent配置 '%s' 失败: %v", agentId, err) 368 } 369 enhancedParams["agent_data"] = string(agentData) 370 } 371 } 372 } 373 374 // 6. 构造任务分配消息 375 taskMsg := WSMessage{ 376 Type: WSMsgTypeTaskAssign, 377 Content: TaskContent{ 378 SessionID: task.SessionID, 379 TaskType: task.Task, 380 Content: task.Content, 381 Params: enhancedParams, 382 Attachments: task.Attachments, 383 Timeout: 3600, 384 CountryIsoCode: task.CountryIsoCode, 385 }, 386 } 387 log.Infof("任务分配消息: trace_id=%s, sessionId=%s, taskMsg=%+v", traceID, sessionId, taskMsg) 388 389 // 7. 直接发送给 Agent(简化:无重试,无额外健康检查) 390 selectedAgent.stateMu.RLock() 391 agentID := selectedAgent.agentID 392 isActive := selectedAgent.isActive 393 selectedAgent.stateMu.RUnlock() 394 395 if !isActive { 396 log.Errorf("选中的Agent已不活跃: trace_id=%s, sessionId=%s, agentId=%s", traceID, sessionId, agentID) 397 // 重置assigned_agent 398 tm.taskStore.UpdateSessionAssignedAgent(task.SessionID, "") 399 return fmt.Errorf("选中的Agent已不活跃: %s", agentID) 400 } 401 402 // 设置写超时并直接发送 403 selectedAgent.conn.SetWriteDeadline(time.Now().Add(writeWait)) 404 err = selectedAgent.conn.WriteJSON(taskMsg) 405 if err != nil { 406 log.Errorf("下发任务给Agent失败: trace_id=%s, sessionId=%s, agentId=%s, error=%v", traceID, task.SessionID, agentID, err) 407 return fmt.Errorf("下发任务给 %s 失败: %v", agentID, err) 408 } 409 410 log.Infof("任务分发成功: trace_id=%s, sessionId=%s, agentId=%s", traceID, task.SessionID, agentID) 411 return nil 412 } 413 414 // HandleAgentEvent 处理来自Agent的事件 415 func (tm *TaskManager) HandleAgentEvent(sessionId string, eventType string, event interface{}) { 416 log.Debugf("收到Agent事件: sessionId=%s, eventType=%s", sessionId, eventType) 417 418 if tm.shouldIgnoreAgentEvent(sessionId, eventType) { 419 log.Infof("忽略无效或终态任务的Agent事件: sessionId=%s, eventType=%s", sessionId, eventType) 420 return 421 } 422 423 // 使用通用事件处理函数 424 tm.handleEvent(sessionId, eventType, event) 425 426 // 根据事件类型记录特定日志 427 switch eventType { 428 case "liveStatus": 429 if convertedEvent, err := convertToStruct(event, &LiveStatusEvent{}); err == nil { 430 if liveStatusEvent, ok := convertedEvent.(*LiveStatusEvent); ok { 431 log.Debugf("liveStatus事件详情: sessionId=%s, text=%s", sessionId, liveStatusEvent.Text) 432 } 433 } 434 case "planUpdate": 435 if convertedEvent, err := convertToStruct(event, &PlanUpdateEvent{}); err == nil { 436 if planUpdateEvent, ok := convertedEvent.(*PlanUpdateEvent); ok { 437 log.Infof("收到计划更新: sessionId=%s, tasks=%d", sessionId, len(planUpdateEvent.Tasks)) 438 } 439 } 440 case "newPlanStep": 441 if convertedEvent, err := convertToStruct(event, &NewPlanStepEvent{}); err == nil { 442 if newPlanStepEvent, ok := convertedEvent.(*NewPlanStepEvent); ok { 443 log.Infof("新计划步骤: sessionId=%s, stepId=%s", sessionId, newPlanStepEvent.StepID) 444 } 445 } 446 case "statusUpdate": 447 if convertedEvent, err := convertToStruct(event, &StatusUpdateEvent{}); err == nil { 448 if statusUpdateEvent, ok := convertedEvent.(*StatusUpdateEvent); ok { 449 log.Infof("状态更新: sessionId=%s, status=%s", sessionId, statusUpdateEvent.AgentStatus) 450 } 451 } 452 case "toolUsed": 453 if convertedEvent, err := convertToStruct(event, &ToolUsedEvent{}); err == nil { 454 if toolUsedEvent, ok := convertedEvent.(*ToolUsedEvent); ok { 455 log.Infof("工具使用: sessionId=%s, tools=%d", sessionId, len(toolUsedEvent.Tools)) 456 } 457 } 458 case "actionLog": 459 if convertedEvent, err := convertToStruct(event, &ActionLogEvent{}); err == nil { 460 if actionLogEvent, ok := convertedEvent.(*ActionLogEvent); ok { 461 log.Debugf("动作日志: sessionId=%s, actionId=%s", sessionId, actionLogEvent.ActionID) 462 } 463 } 464 case "error": 465 log.Errorf("错误事件: sessionId=%s %v", sessionId, event) 466 err := tm.taskStore.UpdateSessionStatus(sessionId, TaskStatusError) 467 if err != nil { 468 log.Errorf("更新任务失败: sessionId=%s, error=%v", sessionId, err) 469 } 470 case "resultUpdate": 471 if convertedEvent, err := convertToStruct(event, &ResultUpdateEvent{}); err == nil { 472 if _, ok := convertedEvent.(*ResultUpdateEvent); ok { 473 log.Infof("任务完成: sessionId=%s", sessionId) 474 475 // 监控相关代码已移除 476 477 // 更新任务状态为已完成 478 err := tm.taskStore.UpdateSessionStatus(sessionId, TaskStatusDone) 479 if err != nil { 480 log.Errorf("更新任务状态为已完成失败: sessionId=%s, error=%v", sessionId, err) 481 } else { 482 log.Infof("任务状态已更新为已完成: sessionId=%s", sessionId) 483 } 484 // 任务完成,可以清理资源 485 go tm.cleanupTask(sessionId) 486 } 487 } 488 default: 489 log.Debugf("未知事件类型: sessionId=%s, eventType=%s", sessionId, eventType) 490 } 491 } 492 493 // convertToStruct 将 interface{} 转换为指定的结构体类型 494 func convertToStruct(data interface{}, target interface{}) (interface{}, error) { 495 // 先序列化为JSON 496 jsonData, err := json.Marshal(data) 497 if err != nil { 498 return nil, err 499 } 500 501 // 再反序列化为目标结构体 502 err = json.Unmarshal(jsonData, target) 503 if err != nil { 504 return nil, err 505 } 506 507 return target, nil 508 } 509 510 // generateSecureFileName 生成安全的唯一文件名 511 func generateSecureFileName(originalName string) string { 512 // 获取文件扩展名 513 ext := filepath.Ext(originalName) 514 515 // 获取不带扩展名的原始文件名 516 baseName := strings.TrimSuffix(originalName, ext) 517 518 // 生成UUID 519 uuid := generateUUID() 520 521 // 组合:UUID_原始文件名.扩展名 522 return fmt.Sprintf("%s_%s%s", baseName, uuid, ext) 523 } 524 525 // generateUUID 生成简单的UUID 526 func generateUUID() string { 527 return fmt.Sprintf("%d_%d", time.Now().UnixNano(), time.Now().Unix()) 528 } 529 530 // 通用事件处理函数 531 func (tm *TaskManager) handleEvent(sessionId string, eventType string, event interface{}) { 532 log.Debugf("开始处理事件: sessionId=%s, eventType=%s", sessionId, eventType) 533 534 // 生成事件ID 535 id := generateEventID() 536 537 // 获取事件的时间戳 538 timestamp := getEventTimestamp(event) 539 540 // 存储事件到数据库 541 err := tm.taskStore.StoreEvent(id, sessionId, eventType, event, timestamp) 542 if err != nil { 543 log.Errorf("存储事件失败: sessionId=%s, eventType=%s, error=%v", sessionId, eventType, err) 544 return 545 } 546 547 // 推送事件到SSE 548 err = tm.sseManager.SendEvent(id, sessionId, eventType, event) 549 if err != nil { 550 // 如果是连接不存在的错误,记录为调试信息而不是错误 551 if strings.Contains(err.Error(), "连接不存在") { 552 log.Debugf("SSE连接已关闭,跳过事件推送: sessionId=%s, eventType=%s", sessionId, eventType) 553 } else { 554 log.Errorf("推送事件到SSE失败: sessionId=%s, eventType=%s, error=%v", sessionId, eventType, err) 555 } 556 return 557 } 558 559 // 记录日志 560 log.Debugf("事件处理完成: sessionId=%s, eventType=%s", sessionId, eventType) 561 } 562 563 // getEventTimestamp 获取事件的时间戳 564 func getEventTimestamp(event interface{}) int64 { 565 // 使用反射获取Timestamp字段 566 v := reflect.ValueOf(event) 567 if v.Kind() == reflect.Ptr { 568 v = v.Elem() 569 } 570 571 if v.Kind() == reflect.Struct { 572 if field := v.FieldByName("Timestamp"); field.IsValid() && field.CanInterface() { 573 if timestamp, ok := field.Interface().(int64); ok { 574 return timestamp 575 } 576 } 577 } 578 579 // 如果无法获取时间戳,使用当前时间 580 return time.Now().UnixMilli() 581 } 582 583 // TerminateTask 终止任务 584 func (tm *TaskManager) TerminateTask(sessionId string, username string, traceID string) error { 585 log.Infof("开始终止任务: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 586 587 // 检查任务是否存在 588 session, err := tm.taskStore.GetSession(sessionId) 589 if err != nil { 590 log.Errorf("任务不存在: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 591 return fmt.Errorf("任务不存在") 592 } 593 594 // 验证用户权限(只有任务创建者才能终止任务) 595 if session.Username != username { 596 log.Errorf("无权限终止任务: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) 597 return fmt.Errorf("无权限操作此任务") 598 } 599 600 if isTerminalTaskStatus(session.Status) { 601 log.Infof("任务已结束,无需终止: trace_id=%s, sessionId=%s, status=%s", traceID, sessionId, session.Status) 602 return fmt.Errorf("任务已结束,无需终止") 603 } 604 605 // 通知 Agent 终止任务 606 if session.AssignedAgent != "" { 607 log.Infof("通知Agent终止任务: trace_id=%s, sessionId=%s, agentId=%s", traceID, sessionId, session.AssignedAgent) 608 tm.notifyAgentToTerminate(session.AssignedAgent, sessionId, traceID) 609 } 610 611 // 发送终止事件给前端 612 tm.sendTerminationEvent(sessionId, traceID) 613 614 // 更新任务状态为已终止 615 err = tm.taskStore.UpdateSessionStatus(sessionId, TaskStatusTerminated) 616 if err != nil { 617 log.Errorf("更新任务状态失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 618 return fmt.Errorf("更新任务状态失败") 619 } 620 621 log.Infof("任务终止完成: trace_id=%s, sessionId=%s", traceID, sessionId) 622 623 // 监控相关代码已移除 624 625 // 异步清理任务资源 626 go tm.cleanupTask(sessionId) 627 628 return nil 629 } 630 631 // notifyAgentToTerminate 通知 Agent 终止任务(简化版本) 632 func (tm *TaskManager) notifyAgentToTerminate(agentID string, sessionId string, traceID string) { 633 // 获取 Agent 连接 634 availableAgents := tm.agentManager.GetAvailableAgents() 635 for _, agent := range availableAgents { 636 agent.stateMu.RLock() 637 currentAgentID := agent.agentID 638 isActive := agent.isActive 639 agent.stateMu.RUnlock() 640 641 if currentAgentID == agentID && isActive { 642 // 发送终止消息给 Agent 643 terminateMsg := WSMessage{ 644 Type: "terminate", 645 Content: map[string]interface{}{ 646 "session_id": sessionId, 647 "reason": "用户主动终止", 648 }, 649 } 650 651 // 直接发送,无重试机制 652 agent.conn.SetWriteDeadline(time.Now().Add(writeWait)) 653 err := agent.conn.WriteJSON(terminateMsg) 654 if err != nil { 655 log.Errorf("发送终止消息给Agent %s失败: %v", agentID, err) 656 } else { 657 log.Infof("终止消息已发送给Agent %s: trace_id=%s, sessionId=%s", agentID, traceID, sessionId) 658 } 659 return 660 } 661 } 662 663 log.Warnf("未找到可终止的Agent连接: trace_id=%s, sessionId=%s, agentId=%s", traceID, sessionId, agentID) 664 } 665 666 // sendTerminationEvent 发送终止事件给前端 667 func (tm *TaskManager) sendTerminationEvent(sessionId string, traceID string) { 668 event := StatusUpdateEvent{ 669 ID: generateEventID(), 670 Type: "statusUpdate", 671 Timestamp: time.Now().UnixMilli(), 672 AgentStatus: "terminated", 673 Brief: "任务已终止", 674 Description: "用户主动终止了任务执行", 675 NoRender: false, 676 } 677 678 // 使用通用事件处理函数 679 tm.handleEvent(sessionId, "statusUpdate", event) 680 681 log.Infof("终止事件已发送: trace_id=%s, sessionId=%s", traceID, sessionId) 682 } 683 684 func isTerminalTaskStatus(status string) bool { 685 switch status { 686 case TaskStatusDone, TaskStatusError, TaskStatusTerminated: 687 return true 688 default: 689 return false 690 } 691 } 692 693 func (tm *TaskManager) shouldIgnoreAgentEvent(sessionId string, eventType string) bool { 694 session, err := tm.taskStore.GetSession(sessionId) 695 if err != nil || session == nil { 696 return true 697 } 698 699 if !isTerminalTaskStatus(session.Status) { 700 return false 701 } 702 703 return true 704 } 705 706 // generateEventID 生成事件ID 707 func generateEventID() string { 708 return time.Now().Format("20060102150405") + "_" + fmt.Sprintf("%d", time.Now().UnixNano()) 709 } 710 711 // UpdateTask 更新任务信息 712 func (tm *TaskManager) UpdateTask(sessionId string, req *TaskUpdateRequest, username string, traceID string) error { 713 log.Infof("开始更新任务: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 714 715 // 1. 验证任务是否存在 716 session, err := tm.taskStore.GetSession(sessionId) 717 if err != nil { 718 log.Errorf("任务不存在: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 719 return fmt.Errorf("任务不存在") 720 } 721 722 // 2. 验证权限(只有任务创建者才能更新) 723 if session.Username != username { 724 log.Errorf("无权限操作此任务: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) 725 return fmt.Errorf("无权限操作此任务") 726 } 727 728 // 3. 更新任务信息 729 updates := map[string]interface{}{ 730 "title": req.Title, 731 } 732 err = tm.taskStore.UpdateSession(sessionId, updates) 733 if err != nil { 734 log.Errorf("更新任务信息失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 735 return fmt.Errorf("更新任务信息失败: %v", err) 736 } 737 738 log.Infof("任务信息更新成功: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 739 return nil 740 } 741 742 // DeleteTask 删除任务 743 func (tm *TaskManager) DeleteTask(sessionId string, username string, traceID string) error { 744 log.Infof("开始删除任务: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 745 746 // 检查任务是否存在 747 session, err := tm.taskStore.GetSession(sessionId) 748 if err != nil { 749 log.Errorf("任务不存在: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 750 return fmt.Errorf("任务不存在") 751 } 752 753 // 验证用户权限: 754 // 1. 任务创建者始终可以删除 755 // 2. 共享任务允许 Web 公共视角(public_user / 空用户)删除, 756 // 以便在前端管理通过 API 创建的共享任务 757 isSharedWebDelete := session.Share && (username == "" || username == PublicUser) 758 if session.Username != username && !isSharedWebDelete { 759 log.Errorf("无权限操作此任务: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) 760 return fmt.Errorf("无权限操作此任务") 761 } 762 763 if session.Status == TaskStatusDoing { 764 log.Infof("删除运行中任务前先终止执行: trace_id=%s, sessionId=%s, agentId=%s", traceID, sessionId, session.AssignedAgent) 765 if session.AssignedAgent != "" { 766 tm.notifyAgentToTerminate(session.AssignedAgent, sessionId, traceID) 767 } 768 if err := tm.taskStore.UpdateSessionStatus(sessionId, TaskStatusTerminated); err != nil { 769 log.Errorf("删除前标记任务终止失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 770 return fmt.Errorf("删除前终止任务失败: %v", err) 771 } 772 tm.sendTerminationEvent(sessionId, traceID) 773 } 774 775 // 使用事务删除会话及其所有消息 776 err = tm.taskStore.DeleteSessionWithMessages(sessionId) 777 if err != nil { 778 log.Errorf("删除任务失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 779 return fmt.Errorf("删除任务失败: %v", err) 780 } 781 782 // 删除附件文件 783 err = tm.deleteSessionAttachments(session) 784 if err != nil { 785 log.Errorf("删除附件文件失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 786 // 附件删除失败不影响主流程,只记录警告 787 } 788 789 // 清理内存中的任务数据 790 tm.mu.Lock() 791 delete(tm.tasks, sessionId) 792 tm.mu.Unlock() 793 794 // 关闭SSE连接 795 tm.CloseSSESession(sessionId) 796 797 log.Infof("任务删除完成: trace_id=%s, sessionId=%s", traceID, sessionId) 798 return nil 799 } 800 801 // deleteSessionAttachments 删除会话的附件文件 802 func (tm *TaskManager) deleteSessionAttachments(session *database.Session) error { 803 if session.Attachments == nil { 804 return nil 805 } 806 807 var attachmentURLs []string 808 if err := json.Unmarshal(session.Attachments, &attachmentURLs); err != nil { 809 return fmt.Errorf("解析附件URL失败: %v", err) 810 } 811 812 for _, url := range attachmentURLs { 813 // 从URL中提取文件名 814 fileName := tm.extractFileNameFromURL(url) 815 if fileName == url { 816 // 如果无法提取文件名,跳过 817 continue 818 } 819 820 // 构建完整的文件路径 821 filePath := filepath.Join(tm.fileConfig.UploadDir, fileName) 822 823 // 删除文件 824 if err := os.Remove(filePath); err != nil { 825 if !os.IsNotExist(err) { 826 log.Errorf("删除附件文件失败: %s, error: %v", filePath, err) 827 } 828 } else { 829 log.Debugf("删除附件文件成功: %s", filePath) 830 } 831 } 832 833 return nil 834 } 835 836 // UploadFileResult 文件上传结果 837 type UploadFileResult struct { 838 Filename string `json:"filename"` // 原始文件名 839 FileURL string `json:"fileUrl"` // 文件访问URL 840 } 841 842 // UploadFile 上传文件 843 func (tm *TaskManager) UploadFile(file *multipart.FileHeader, traceID string) (*UploadFileResult, error) { 844 log.Infof("开始文件上传: trace_id=%s, originalName=%s, size=%d", traceID, file.Filename, file.Size) 845 846 // 保存原始文件名 847 originalName := file.Filename 848 849 // 生成安全的唯一文件名 850 fileName := generateSecureFileName(file.Filename) 851 log.Debugf("生成安全文件名: trace_id=%s, originalName=%s, secureName=%s", traceID, originalName, fileName) 852 853 // 使用配置的上传目录 854 uploadDir := tm.fileConfig.UploadDir 855 if err := os.MkdirAll(uploadDir, 0755); err != nil { 856 log.Errorf("创建上传目录失败: trace_id=%s, path=%s, error=%v", traceID, uploadDir, err) 857 return nil, fmt.Errorf("创建上传目录失败: %v", err) 858 } 859 860 // 创建文件路径 861 filePath := filepath.Join(uploadDir, fileName) 862 863 // 保存文件到本地 864 src, err := file.Open() 865 if err != nil { 866 log.Errorf("打开上传文件失败: trace_id=%s, originalName=%s, error=%v", traceID, originalName, err) 867 return nil, fmt.Errorf("打开文件失败: %v", err) 868 } 869 defer src.Close() 870 871 dst, err := os.Create(filePath) 872 if err != nil { 873 log.Errorf("创建目标文件失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) 874 return nil, fmt.Errorf("创建文件失败: %v", err) 875 } 876 defer dst.Close() 877 878 // 复制文件内容并验证 879 written, err := io.Copy(dst, src) 880 if err != nil { 881 // 清理已创建的文件 882 os.Remove(filePath) 883 log.Errorf("文件写入失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) 884 return nil, fmt.Errorf("保存文件失败: %v", err) 885 } 886 887 // 验证写入的文件大小 888 if written != file.Size { 889 os.Remove(filePath) 890 log.Errorf("文件写入不完整: trace_id=%s, expected=%d, actual=%d, filePath=%s", traceID, file.Size, written, filePath) 891 return nil, fmt.Errorf("文件写入不完整") 892 } 893 894 // 生成文件访问URL 895 fileURL := tm.fileConfig.GetFileURL(fileName) 896 897 log.Infof("文件上传成功: trace_id=%s, originalName=%s, secureName=%s, size=%d, fileURL=%s", traceID, originalName, fileName, written, fileURL) 898 899 return &UploadFileResult{ 900 Filename: originalName, 901 FileURL: fileURL, 902 }, nil 903 } 904 905 // ChunkUploadResult 分片上传结果 906 type ChunkUploadResult struct { 907 ChunkIndex int `json:"chunkIndex"` // 当前分片索引 908 TotalChunks int `json:"totalChunks"` // 总分片数 909 Message string `json:"message"` // 消息 910 } 911 912 // MergeChunksResult 合并分片结果 913 type MergeChunksResult struct { 914 Filename string `json:"filename"` // 原始文件名 915 FileURL string `json:"fileUrl"` // 文件访问URL 916 FileSize int64 `json:"fileSize"` // 文件大小 917 } 918 919 // validatePathSafety 验证路径安全性,确保路径在基础目录内 920 func (tm *TaskManager) validatePathSafety(targetPath string) error { 921 // 获取基础目录的绝对路径 922 baseDir, err := filepath.Abs(tm.fileConfig.UploadDir) 923 if err != nil { 924 return fmt.Errorf("获取基础目录失败: %v", err) 925 } 926 927 // 获取目标路径的绝对路径 928 absPath, err := filepath.Abs(targetPath) 929 if err != nil { 930 return fmt.Errorf("获取目标路径失败: %v", err) 931 } 932 933 // 清理路径(处理 .. 等) 934 cleanPath := filepath.Clean(absPath) 935 936 // 验证目标路径是否在基础目录内 937 if !strings.HasPrefix(cleanPath, baseDir+string(filepath.Separator)) && cleanPath != baseDir { 938 return fmt.Errorf("路径越界: %s 不在 %s 内", cleanPath, baseDir) 939 } 940 941 return nil 942 } 943 944 // UploadFileChunk 上传文件分片 945 func (tm *TaskManager) UploadFileChunk(fileID string, filename string, chunkIndex int, totalChunks int, chunkData []byte, traceID string) (*ChunkUploadResult, error) { 946 log.Infof("开始分片上传: trace_id=%s, fileID=%s, filename=%s, chunkIndex=%d/%d, size=%d", 947 traceID, fileID, filename, chunkIndex+1, totalChunks, len(chunkData)) 948 949 // 创建临时目录存储分片 950 tempDir := filepath.Join(tm.fileConfig.UploadDir, "temp", fileID) 951 952 // 验证路径安全性(防止路径穿越) 953 if err := tm.validatePathSafety(tempDir); err != nil { 954 log.Errorf("路径安全校验失败: trace_id=%s, path=%s, error=%v", traceID, tempDir, err) 955 return nil, fmt.Errorf("无效的文件路径") 956 } 957 958 if err := os.MkdirAll(tempDir, 0755); err != nil { 959 log.Errorf("创建临时目录失败: trace_id=%s, path=%s, error=%v", traceID, tempDir, err) 960 return nil, fmt.Errorf("创建临时目录失败: %v", err) 961 } 962 963 // 保存分片到临时目录 964 chunkPath := filepath.Join(tempDir, fmt.Sprintf("chunk_%d", chunkIndex)) 965 966 // 验证分片路径安全性 967 if err := tm.validatePathSafety(chunkPath); err != nil { 968 log.Errorf("分片路径安全校验失败: trace_id=%s, path=%s, error=%v", traceID, chunkPath, err) 969 return nil, fmt.Errorf("无效的文件路径") 970 } 971 972 if err := os.WriteFile(chunkPath, chunkData, 0644); err != nil { 973 log.Errorf("保存分片失败: trace_id=%s, chunkPath=%s, error=%v", traceID, chunkPath, err) 974 return nil, fmt.Errorf("保存分片失败: %v", err) 975 } 976 977 log.Infof("分片上传成功: trace_id=%s, fileID=%s, chunkIndex=%d/%d", traceID, fileID, chunkIndex+1, totalChunks) 978 979 return &ChunkUploadResult{ 980 ChunkIndex: chunkIndex, 981 TotalChunks: totalChunks, 982 Message: fmt.Sprintf("分片 %d/%d 上传成功", chunkIndex+1, totalChunks), 983 }, nil 984 } 985 986 // MergeFileChunks 合并文件分片 987 func (tm *TaskManager) MergeFileChunks(fileID string, filename string, totalChunks int, fileSize int64, traceID string) (*MergeChunksResult, error) { 988 log.Infof("开始合并分片: trace_id=%s, fileID=%s, filename=%s, totalChunks=%d, expectedSize=%d", 989 traceID, fileID, filename, totalChunks, fileSize) 990 991 tempDir := filepath.Join(tm.fileConfig.UploadDir, "temp", fileID) 992 993 // 验证临时目录路径安全性(防止路径穿越) 994 if err := tm.validatePathSafety(tempDir); err != nil { 995 log.Errorf("临时目录路径安全校验失败: trace_id=%s, path=%s, error=%v", traceID, tempDir, err) 996 return nil, fmt.Errorf("无效的文件路径") 997 } 998 999 // 确保最终完成后清理临时目录 1000 defer func() { 1001 if err := os.RemoveAll(tempDir); err != nil { 1002 log.Warnf("清理临时文件失败: trace_id=%s, path=%s, error=%v", traceID, tempDir, err) 1003 } 1004 }() 1005 1006 // 读取并合并所有分片 1007 var mergedData []byte 1008 for i := 0; i < totalChunks; i++ { 1009 chunkPath := filepath.Join(tempDir, fmt.Sprintf("chunk_%d", i)) 1010 1011 // 验证分片路径安全性 1012 if err := tm.validatePathSafety(chunkPath); err != nil { 1013 log.Errorf("分片路径安全校验失败: trace_id=%s, path=%s, error=%v", traceID, chunkPath, err) 1014 return nil, fmt.Errorf("无效的文件路径") 1015 } 1016 1017 chunkData, err := os.ReadFile(chunkPath) 1018 if err != nil { 1019 log.Errorf("读取分片失败: trace_id=%s, chunkPath=%s, error=%v", traceID, chunkPath, err) 1020 return nil, fmt.Errorf("读取分片 %d 失败: %v", i, err) 1021 } 1022 mergedData = append(mergedData, chunkData...) 1023 } 1024 1025 // 验证文件大小 1026 if int64(len(mergedData)) != fileSize { 1027 log.Errorf("文件大小不匹配: trace_id=%s, expected=%d, actual=%d", traceID, fileSize, len(mergedData)) 1028 return nil, fmt.Errorf("文件大小不匹配: 期望 %d 字节, 实际 %d 字节", fileSize, len(mergedData)) 1029 } 1030 1031 // 生成安全的唯一文件名 1032 secureFileName := generateSecureFileName(filename) 1033 1034 // 保存合并后的文件 1035 uploadDir := tm.fileConfig.UploadDir 1036 if err := os.MkdirAll(uploadDir, 0755); err != nil { 1037 log.Errorf("创建上传目录失败: trace_id=%s, path=%s, error=%v", traceID, uploadDir, err) 1038 return nil, fmt.Errorf("创建上传目录失败: %v", err) 1039 } 1040 1041 filePath := filepath.Join(uploadDir, secureFileName) 1042 1043 // 验证最终文件路径安全性 1044 if err := tm.validatePathSafety(filePath); err != nil { 1045 log.Errorf("文件路径安全校验失败: trace_id=%s, path=%s, error=%v", traceID, filePath, err) 1046 return nil, fmt.Errorf("无效的文件路径") 1047 } 1048 1049 if err := os.WriteFile(filePath, mergedData, 0644); err != nil { 1050 log.Errorf("保存合并文件失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) 1051 return nil, fmt.Errorf("保存文件失败: %v", err) 1052 } 1053 1054 // 生成文件访问URL 1055 fileURL := tm.fileConfig.GetFileURL(secureFileName) 1056 1057 log.Infof("文件合并成功: trace_id=%s, fileID=%s, filename=%s, secureName=%s, size=%d, fileURL=%s", 1058 traceID, fileID, filename, secureFileName, len(mergedData), fileURL) 1059 1060 return &MergeChunksResult{ 1061 Filename: filename, 1062 FileURL: fileURL, 1063 FileSize: int64(len(mergedData)), 1064 }, nil 1065 } 1066 1067 // GetUserTasks 获取指定用户的任务列表,只返回属于该用户的会话,确保用户只能看到自己的任务。 1068 func (tm *TaskManager) GetUserTasks(username string, traceID string) ([]map[string]interface{}, error) { 1069 // 从数据库获取用户的任务列表 1070 sessions, err := tm.taskStore.GetUserSessions(username) 1071 if err != nil { 1072 log.Errorf("获取用户任务列表失败: trace_id=%s, username=%s, error=%v", traceID, username, err) 1073 return nil, fmt.Errorf("获取任务列表失败: %v", err) 1074 } 1075 1076 // 转换为前端需要的格式 1077 var tasks []map[string]interface{} 1078 for _, session := range sessions { 1079 task := buildTaskSummary(session) 1080 1081 // 添加完成时间(如果任务已完成) 1082 if session.CompletedAt != nil { 1083 task["completedAt"] = *session.CompletedAt 1084 } else { 1085 task["completedAt"] = nil 1086 } 1087 1088 tasks = append(tasks, task) 1089 } 1090 return tasks, nil 1091 } 1092 1093 // GetUserTasksByType 获取指定用户的任务列表,支持可选的任务类型过滤 1094 func (tm *TaskManager) GetUserTasksByType(username string, taskType string, traceID string) ([]map[string]interface{}, error) { 1095 // 从数据库获取用户的任务列表(支持类型过滤) 1096 sessions, err := tm.taskStore.GetUserSessionsByType(username, taskType) 1097 if err != nil { 1098 log.Errorf("获取用户任务列表失败: trace_id=%s, username=%s, taskType=%s, error=%v", traceID, username, taskType, err) 1099 return nil, fmt.Errorf("获取任务列表失败: %v", err) 1100 } 1101 1102 // 转换为前端需要的格式 1103 var tasks []map[string]interface{} 1104 for _, session := range sessions { 1105 task := buildTaskSummary(session) 1106 1107 // 添加完成时间(如果任务已完成) 1108 if session.CompletedAt != nil { 1109 task["completedAt"] = *session.CompletedAt 1110 } else { 1111 task["completedAt"] = nil 1112 } 1113 1114 tasks = append(tasks, task) 1115 } 1116 return tasks, nil 1117 } 1118 1119 // SearchUserTasksSimple 使用简化参数搜索指定用户的任务,支持单个查询关键词和分页 1120 func (tm *TaskManager) SearchUserTasksSimple(username string, searchParams database.SimpleSearchParams, traceID string) ([]map[string]interface{}, error) { 1121 log.Infof("开始简化搜索用户任务: trace_id=%s, username=%s, query=%s, taskType=%s", traceID, username, searchParams.Query, searchParams.TaskType) 1122 1123 // 验证和设置默认分页参数 1124 if searchParams.Page < 1 { 1125 searchParams.Page = 1 1126 } 1127 if searchParams.PageSize < 1 { 1128 searchParams.PageSize = 10 1129 } 1130 if searchParams.PageSize > 100 { 1131 searchParams.PageSize = 100 // 限制最大页面大小 1132 } 1133 1134 // 从数据库搜索用户的任务列表 1135 sessions, _, err := tm.taskStore.SearchUserSessionsSimple(username, searchParams) 1136 if err != nil { 1137 log.Errorf("简化搜索用户任务失败: trace_id=%s, username=%s, taskType=%s, error=%v", traceID, username, searchParams.TaskType, err) 1138 return nil, fmt.Errorf("搜索任务失败: %v", err) 1139 } 1140 1141 // 转换为前端需要的格式 1142 var tasks []map[string]interface{} 1143 for _, session := range sessions { 1144 task := buildTaskSummary(session) 1145 1146 // 添加完成时间(如果任务已完成) 1147 if session.CompletedAt != nil { 1148 task["completedAt"] = *session.CompletedAt 1149 } else { 1150 task["completedAt"] = nil 1151 } 1152 1153 tasks = append(tasks, task) 1154 } 1155 return tasks, nil 1156 } 1157 1158 func buildTaskSummary(session *database.Session) map[string]interface{} { 1159 source, sourceLabel := resolveTaskSource(session) 1160 return map[string]interface{}{ 1161 "sessionId": session.ID, 1162 "title": decorateTaskTitle(session.Title, sourceLabel), 1163 "rawTitle": session.Title, 1164 "taskType": session.TaskType, 1165 "status": session.Status, 1166 "countryIsoCode": session.CountryIsoCode, 1167 "updatedAt": session.UpdatedAt, 1168 "createdAt": session.CreatedAt, 1169 "source": source, 1170 "sourceLabel": sourceLabel, 1171 } 1172 } 1173 1174 func decorateTaskTitle(title, sourceLabel string) string { 1175 if sourceLabel == "" { 1176 return title 1177 } 1178 return fmt.Sprintf("[%s] %s", sourceLabel, title) 1179 } 1180 1181 func resolveTaskSource(session *database.Session) (string, string) { 1182 switch session.Username { 1183 case "", PublicUser, "demo-test": 1184 return "web", "" 1185 default: 1186 return "api", "API" 1187 } 1188 } 1189 1190 // generateTaskTitle 生成任务标题(用于任务创建API) 1191 func (tm *TaskManager) generateTaskTitle(req *TaskCreateRequest) string { 1192 ret := "" 1193 var ModelName = "" 1194 language := req.CountryIsoCode 1195 if language == "" { 1196 language = "zh" 1197 } 1198 1199 // 定义语言相关的文本 1200 var texts struct { 1201 // 任务类型标题 1202 aiInfraScan, mcpScan, modelJailbreak, modelRedteamReport, agentScan, otherTask string 1203 // 其他文本 1204 model, prompt, github, sse string 1205 } 1206 1207 if language == "en" { 1208 texts.aiInfraScan = "AI Infra Scan - " 1209 texts.mcpScan = "AI Tool and Skill Scan - " 1210 texts.modelJailbreak = "LLM Jailbreaking - " 1211 texts.modelRedteamReport = "Jailbreak Evaluation - " 1212 texts.agentScan = "Agent Scan - " 1213 texts.otherTask = "Other Task - " 1214 texts.model = "Model:" 1215 texts.prompt = "Prompt:" 1216 texts.github = "Github:" 1217 texts.sse = "SSE:" 1218 } else { 1219 texts.aiInfraScan = "AI基础设施扫描 - " 1220 texts.mcpScan = "AI工具技能扫描 - " 1221 texts.modelJailbreak = "一键越狱任务 - " 1222 texts.modelRedteamReport = "大模型安全体检 - " 1223 texts.agentScan = "Agent安全扫描 - " 1224 texts.otherTask = "其他任务 - " 1225 texts.model = "模型:" 1226 texts.prompt = "prompt:" 1227 texts.github = "Github:" 1228 texts.sse = "SSE:" 1229 } 1230 if modelID, exists := req.Params["model_id"]; exists { 1231 switch v := modelID.(type) { 1232 case string: 1233 model, err := tm.modelStore.GetModel(v) 1234 if err == nil { 1235 ModelName = model.ModelName 1236 } 1237 case []interface{}: 1238 modelStr := make([]string, 0) 1239 for _, mid := range v { 1240 mid, ok := mid.(string) 1241 if !ok { 1242 continue 1243 } 1244 model, err := tm.modelStore.GetModel(mid) 1245 if err == nil { 1246 modelStr = append(modelStr, model.ModelName) 1247 } 1248 } 1249 ModelName = strings.Join(modelStr, ",") 1250 } 1251 } 1252 // 1. AI基础 ip/域名 ,文件形式:取第一行等xx个 1253 // 2. MCP:文件名以文件展示,github取项目名,sse取链接 1254 // 3. 评测:模型名 eg:qwen3模型评测任务 1255 // 4. 一键越狱:模型名+prompt 1256 switch req.Task { 1257 case agent.TaskTypeAIInfraScan: 1258 ret = texts.aiInfraScan 1259 if len(req.Attachments) > 0 && req.Attachments[0] != "" { 1260 ret += tm.extractFileNameFromURL(req.Attachments[0]) 1261 } 1262 if req.Content != "" { 1263 ret += req.Content 1264 } 1265 case agent.TaskTypeMcpScan: 1266 ret = texts.mcpScan 1267 if len(req.Attachments) > 0 && req.Attachments[0] != "" { 1268 // 直接调用现有的extractFileNameFromURL方法 1269 ret += tm.extractFileNameFromURL(req.Attachments[0]) 1270 } else if strings.Contains(req.Content, "github.com") { 1271 ret += texts.github + tm.extractFileNameFromURL(req.Content) 1272 } else { 1273 ret += texts.sse + req.Content 1274 } 1275 case agent.TaskTypeModelJailbreak: 1276 ret = texts.modelJailbreak + fmt.Sprintf("%s%s, %s%s", texts.model, ModelName, texts.prompt, req.Content) 1277 case agent.TaskTypeModelRedteamReport: 1278 ret = texts.modelRedteamReport + ModelName 1279 case agent.TaskTypeAgentScan: 1280 agentId, ok := req.Params["agent_id"] 1281 ret = texts.agentScan 1282 if ok { 1283 ret += agentId.(string) 1284 } 1285 if req.Content != "" { 1286 ret += " " + req.Content 1287 } 1288 default: 1289 ret = texts.otherTask + req.Content 1290 } 1291 // 如果content为空,尝试从附件中提取第一个URL的文件名作为title 1292 return ret 1293 } 1294 1295 // 辅助函数:将interface{}转换为datatypes.JSON 1296 func mustMarshalJSON(v interface{}) datatypes.JSON { 1297 if v == nil { 1298 return datatypes.JSON("{}") 1299 } 1300 data, err := json.Marshal(v) 1301 if err != nil { 1302 return datatypes.JSON("{}") 1303 } 1304 return datatypes.JSON(data) 1305 } 1306 1307 // EstablishSSEConnection 建立SSE连接 1308 func (tm *TaskManager) EstablishSSEConnection(w http.ResponseWriter, sessionId string, username string, traceID string) error { 1309 log.Infof("建立SSE连接: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 1310 err := tm.sseManager.AddConnection(sessionId, username, w) 1311 if err != nil { 1312 log.Errorf("建立SSE连接失败: trace_id=%s, sessionId=%s, username=%s, error=%v", traceID, sessionId, username, err) 1313 } else { 1314 log.Infof("SSE连接建立成功: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 1315 } 1316 return err 1317 } 1318 1319 // CloseSSESession 关闭SSE会话 1320 func (tm *TaskManager) CloseSSESession(sessionId string) { 1321 log.Infof("关闭SSE会话: sessionId=%s", sessionId) 1322 tm.sseManager.RemoveConnection(sessionId) 1323 log.Infof("SSE会话已关闭: sessionId=%s", sessionId) 1324 } 1325 1326 // 任务完成/中断时的清理 1327 func (tm *TaskManager) cleanupTask(sessionId string) { 1328 log.Infof("开始清理任务资源: sessionId=%s", sessionId) 1329 1330 // 清理内存中的任务数据 1331 tm.mu.Lock() 1332 delete(tm.tasks, sessionId) 1333 tm.mu.Unlock() 1334 1335 // 注意:SSE连接已在resultUpdate事件处理中立即清理 1336 tm.CloseSSESession(sessionId) 1337 1338 log.Infof("任务清理完成: sessionId=%s", sessionId) 1339 } 1340 1341 // GetTaskDetail 获取任务详情 1342 func (tm *TaskManager) GetTaskDetail(sessionId string, username string, traceID string) (map[string]interface{}, error) { 1343 log.Infof("开始获取任务详情: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 1344 1345 // 检查任务是否存在 1346 session, err := tm.taskStore.GetSession(sessionId) 1347 if err != nil { 1348 log.Errorf("获取任务详情失败: trace_id=%s, sessionId=%s, username=%s, error=%v", traceID, sessionId, username, err) 1349 return nil, fmt.Errorf("任务不存在") 1350 } 1351 1352 // 验证用户权限(只有任务创建者才能查看) 1353 if !session.Share && session.Username != username { 1354 log.Errorf("无权限访问任务详情: trace_id=%s, sessionId=%s, username=%s, owner=%s", traceID, sessionId, username, session.Username) 1355 return nil, fmt.Errorf("无权限查看此任务") 1356 } 1357 1358 // 获取任务的所有消息 1359 messages, err := tm.taskStore.GetSessionMessages(sessionId) 1360 if err != nil { 1361 log.Errorf("获取任务消息失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 1362 return nil, fmt.Errorf("获取任务消息失败: %v", err) 1363 } 1364 1365 // 处理附件信息 1366 var attachments []map[string]interface{} 1367 if session.Attachments != nil { 1368 var attachmentURLs []string 1369 if err := json.Unmarshal(session.Attachments, &attachmentURLs); err == nil { 1370 for _, url := range attachmentURLs { 1371 // 从URL中提取文件名 1372 fileName := tm.extractFileNameFromURL(url) 1373 attachments = append(attachments, map[string]interface{}{ 1374 "filename": fileName, 1375 "fileUrl": url, 1376 }) 1377 } 1378 } 1379 } 1380 1381 // 处理消息列表 1382 var messageList []map[string]interface{} 1383 for _, msg := range messages { 1384 // 解析事件数据 1385 var eventData map[string]interface{} 1386 if err := json.Unmarshal(msg.EventData, &eventData); err != nil { 1387 continue 1388 } 1389 1390 messageList = append(messageList, map[string]interface{}{ 1391 "id": msg.ID, 1392 "type": msg.Type, 1393 "timestamp": msg.Timestamp, 1394 "event": eventData, 1395 }) 1396 } 1397 1398 // 处理任务参数 1399 var params map[string]interface{} 1400 if session.Params != nil { 1401 if err := json.Unmarshal(session.Params, ¶ms); err != nil { 1402 log.Warnf("解析任务参数失败: trace_id=%s, sessionId=%s, error=%v", traceID, sessionId, err) 1403 params = make(map[string]interface{}) 1404 } 1405 } else { 1406 params = make(map[string]interface{}) 1407 } 1408 // Mask token fields in params to avoid leaking sensitive credentials 1409 maskParamsToken(params) 1410 1411 // 构建返回数据 1412 source, sourceLabel := resolveTaskSource(session) 1413 detail := map[string]interface{}{ 1414 "sessionId": session.ID, 1415 "title": decorateTaskTitle(session.Title, sourceLabel), 1416 "rawTitle": session.Title, 1417 "status": session.Status, 1418 "countryIsoCode": session.CountryIsoCode, 1419 "createdAt": session.CreatedAt, 1420 "content": session.Content, 1421 "params": params, 1422 "taskType": session.TaskType, 1423 "attachments": attachments, 1424 "messages": messageList, 1425 "source": source, 1426 "sourceLabel": sourceLabel, 1427 } 1428 if session.Username != username { 1429 delete(detail, "attachments") 1430 } 1431 1432 log.Infof("获取任务详情成功: trace_id=%s, sessionId=%s, username=%s", traceID, sessionId, username) 1433 return detail, nil 1434 } 1435 1436 // extractFileNameFromURL 从文件URL中提取原始文件名 1437 func (tm *TaskManager) extractFileNameFromURL(url string) string { 1438 // 新的文件名格式: UUID_原始文件名.扩展名 1439 if strings.Contains(url, "/") { 1440 parts := strings.Split(url, "/") 1441 if len(parts) > 0 { 1442 fileName := parts[len(parts)-1] 1443 // 新的文件名格式: UUID_原始文件名.扩展名 1444 if strings.Contains(fileName, "_") { 1445 // 查找第一个下划线,之后的部分是原始文件名 1446 firstUnderscoreIndex := strings.Index(fileName, "_") 1447 if firstUnderscoreIndex > 0 { 1448 // 返回下划线后的部分作为原始文件名 1449 return fileName[firstUnderscoreIndex+1:] 1450 } 1451 } 1452 // 如果没有下划线,直接返回文件名 1453 return fileName 1454 } 1455 } 1456 return url 1457 } 1458 1459 // DownloadFile 下载文件 1460 func (tm *TaskManager) DownloadFile(sessionId string, fileUrl string, username string, c *gin.Context, traceID string) error { 1461 log.Infof("开始文件下载: trace_id=%s, sessionId=%s, fileUrl=%s, username=%s", traceID, sessionId, fileUrl, username) 1462 1463 filename := strings.TrimLeft(fileUrl, "/") 1464 filePath, _ := filepath.Abs(filepath.Join(tm.fileConfig.UploadDir, filename)) 1465 1466 if !strings.HasPrefix(filePath, tm.fileConfig.UploadDir) { 1467 return fmt.Errorf("文件路径不合法") 1468 } 1469 1470 if _, err := os.Stat(filePath); os.IsNotExist(err) { 1471 log.Errorf("本地文件不存在: trace_id=%s, filePath=%s", traceID, filePath) 1472 return fmt.Errorf("文件不存在") 1473 } 1474 1475 fileInfo, err := os.Stat(filePath) 1476 if err != nil { 1477 log.Errorf("获取文件信息失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) 1478 return fmt.Errorf("获取文件信息失败: %v", err) 1479 } 1480 1481 log.Debugf("文件信息获取成功: trace_id=%s, filePath=%s, size=%d", traceID, filePath, fileInfo.Size()) 1482 1483 // 8. 设置响应头 1484 // 获取文件的MIME类型 1485 ext := filepath.Ext(filePath) 1486 mimeType := mime.TypeByExtension(ext) 1487 if mimeType == "" { 1488 mimeType = "application/octet-stream" 1489 } 1490 1491 // 设置Content-Type 1492 c.Header("Content-Type", mimeType) 1493 1494 // 设置Content-Disposition,支持中文文件名 1495 // 使用UTF-8编码处理中文文件名 1496 encodedFileName := url.QueryEscape(filepath.Base(filePath)) 1497 c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"; filename*=UTF-8''%s", encodedFileName, encodedFileName)) 1498 1499 // 设置Content-Length 1500 c.Header("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) 1501 1502 // 9. 打开文件并流式传输 1503 file, err := os.Open(filePath) 1504 if err != nil { 1505 log.Errorf("打开文件失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) 1506 return fmt.Errorf("打开文件失败: %v", err) 1507 } 1508 defer file.Close() 1509 1510 // 10. 流式传输文件内容 1511 written, err := io.Copy(c.Writer, file) 1512 if err != nil { 1513 log.Errorf("文件传输失败: trace_id=%s, filePath=%s, error=%v", traceID, filePath, err) 1514 return fmt.Errorf("传输文件失败: %v", err) 1515 } 1516 log.Infof("文件下载成功: trace_id=%s, sessionId=%s, fileName=%s, fileSize=%d, transmittedSize=%d", 1517 traceID, sessionId, filePath, fileInfo.Size(), written) 1518 return nil 1519 } 1520 1521 // maskParamsToken masks sensitive token fields in task params before returning to the client. 1522 // It handles both single model object and list of model objects under the "model" key. 1523 func maskParamsToken(params map[string]interface{}) { 1524 const masked = "********" 1525 maskModel := func(m interface{}) { 1526 if obj, ok := m.(map[string]interface{}); ok { 1527 if _, has := obj["token"]; has { 1528 obj["token"] = masked 1529 } 1530 } 1531 } 1532 if v, ok := params["model"]; ok { 1533 switch val := v.(type) { 1534 case map[string]interface{}: 1535 maskModel(val) 1536 case []interface{}: 1537 for _, item := range val { 1538 maskModel(item) 1539 } 1540 } 1541 } 1542 if v, ok := params["eval_model"]; ok { 1543 maskModel(v) 1544 } 1545 } 1546