/ common / websocket / agent.go
agent.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package websocket
 20  
 21  import (
 22  	"encoding/json"
 23  	"fmt"
 24  	"strings"
 25  	"sync"
 26  	"time"
 27  
 28  	"github.com/gin-gonic/gin"
 29  	"github.com/go-playground/validator/v10"
 30  	"github.com/gorilla/websocket"
 31  	"trpc.group/trpc-go/trpc-go/log"
 32  	// "gorm.io/datatypes"
 33  )
 34  
 35  const (
 36  	// WebSocket相关常量
 37  	maxMessageSize    = 512 * 1024 * 1024 // 512MB
 38  	pongWait          = 120 * time.Second
 39  	pingPeriod        = (pongWait * 8) / 10
 40  	writeWait         = 60 * time.Second
 41  	WSMsgTypeRegister = "register"
 42  	// WSMsgTypeTaskAssign = "task_assign" // 任务分配
 43  	WSMsgTypeDisconnect = "disconnect" // 主动断开连接的消息类型
 44  
 45  	// Agent 端事件类型(与前端 SSE 事件类型一致)
 46  	WSMsgTypeLiveStatus   = "liveStatus"   // 存活状态
 47  	WSMsgTypePlanUpdate   = "planUpdate"   // 计划更新
 48  	WSMsgTypeNewPlanStep  = "newPlanStep"  // 新计划步骤
 49  	WSMsgTypeStatusUpdate = "statusUpdate" // 状态更新
 50  	WSMsgTypeToolUsed     = "toolUsed"     // 工具使用
 51  	WSMsgTypeResultUpdate = "resultUpdate" // 结果更新
 52  	WSMsgTypeActionLog    = "actionLog"    // 日志
 53  	WSMsgTypeError        = "error"        // 日志
 54  )
 55  
 56  // Agent 端事件消息(Agent -> Server,直接使用 task.go 中的结构体)
 57  // 注意:Agent 端返回的事件体直接使用 task.go 中定义的结构体:
 58  // - LiveStatusEvent
 59  // - PlanUpdateEvent
 60  // - PlanTaskItem
 61  // - NewPlanStepEvent
 62  // - StatusUpdateEvent
 63  // - ToolUsedEvent
 64  // 这样可以确保格式完全一致,避免重复定义
 65  
 66  // AgentConnection 管理单个agent的连接
 67  type AgentConnection struct {
 68  	conn    *websocket.Conn
 69  	agentID string
 70  
 71  	// 细粒度的锁控制
 72  	stateMu sync.RWMutex // 保护连接状态(agentID, isActive)
 73  	writeMu sync.Mutex   // 保护写操作(发送消息)
 74  
 75  	isActive bool
 76  }
 77  
 78  // AgentManager 管理所有agent连接
 79  type AgentManager struct {
 80  	connections map[string]*AgentConnection
 81  	mu          sync.RWMutex
 82  	taskManager *TaskManager // 新增:引用 TaskManager
 83  	// store       *database.AgentStore // 注释掉数据库字段
 84  }
 85  
 86  // 注册/心跳消息内容
 87  type AgentRegisterContent struct {
 88  	AgentID      string   `json:"agent_id" validate:"required"` // 必需字段
 89  	Hostname     string   `json:"hostname" validate:"required"` // 必需字段
 90  	IP           string   `json:"ip" validate:"required,ip"`    // 必需且必须是IP格式
 91  	Version      string   `json:"version" validate:"required"`  // 必需字段
 92  	Capabilities []string `json:"capabilities,omitempty"`       // 可选字段
 93  	Meta         string   `json:"meta,omitempty"`               // 可选字段
 94  }
 95  
 96  // 断开连接消息内容
 97  type DisconnectContent struct {
 98  	AgentID string `json:"agent_id" validate:"required"` // 必需字段
 99  	Reason  string `json:"reason,omitempty"`             // 可选字段
100  }
101  
102  // 全局验证器实例
103  var validate *validator.Validate
104  
105  // 初始化验证器
106  func init() {
107  	validate = validator.New()
108  }
109  
110  // formatValidationErrors 格式化验证错误信息
111  func formatValidationErrors(err error) string {
112  	if validationErrors, ok := err.(validator.ValidationErrors); ok {
113  		var errorMessages []string
114  		for _, fieldError := range validationErrors {
115  			fieldName := fieldError.Field()
116  			switch fieldError.Tag() {
117  			case "required":
118  				errorMessages = append(errorMessages,
119  					fmt.Sprintf("缺少必需字段: %s", fieldName))
120  			case "ip":
121  				errorMessages = append(errorMessages,
122  					fmt.Sprintf("字段 %s 必须是有效的IP地址", fieldName))
123  			case "email":
124  				errorMessages = append(errorMessages,
125  					fmt.Sprintf("字段 %s 必须是有效的邮箱格式", fieldName))
126  			case "url":
127  				errorMessages = append(errorMessages,
128  					fmt.Sprintf("字段 %s 必须是有效的URL", fieldName))
129  			case "min":
130  				errorMessages = append(errorMessages,
131  					fmt.Sprintf("字段 %s 长度不能小于 %s", fieldName, fieldError.Param()))
132  			case "max":
133  				errorMessages = append(errorMessages,
134  					fmt.Sprintf("字段 %s 长度不能大于 %s", fieldName, fieldError.Param()))
135  			default:
136  				errorMessages = append(errorMessages,
137  					fmt.Sprintf("字段 %s 验证失败: %s", fieldName, fieldError.Tag()))
138  			}
139  		}
140  		return fmt.Sprintf("验证失败: %s", strings.Join(errorMessages, "; "))
141  	}
142  	return "验证失败"
143  }
144  
145  // NewAgentManager 创建新的AgentManager
146  // func NewAgentManager(store *database.AgentStore) *AgentManager {
147  func NewAgentManager() *AgentManager {
148  	return &AgentManager{
149  		connections: make(map[string]*AgentConnection),
150  		// store:       store,
151  	}
152  }
153  
154  // NewAgentConnection 创建新的AgentConnection
155  // func NewAgentConnection(conn *websocket.Conn, store *database.AgentStore) *AgentConnection {
156  func NewAgentConnection(conn *websocket.Conn) *AgentConnection {
157  	return &AgentConnection{
158  		conn: conn,
159  		// store:    store,
160  		isActive: true,
161  	}
162  }
163  
164  // HandleAgentWebSocket 处理agent的WebSocket连接
165  func (am *AgentManager) HandleAgentWebSocket() gin.HandlerFunc {
166  
167  	return func(c *gin.Context) {
168  		conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
169  		if err != nil {
170  			log.Errorf("WebSocket升级失败: error=%v", err)
171  			return
172  		}
173  
174  		// ac := NewAgentConnection(conn, am.store)
175  		ac := NewAgentConnection(conn)
176  		log.Infof("新的Agent连接建立: remoteAddr=%s", conn.RemoteAddr().String())
177  		go ac.handleConnection(am)
178  	}
179  }
180  
181  // handleConnection 处理单个连接的消息
182  func (ac *AgentConnection) handleConnection(am *AgentManager) {
183  	defer func() {
184  		ac.stateMu.RLock()
185  		agentID := ac.agentID
186  		remoteAddr := ac.conn.RemoteAddr().String()
187  		ac.stateMu.RUnlock()
188  		ac.cleanup(am)
189  		log.Infof("Agent连接处理结束: agentId=%s, remoteAddr=%s", agentID, remoteAddr)
190  	}()
191  
192  	// 设置连接参数
193  	ac.conn.SetReadLimit(maxMessageSize)
194  	ac.conn.SetPongHandler(func(string) error {
195  		ac.conn.SetReadDeadline(time.Now().Add(pongWait))
196  		return nil
197  	})
198  
199  	// 启动心跳检测
200  	go ac.writePump()
201  
202  	// 处理消息
203  	for {
204  		_, message, err := ac.conn.ReadMessage()
205  		if err != nil {
206  			ac.stateMu.RLock()
207  			agentID := ac.agentID
208  			ac.stateMu.RUnlock()
209  
210  			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
211  				log.Errorf("Agent连接异常断开: agentId=%s, error=%v", agentID, err)
212  			} else {
213  				log.Infof("Agent连接正常断开: agentId=%s, closeCode=%v", agentID, err)
214  			}
215  			break
216  		}
217  
218  		var wsMsg WSMessage
219  		if err := json.Unmarshal(message, &wsMsg); err != nil {
220  			log.Errorf("Agent消息解析失败: agentId=%s, error=%v", ac.agentID, err)
221  			// 发送错误响应但不断开连接
222  			ac.sendError("消息格式错误请检查JSON格式")
223  			continue
224  		}
225  
226  		// 验证消息类型
227  		if wsMsg.Type == "" {
228  			log.Errorf("Agent消息类型为空: agentId=%s", ac.agentID)
229  			ac.sendError("消息类型不能为空")
230  			continue
231  		}
232  
233  		switch wsMsg.Type {
234  		case WSMsgTypeRegister:
235  			ac.handleRegister(am, wsMsg.Content)
236  		case WSMsgTypeDisconnect:
237  			// 只有在身份验证成功时才断开连接
238  			ac.handleDisconnect(am, wsMsg.Content)
239  			// 检查连接是否仍然活跃,如果不活跃则退出
240  			ac.stateMu.RLock()
241  			if !ac.isActive {
242  				ac.stateMu.RUnlock()
243  				return
244  			}
245  			ac.stateMu.RUnlock()
246  		case WSMsgTypeLiveStatus, WSMsgTypePlanUpdate, WSMsgTypeNewPlanStep, WSMsgTypeStatusUpdate, WSMsgTypeToolUsed, WSMsgTypeResultUpdate, WSMsgTypeActionLog, WSMsgTypeError:
247  			// 所有事件类型都统一处理
248  			ac.handleAgentEvent(am, wsMsg.Content, wsMsg.Type)
249  		default:
250  			log.Warnf("Agent发送未知消息类型: agentId=%s, type=%s", ac.agentID, wsMsg.Type)
251  			ac.sendError(fmt.Sprintf("未知的消息类型: %s。支持的类型: register, disconnect, liveStatus, planUpdate, newPlanStep, statusUpdate, toolUsed, resultUpdate, actionLog", wsMsg.Type))
252  		}
253  	}
254  }
255  
256  // handleRegister 处理注册消息
257  func (ac *AgentConnection) handleRegister(am *AgentManager, content interface{}) {
258  	contentBytes, _ := json.Marshal(content)
259  	var rc AgentRegisterContent
260  	if err := json.Unmarshal(contentBytes, &rc); err != nil {
261  		log.Errorf("Agent注册消息解析失败: error=%v", err)
262  		ac.sendError("注册消息格式错误")
263  		return
264  	}
265  
266  	// 使用validator验证结构体
267  	if err := validate.Struct(rc); err != nil {
268  		errorMsg := formatValidationErrors(err)
269  		log.Errorf("Agent注册验证失败: agentId=%s, error=%s", rc.AgentID, errorMsg)
270  		ac.sendError(errorMsg)
271  		return
272  	}
273  
274  	// 检查是否已存在相同ID的Agent
275  	am.mu.Lock()
276  	if existingConn, exists := am.connections[rc.AgentID]; exists {
277  		am.mu.Unlock()
278  		log.Warnf("Agent ID已存在,断开旧连接: agentId=%s", rc.AgentID)
279  		// 断开旧连接
280  		existingConn.stateMu.Lock()
281  		existingConn.isActive = false
282  		existingConn.stateMu.Unlock()
283  		existingConn.conn.Close()
284  	} else {
285  		am.mu.Unlock()
286  	}
287  
288  	// 注册新连接
289  	am.mu.Lock()
290  	am.connections[rc.AgentID] = ac
291  	am.mu.Unlock()
292  
293  	// 更新连接状态
294  	ac.stateMu.Lock()
295  	ac.agentID = rc.AgentID
296  	ac.isActive = true
297  	ac.stateMu.Unlock()
298  
299  	log.Infof("Agent注册成功: agentId=%s, hostname=%s, ip=%s, version=%s", rc.AgentID, rc.Hostname, rc.IP, rc.Version)
300  	// 发送注册成功响应
301  	response := WSMessage{
302  		Type: "register_ack",
303  		Content: Response{
304  			Status:  0,
305  			Message: "注册成功",
306  		},
307  	}
308  	ac.conn.WriteJSON(response)
309  }
310  
311  // handleDisconnect 处理主动断开连接
312  func (ac *AgentConnection) handleDisconnect(am *AgentManager, content interface{}) {
313  	contentBytes, _ := json.Marshal(content)
314  	var dc DisconnectContent
315  	if err := json.Unmarshal(contentBytes, &dc); err != nil {
316  		ac.sendError("断开连接消息格式错误")
317  		return
318  	}
319  
320  	// 使用validator验证结构体
321  	if err := validate.Struct(dc); err != nil {
322  		errorMsg := formatValidationErrors(err)
323  		ac.sendError(errorMsg)
324  		return
325  	}
326  
327  	// 验证身份一致性
328  	ac.stateMu.RLock()
329  	agentID := ac.agentID
330  	ac.stateMu.RUnlock()
331  
332  	if agentID == "" || agentID != dc.AgentID {
333  		ac.sendError("断开连接消息身份验证失败")
334  		return
335  	}
336  
337  	// 从连接管理器中移除
338  	am.mu.Lock()
339  	delete(am.connections, agentID)
340  	am.mu.Unlock()
341  
342  	// 发送断开确认
343  	response := WSMessage{
344  		Type: "disconnect_ack",
345  		Content: Response{
346  			Status:  0,
347  			Message: "断开连接成功",
348  		},
349  	}
350  	ac.conn.WriteJSON(response)
351  
352  	// 标记连接为非活跃
353  	ac.stateMu.Lock()
354  	ac.isActive = false
355  	ac.stateMu.Unlock()
356  }
357  
358  // writePump 发送心跳包
359  func (ac *AgentConnection) writePump() {
360  	ticker := time.NewTicker(pingPeriod)
361  	defer func() {
362  		ticker.Stop()
363  		log.Infof("Agent心跳检测已停止: agentId=%s", ac.agentID)
364  	}()
365  
366  	log.Infof("Agent心跳检测已启动: agentId=%s, pingPeriod=%v", ac.agentID, pingPeriod)
367  
368  	for range ticker.C {
369  		ac.stateMu.RLock()
370  		if !ac.isActive {
371  			ac.stateMu.RUnlock()
372  			log.Infof("Agent连接已标记为非活跃,停止心跳检测: agentId=%s", ac.agentID)
373  			return
374  		}
375  		agentID := ac.agentID
376  		ac.stateMu.RUnlock()
377  
378  		// 设置写超时
379  		ac.conn.SetWriteDeadline(time.Now().Add(writeWait))
380  
381  		// 尝试发送ping消息
382  		err := ac.conn.WriteMessage(websocket.PingMessage, nil)
383  		if err != nil {
384  			log.Warnf("Agent心跳发送失败,准备重试: agentId=%s, error=%v", agentID, err)
385  
386  			// 尝试重试一次
387  			time.Sleep(1 * time.Second)
388  			ac.stateMu.RLock()
389  			if !ac.isActive {
390  				ac.stateMu.RUnlock()
391  				log.Infof("Agent连接在重试期间已标记为非活跃: agentId=%s", agentID)
392  				return
393  			}
394  			ac.stateMu.RUnlock()
395  
396  			ac.conn.SetWriteDeadline(time.Now().Add(writeWait))
397  			err = ac.conn.WriteMessage(websocket.PingMessage, nil)
398  			if err != nil {
399  				log.Errorf("Agent心跳重试失败,连接已失效: agentId=%s, error=%v", agentID, err)
400  
401  				// 标记连接为非活跃
402  				ac.stateMu.Lock()
403  				ac.isActive = false
404  				ac.stateMu.Unlock()
405  
406  				log.Errorf("Agent连接已标记为失效: agentId=%s, 原因=心跳失败", agentID)
407  				return
408  			} else {
409  				log.Infof("Agent心跳重试成功: agentId=%s", agentID)
410  			}
411  		} else {
412  			log.Debugf("Agent心跳发送成功: agentId=%s", agentID)
413  		}
414  	}
415  }
416  
417  // cleanup 清理连接
418  func (ac *AgentConnection) cleanup(am *AgentManager) {
419  	ac.stateMu.Lock()
420  	agentID := ac.agentID
421  	wasActive := ac.isActive
422  	ac.isActive = false
423  	ac.stateMu.Unlock()
424  
425  	log.Infof("开始清理Agent连接: agentId=%s, wasActive=%v", agentID, wasActive)
426  
427  	if agentID != "" {
428  		am.mu.Lock()
429  		// 检查是否真的存在于连接管理器中
430  		if _, exists := am.connections[agentID]; exists {
431  			delete(am.connections, agentID)
432  			log.Infof("Agent已从连接管理器中移除: agentId=%s", agentID)
433  		} else {
434  			log.Warnf("Agent不在连接管理器中,可能已被移除: agentId=%s", agentID)
435  		}
436  		am.mu.Unlock()
437  
438  		// ac.store.UpdateOnlineStatus(ac.agentID, false)
439  	} else {
440  		log.Warnf("清理未注册的Agent连接: remoteAddr=%s", ac.conn.RemoteAddr().String())
441  	}
442  
443  	// 关闭WebSocket连接
444  	err := ac.conn.Close()
445  	if err != nil {
446  		log.Warnf("关闭Agent连接时出错: agentId=%s, error=%v", agentID, err)
447  	} else {
448  		log.Infof("Agent连接已关闭: agentId=%s", agentID)
449  	}
450  
451  	log.Infof("Agent连接清理完成: agentId=%s", agentID)
452  }
453  
454  // sendError 发送错误响应
455  func (ac *AgentConnection) sendError(message string) {
456  	response := WSMessage{
457  		Type: "error",
458  		Content: Response{
459  			Status:  1,
460  			Message: message,
461  		},
462  	}
463  
464  	// 设置写超时
465  	ac.conn.SetWriteDeadline(time.Now().Add(writeWait))
466  
467  	err := ac.conn.WriteJSON(response)
468  	if err != nil {
469  		// 如果发送错误响应都失败,说明连接可能有问题
470  		ac.stateMu.Lock()
471  		ac.isActive = false
472  		ac.stateMu.Unlock()
473  	}
474  }
475  
476  // 通用事件处理函数
477  func (ac *AgentConnection) handleAgentEvent(am *AgentManager, content interface{}, eventType string) {
478  	contentBytes, err := json.Marshal(content)
479  	if err != nil {
480  		log.Errorf("Agent事件序列化失败: agentId=%s, eventType=%s, error=%v", ac.agentID, eventType, err)
481  		ac.sendError(fmt.Sprintf("%s事件序列化失败: %v", eventType, err))
482  		return
483  	}
484  
485  	var eventMessage TaskEventMessage
486  	if err := json.Unmarshal(contentBytes, &eventMessage); err != nil {
487  		log.Errorf("Agent事件格式错误: agentId=%s, eventType=%s, error=%v", ac.agentID, eventType, err)
488  		ac.sendError(fmt.Sprintf("%s事件格式错误: %v", eventType, err))
489  		return
490  	}
491  
492  	// 使用validator验证TaskEventMessage
493  	if err := validate.Struct(eventMessage); err != nil {
494  		errorMsg := formatValidationErrors(err)
495  		log.Errorf("Agent事件验证失败: agentId=%s, eventType=%s, error=%s", ac.agentID, eventType, errorMsg)
496  		ac.sendError(fmt.Sprintf("%s事件验证失败: %s", eventType, errorMsg))
497  		return
498  	}
499  
500  	// 从TaskEventMessage中提取sessionId和事件数据
501  	sessionId := eventMessage.SessionID
502  	event := eventMessage.Event
503  
504  	log.Debugf("收到Agent事件: agentId=%s, sessionId=%s, eventType=%s", ac.agentID, sessionId, eventType)
505  
506  	// 转发给 TaskManager 处理
507  	am.mu.RLock()
508  	am.taskManager.HandleAgentEvent(sessionId, eventType, event)
509  	am.mu.RUnlock()
510  }
511  
512  // 添加获取可用 Agent 的方法
513  func (am *AgentManager) GetAvailableAgents() []*AgentConnection {
514  	am.mu.RLock()
515  	defer am.mu.RUnlock()
516  
517  	var availableAgents []*AgentConnection
518  	for _, conn := range am.connections {
519  		conn.stateMu.RLock()
520  		if conn.isActive {
521  			availableAgents = append(availableAgents, conn)
522  		}
523  		conn.stateMu.RUnlock()
524  	}
525  	return availableAgents
526  }
527  
528  // SetTaskManager 设置 TaskManager 引用
529  func (am *AgentManager) SetTaskManager(taskManager *TaskManager) {
530  	am.mu.Lock()
531  	defer am.mu.Unlock()
532  	am.taskManager = taskManager
533  }