agent.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package websocket 20 21 import ( 22 "encoding/json" 23 "fmt" 24 "strings" 25 "sync" 26 "time" 27 28 "github.com/gin-gonic/gin" 29 "github.com/go-playground/validator/v10" 30 "github.com/gorilla/websocket" 31 "trpc.group/trpc-go/trpc-go/log" 32 // "gorm.io/datatypes" 33 ) 34 35 const ( 36 // WebSocket相关常量 37 maxMessageSize = 512 * 1024 * 1024 // 512MB 38 pongWait = 120 * time.Second 39 pingPeriod = (pongWait * 8) / 10 40 writeWait = 60 * time.Second 41 WSMsgTypeRegister = "register" 42 // WSMsgTypeTaskAssign = "task_assign" // 任务分配 43 WSMsgTypeDisconnect = "disconnect" // 主动断开连接的消息类型 44 45 // Agent 端事件类型(与前端 SSE 事件类型一致) 46 WSMsgTypeLiveStatus = "liveStatus" // 存活状态 47 WSMsgTypePlanUpdate = "planUpdate" // 计划更新 48 WSMsgTypeNewPlanStep = "newPlanStep" // 新计划步骤 49 WSMsgTypeStatusUpdate = "statusUpdate" // 状态更新 50 WSMsgTypeToolUsed = "toolUsed" // 工具使用 51 WSMsgTypeResultUpdate = "resultUpdate" // 结果更新 52 WSMsgTypeActionLog = "actionLog" // 日志 53 WSMsgTypeError = "error" // 日志 54 ) 55 56 // Agent 端事件消息(Agent -> Server,直接使用 task.go 中的结构体) 57 // 注意:Agent 端返回的事件体直接使用 task.go 中定义的结构体: 58 // - LiveStatusEvent 59 // - PlanUpdateEvent 60 // - PlanTaskItem 61 // - NewPlanStepEvent 62 // - StatusUpdateEvent 63 // - ToolUsedEvent 64 // 这样可以确保格式完全一致,避免重复定义 65 66 // AgentConnection 管理单个agent的连接 67 type AgentConnection struct { 68 conn *websocket.Conn 69 agentID string 70 71 // 细粒度的锁控制 72 stateMu sync.RWMutex // 保护连接状态(agentID, isActive) 73 writeMu sync.Mutex // 保护写操作(发送消息) 74 75 isActive bool 76 } 77 78 // AgentManager 管理所有agent连接 79 type AgentManager struct { 80 connections map[string]*AgentConnection 81 mu sync.RWMutex 82 taskManager *TaskManager // 新增:引用 TaskManager 83 // store *database.AgentStore // 注释掉数据库字段 84 } 85 86 // 注册/心跳消息内容 87 type AgentRegisterContent struct { 88 AgentID string `json:"agent_id" validate:"required"` // 必需字段 89 Hostname string `json:"hostname" validate:"required"` // 必需字段 90 IP string `json:"ip" validate:"required,ip"` // 必需且必须是IP格式 91 Version string `json:"version" validate:"required"` // 必需字段 92 Capabilities []string `json:"capabilities,omitempty"` // 可选字段 93 Meta string `json:"meta,omitempty"` // 可选字段 94 } 95 96 // 断开连接消息内容 97 type DisconnectContent struct { 98 AgentID string `json:"agent_id" validate:"required"` // 必需字段 99 Reason string `json:"reason,omitempty"` // 可选字段 100 } 101 102 // 全局验证器实例 103 var validate *validator.Validate 104 105 // 初始化验证器 106 func init() { 107 validate = validator.New() 108 } 109 110 // formatValidationErrors 格式化验证错误信息 111 func formatValidationErrors(err error) string { 112 if validationErrors, ok := err.(validator.ValidationErrors); ok { 113 var errorMessages []string 114 for _, fieldError := range validationErrors { 115 fieldName := fieldError.Field() 116 switch fieldError.Tag() { 117 case "required": 118 errorMessages = append(errorMessages, 119 fmt.Sprintf("缺少必需字段: %s", fieldName)) 120 case "ip": 121 errorMessages = append(errorMessages, 122 fmt.Sprintf("字段 %s 必须是有效的IP地址", fieldName)) 123 case "email": 124 errorMessages = append(errorMessages, 125 fmt.Sprintf("字段 %s 必须是有效的邮箱格式", fieldName)) 126 case "url": 127 errorMessages = append(errorMessages, 128 fmt.Sprintf("字段 %s 必须是有效的URL", fieldName)) 129 case "min": 130 errorMessages = append(errorMessages, 131 fmt.Sprintf("字段 %s 长度不能小于 %s", fieldName, fieldError.Param())) 132 case "max": 133 errorMessages = append(errorMessages, 134 fmt.Sprintf("字段 %s 长度不能大于 %s", fieldName, fieldError.Param())) 135 default: 136 errorMessages = append(errorMessages, 137 fmt.Sprintf("字段 %s 验证失败: %s", fieldName, fieldError.Tag())) 138 } 139 } 140 return fmt.Sprintf("验证失败: %s", strings.Join(errorMessages, "; ")) 141 } 142 return "验证失败" 143 } 144 145 // NewAgentManager 创建新的AgentManager 146 // func NewAgentManager(store *database.AgentStore) *AgentManager { 147 func NewAgentManager() *AgentManager { 148 return &AgentManager{ 149 connections: make(map[string]*AgentConnection), 150 // store: store, 151 } 152 } 153 154 // NewAgentConnection 创建新的AgentConnection 155 // func NewAgentConnection(conn *websocket.Conn, store *database.AgentStore) *AgentConnection { 156 func NewAgentConnection(conn *websocket.Conn) *AgentConnection { 157 return &AgentConnection{ 158 conn: conn, 159 // store: store, 160 isActive: true, 161 } 162 } 163 164 // HandleAgentWebSocket 处理agent的WebSocket连接 165 func (am *AgentManager) HandleAgentWebSocket() gin.HandlerFunc { 166 167 return func(c *gin.Context) { 168 conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) 169 if err != nil { 170 log.Errorf("WebSocket升级失败: error=%v", err) 171 return 172 } 173 174 // ac := NewAgentConnection(conn, am.store) 175 ac := NewAgentConnection(conn) 176 log.Infof("新的Agent连接建立: remoteAddr=%s", conn.RemoteAddr().String()) 177 go ac.handleConnection(am) 178 } 179 } 180 181 // handleConnection 处理单个连接的消息 182 func (ac *AgentConnection) handleConnection(am *AgentManager) { 183 defer func() { 184 ac.stateMu.RLock() 185 agentID := ac.agentID 186 remoteAddr := ac.conn.RemoteAddr().String() 187 ac.stateMu.RUnlock() 188 ac.cleanup(am) 189 log.Infof("Agent连接处理结束: agentId=%s, remoteAddr=%s", agentID, remoteAddr) 190 }() 191 192 // 设置连接参数 193 ac.conn.SetReadLimit(maxMessageSize) 194 ac.conn.SetPongHandler(func(string) error { 195 ac.conn.SetReadDeadline(time.Now().Add(pongWait)) 196 return nil 197 }) 198 199 // 启动心跳检测 200 go ac.writePump() 201 202 // 处理消息 203 for { 204 _, message, err := ac.conn.ReadMessage() 205 if err != nil { 206 ac.stateMu.RLock() 207 agentID := ac.agentID 208 ac.stateMu.RUnlock() 209 210 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 211 log.Errorf("Agent连接异常断开: agentId=%s, error=%v", agentID, err) 212 } else { 213 log.Infof("Agent连接正常断开: agentId=%s, closeCode=%v", agentID, err) 214 } 215 break 216 } 217 218 var wsMsg WSMessage 219 if err := json.Unmarshal(message, &wsMsg); err != nil { 220 log.Errorf("Agent消息解析失败: agentId=%s, error=%v", ac.agentID, err) 221 // 发送错误响应但不断开连接 222 ac.sendError("消息格式错误请检查JSON格式") 223 continue 224 } 225 226 // 验证消息类型 227 if wsMsg.Type == "" { 228 log.Errorf("Agent消息类型为空: agentId=%s", ac.agentID) 229 ac.sendError("消息类型不能为空") 230 continue 231 } 232 233 switch wsMsg.Type { 234 case WSMsgTypeRegister: 235 ac.handleRegister(am, wsMsg.Content) 236 case WSMsgTypeDisconnect: 237 // 只有在身份验证成功时才断开连接 238 ac.handleDisconnect(am, wsMsg.Content) 239 // 检查连接是否仍然活跃,如果不活跃则退出 240 ac.stateMu.RLock() 241 if !ac.isActive { 242 ac.stateMu.RUnlock() 243 return 244 } 245 ac.stateMu.RUnlock() 246 case WSMsgTypeLiveStatus, WSMsgTypePlanUpdate, WSMsgTypeNewPlanStep, WSMsgTypeStatusUpdate, WSMsgTypeToolUsed, WSMsgTypeResultUpdate, WSMsgTypeActionLog, WSMsgTypeError: 247 // 所有事件类型都统一处理 248 ac.handleAgentEvent(am, wsMsg.Content, wsMsg.Type) 249 default: 250 log.Warnf("Agent发送未知消息类型: agentId=%s, type=%s", ac.agentID, wsMsg.Type) 251 ac.sendError(fmt.Sprintf("未知的消息类型: %s。支持的类型: register, disconnect, liveStatus, planUpdate, newPlanStep, statusUpdate, toolUsed, resultUpdate, actionLog", wsMsg.Type)) 252 } 253 } 254 } 255 256 // handleRegister 处理注册消息 257 func (ac *AgentConnection) handleRegister(am *AgentManager, content interface{}) { 258 contentBytes, _ := json.Marshal(content) 259 var rc AgentRegisterContent 260 if err := json.Unmarshal(contentBytes, &rc); err != nil { 261 log.Errorf("Agent注册消息解析失败: error=%v", err) 262 ac.sendError("注册消息格式错误") 263 return 264 } 265 266 // 使用validator验证结构体 267 if err := validate.Struct(rc); err != nil { 268 errorMsg := formatValidationErrors(err) 269 log.Errorf("Agent注册验证失败: agentId=%s, error=%s", rc.AgentID, errorMsg) 270 ac.sendError(errorMsg) 271 return 272 } 273 274 // 检查是否已存在相同ID的Agent 275 am.mu.Lock() 276 if existingConn, exists := am.connections[rc.AgentID]; exists { 277 am.mu.Unlock() 278 log.Warnf("Agent ID已存在,断开旧连接: agentId=%s", rc.AgentID) 279 // 断开旧连接 280 existingConn.stateMu.Lock() 281 existingConn.isActive = false 282 existingConn.stateMu.Unlock() 283 existingConn.conn.Close() 284 } else { 285 am.mu.Unlock() 286 } 287 288 // 注册新连接 289 am.mu.Lock() 290 am.connections[rc.AgentID] = ac 291 am.mu.Unlock() 292 293 // 更新连接状态 294 ac.stateMu.Lock() 295 ac.agentID = rc.AgentID 296 ac.isActive = true 297 ac.stateMu.Unlock() 298 299 log.Infof("Agent注册成功: agentId=%s, hostname=%s, ip=%s, version=%s", rc.AgentID, rc.Hostname, rc.IP, rc.Version) 300 // 发送注册成功响应 301 response := WSMessage{ 302 Type: "register_ack", 303 Content: Response{ 304 Status: 0, 305 Message: "注册成功", 306 }, 307 } 308 ac.conn.WriteJSON(response) 309 } 310 311 // handleDisconnect 处理主动断开连接 312 func (ac *AgentConnection) handleDisconnect(am *AgentManager, content interface{}) { 313 contentBytes, _ := json.Marshal(content) 314 var dc DisconnectContent 315 if err := json.Unmarshal(contentBytes, &dc); err != nil { 316 ac.sendError("断开连接消息格式错误") 317 return 318 } 319 320 // 使用validator验证结构体 321 if err := validate.Struct(dc); err != nil { 322 errorMsg := formatValidationErrors(err) 323 ac.sendError(errorMsg) 324 return 325 } 326 327 // 验证身份一致性 328 ac.stateMu.RLock() 329 agentID := ac.agentID 330 ac.stateMu.RUnlock() 331 332 if agentID == "" || agentID != dc.AgentID { 333 ac.sendError("断开连接消息身份验证失败") 334 return 335 } 336 337 // 从连接管理器中移除 338 am.mu.Lock() 339 delete(am.connections, agentID) 340 am.mu.Unlock() 341 342 // 发送断开确认 343 response := WSMessage{ 344 Type: "disconnect_ack", 345 Content: Response{ 346 Status: 0, 347 Message: "断开连接成功", 348 }, 349 } 350 ac.conn.WriteJSON(response) 351 352 // 标记连接为非活跃 353 ac.stateMu.Lock() 354 ac.isActive = false 355 ac.stateMu.Unlock() 356 } 357 358 // writePump 发送心跳包 359 func (ac *AgentConnection) writePump() { 360 ticker := time.NewTicker(pingPeriod) 361 defer func() { 362 ticker.Stop() 363 log.Infof("Agent心跳检测已停止: agentId=%s", ac.agentID) 364 }() 365 366 log.Infof("Agent心跳检测已启动: agentId=%s, pingPeriod=%v", ac.agentID, pingPeriod) 367 368 for range ticker.C { 369 ac.stateMu.RLock() 370 if !ac.isActive { 371 ac.stateMu.RUnlock() 372 log.Infof("Agent连接已标记为非活跃,停止心跳检测: agentId=%s", ac.agentID) 373 return 374 } 375 agentID := ac.agentID 376 ac.stateMu.RUnlock() 377 378 // 设置写超时 379 ac.conn.SetWriteDeadline(time.Now().Add(writeWait)) 380 381 // 尝试发送ping消息 382 err := ac.conn.WriteMessage(websocket.PingMessage, nil) 383 if err != nil { 384 log.Warnf("Agent心跳发送失败,准备重试: agentId=%s, error=%v", agentID, err) 385 386 // 尝试重试一次 387 time.Sleep(1 * time.Second) 388 ac.stateMu.RLock() 389 if !ac.isActive { 390 ac.stateMu.RUnlock() 391 log.Infof("Agent连接在重试期间已标记为非活跃: agentId=%s", agentID) 392 return 393 } 394 ac.stateMu.RUnlock() 395 396 ac.conn.SetWriteDeadline(time.Now().Add(writeWait)) 397 err = ac.conn.WriteMessage(websocket.PingMessage, nil) 398 if err != nil { 399 log.Errorf("Agent心跳重试失败,连接已失效: agentId=%s, error=%v", agentID, err) 400 401 // 标记连接为非活跃 402 ac.stateMu.Lock() 403 ac.isActive = false 404 ac.stateMu.Unlock() 405 406 log.Errorf("Agent连接已标记为失效: agentId=%s, 原因=心跳失败", agentID) 407 return 408 } else { 409 log.Infof("Agent心跳重试成功: agentId=%s", agentID) 410 } 411 } else { 412 log.Debugf("Agent心跳发送成功: agentId=%s", agentID) 413 } 414 } 415 } 416 417 // cleanup 清理连接 418 func (ac *AgentConnection) cleanup(am *AgentManager) { 419 ac.stateMu.Lock() 420 agentID := ac.agentID 421 wasActive := ac.isActive 422 ac.isActive = false 423 ac.stateMu.Unlock() 424 425 log.Infof("开始清理Agent连接: agentId=%s, wasActive=%v", agentID, wasActive) 426 427 if agentID != "" { 428 am.mu.Lock() 429 // 检查是否真的存在于连接管理器中 430 if _, exists := am.connections[agentID]; exists { 431 delete(am.connections, agentID) 432 log.Infof("Agent已从连接管理器中移除: agentId=%s", agentID) 433 } else { 434 log.Warnf("Agent不在连接管理器中,可能已被移除: agentId=%s", agentID) 435 } 436 am.mu.Unlock() 437 438 // ac.store.UpdateOnlineStatus(ac.agentID, false) 439 } else { 440 log.Warnf("清理未注册的Agent连接: remoteAddr=%s", ac.conn.RemoteAddr().String()) 441 } 442 443 // 关闭WebSocket连接 444 err := ac.conn.Close() 445 if err != nil { 446 log.Warnf("关闭Agent连接时出错: agentId=%s, error=%v", agentID, err) 447 } else { 448 log.Infof("Agent连接已关闭: agentId=%s", agentID) 449 } 450 451 log.Infof("Agent连接清理完成: agentId=%s", agentID) 452 } 453 454 // sendError 发送错误响应 455 func (ac *AgentConnection) sendError(message string) { 456 response := WSMessage{ 457 Type: "error", 458 Content: Response{ 459 Status: 1, 460 Message: message, 461 }, 462 } 463 464 // 设置写超时 465 ac.conn.SetWriteDeadline(time.Now().Add(writeWait)) 466 467 err := ac.conn.WriteJSON(response) 468 if err != nil { 469 // 如果发送错误响应都失败,说明连接可能有问题 470 ac.stateMu.Lock() 471 ac.isActive = false 472 ac.stateMu.Unlock() 473 } 474 } 475 476 // 通用事件处理函数 477 func (ac *AgentConnection) handleAgentEvent(am *AgentManager, content interface{}, eventType string) { 478 contentBytes, err := json.Marshal(content) 479 if err != nil { 480 log.Errorf("Agent事件序列化失败: agentId=%s, eventType=%s, error=%v", ac.agentID, eventType, err) 481 ac.sendError(fmt.Sprintf("%s事件序列化失败: %v", eventType, err)) 482 return 483 } 484 485 var eventMessage TaskEventMessage 486 if err := json.Unmarshal(contentBytes, &eventMessage); err != nil { 487 log.Errorf("Agent事件格式错误: agentId=%s, eventType=%s, error=%v", ac.agentID, eventType, err) 488 ac.sendError(fmt.Sprintf("%s事件格式错误: %v", eventType, err)) 489 return 490 } 491 492 // 使用validator验证TaskEventMessage 493 if err := validate.Struct(eventMessage); err != nil { 494 errorMsg := formatValidationErrors(err) 495 log.Errorf("Agent事件验证失败: agentId=%s, eventType=%s, error=%s", ac.agentID, eventType, errorMsg) 496 ac.sendError(fmt.Sprintf("%s事件验证失败: %s", eventType, errorMsg)) 497 return 498 } 499 500 // 从TaskEventMessage中提取sessionId和事件数据 501 sessionId := eventMessage.SessionID 502 event := eventMessage.Event 503 504 log.Debugf("收到Agent事件: agentId=%s, sessionId=%s, eventType=%s", ac.agentID, sessionId, eventType) 505 506 // 转发给 TaskManager 处理 507 am.mu.RLock() 508 am.taskManager.HandleAgentEvent(sessionId, eventType, event) 509 am.mu.RUnlock() 510 } 511 512 // 添加获取可用 Agent 的方法 513 func (am *AgentManager) GetAvailableAgents() []*AgentConnection { 514 am.mu.RLock() 515 defer am.mu.RUnlock() 516 517 var availableAgents []*AgentConnection 518 for _, conn := range am.connections { 519 conn.stateMu.RLock() 520 if conn.isActive { 521 availableAgents = append(availableAgents, conn) 522 } 523 conn.stateMu.RUnlock() 524 } 525 return availableAgents 526 } 527 528 // SetTaskManager 设置 TaskManager 引用 529 func (am *AgentManager) SetTaskManager(taskManager *TaskManager) { 530 am.mu.Lock() 531 defer am.mu.Unlock() 532 am.taskManager = taskManager 533 }