agent.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package agent 20 21 import ( 22 "context" 23 "encoding/json" 24 "errors" 25 "fmt" 26 "net/url" 27 "sync" 28 "time" 29 30 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 31 32 "github.com/google/uuid" 33 "github.com/gorilla/websocket" 34 ) 35 36 // Agent 客户端结构 37 type Agent struct { 38 // 基本信息 39 info AgentInfo 40 serverURL string 41 conn *websocket.Conn 42 43 // 任务管理 44 Tasks []*TaskContext 45 taskFunc []TaskInterface 46 47 // 通信管理 48 sendChan chan interface{} 49 ctx context.Context 50 cancel context.CancelFunc 51 52 // 配置 53 mutex sync.RWMutex 54 } 55 56 // TaskContext 任务上下文 57 type TaskContext struct { 58 Request TaskRequest 59 Status string 60 Progress int 61 StartTime time.Time 62 Cancel context.CancelFunc 63 Result interface{} 64 Error error 65 } 66 67 // AgentConfig Agent配置 68 type AgentConfig struct { 69 ServerURL string 70 Info AgentInfo 71 } 72 73 // NewAgent 创建新的Agent实例 74 func NewAgent(config AgentConfig) *Agent { 75 ctx, cancel := context.WithCancel(context.Background()) 76 agent := &Agent{ 77 info: config.Info, 78 serverURL: config.ServerURL, 79 conn: nil, 80 Tasks: make([]*TaskContext, 0), 81 sendChan: make(chan interface{}, 100), 82 ctx: ctx, 83 cancel: cancel, 84 mutex: sync.RWMutex{}, 85 taskFunc: make([]TaskInterface, 0), 86 } 87 return agent 88 } 89 90 func (a *Agent) RegisterTaskFunc(taskFunc TaskInterface) { 91 a.mutex.Lock() 92 defer a.mutex.Unlock() 93 a.taskFunc = append(a.taskFunc, taskFunc) 94 a.info.Capabilities = append(a.info.Capabilities, taskFunc.GetName()) 95 } 96 97 // Start 启动Agent 98 func (a *Agent) Start() error { 99 // 尝试连接到服务器 100 if err := a.connect(); err != nil { 101 return fmt.Errorf("failed to connect to server: %v", err) 102 } 103 // 启动各种协程 104 go a.handleSend() 105 a.handleReceive() 106 return nil 107 } 108 109 // Stop 停止Agent 110 func (a *Agent) Stop() { 111 a.mutex.Lock() 112 defer a.mutex.Unlock() 113 114 // 取消所有运行中的任务 115 for _, task := range a.Tasks { 116 if task.Cancel != nil { 117 task.Cancel() 118 } 119 } 120 // 发送停止信号 121 a.cancel() 122 123 // 关闭连接 124 if a.conn != nil { 125 a.conn.Close() 126 } 127 } 128 129 // connect 连接到服务器 130 func (a *Agent) connect() error { 131 u, err := url.Parse(a.serverURL) 132 if err != nil { 133 return err 134 } 135 dialer := websocket.DefaultDialer 136 conn, _, err := dialer.Dial(u.String(), nil) 137 if err != nil { 138 return err 139 } 140 conn.SetReadLimit(1024 * 1024 * 5) 141 a.conn = conn 142 143 // 设置ping处理器:收到ping消息后自动回复pong 144 a.conn.SetPingHandler(func(appData string) error { 145 gologger.Debugln("Received ping message, sending pong response", appData) 146 return a.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(time.Second*60)) 147 }) 148 149 // 设置pong处理器:收到pong消息时的处理逻辑 150 a.conn.SetPongHandler(func(appData string) error { 151 gologger.Debugln("Received pong message") 152 // 更新读取超时时间,保持连接活跃 153 return a.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) 154 }) 155 156 // 发送注册消息 157 return a.register() 158 } 159 160 // register 向服务器注册 161 func (a *Agent) register() error { 162 registerMsg := RequestData{ 163 Type: AgentMsgTypeRegister, 164 Content: a.info, 165 } 166 a.sendChan <- registerMsg 167 return nil 168 } 169 170 // Disconnect 断开连接 171 func (a *Agent) Disconnect(reason string) error { 172 msg := Disconnect{ 173 AgentID: a.info.ID, 174 Reason: reason, 175 } 176 return a.sendMessage(msg) 177 } 178 179 // sendMessage 发送消息 180 func (a *Agent) sendMessage(msg interface{}) error { 181 if a.conn == nil { 182 return fmt.Errorf("connection is nil") 183 } 184 185 data, err := json.Marshal(msg) 186 if err != nil { 187 return err 188 } 189 190 return a.conn.WriteMessage(websocket.TextMessage, data) 191 } 192 193 // handleSend 处理发送队列 194 func (a *Agent) handleSend() { 195 for { 196 select { 197 case <-a.ctx.Done(): 198 return 199 case msg := <-a.sendChan: 200 if err := a.sendMessage(msg); err != nil { 201 gologger.WithError(err).Errorln("Failed to send message") 202 } 203 } 204 } 205 } 206 207 // handleReceive 处理接收消息 208 func (a *Agent) handleReceive() { 209 for { 210 select { 211 case <-a.ctx.Done(): 212 return 213 default: 214 if a.conn == nil { 215 break 216 } 217 _, message, err := a.conn.ReadMessage() 218 if err != nil { 219 gologger.WithError(err).Errorln("Failed to read message") 220 a.conn = nil 221 return 222 } 223 gologger.Debugln("recv", string(message)) 224 if err = a.processMessage(message); err != nil { 225 gologger.WithError(err).Errorln("Failed to send message") 226 } 227 } 228 } 229 } 230 231 // processMessage 处理接收到的消息 232 func (a *Agent) processMessage(data []byte) error { 233 var baseMsg ResponseData 234 235 if err := json.Unmarshal(data, &baseMsg); err != nil { 236 return err 237 } 238 239 switch baseMsg.Type { 240 case ServerMsgTypeRegisterResp: 241 case ServerMsgTypeTaskAssign: 242 var task TaskRequest 243 if err := json.Unmarshal(baseMsg.Content, &task); err != nil { 244 return err 245 } 246 taskType := task.TaskType 247 taskCtx, cancel := context.WithCancel(a.ctx) 248 // 创建任务上下文 249 if task.Timeout > 0 { 250 //taskCtx, cancel = context.WithTimeout(taskCtx, time.Duration(task.Timeout)*time.Second) 251 } 252 // 加入task 上下文 253 taskContext := &TaskContext{ 254 Request: task, 255 Status: TaskStatusPending, 256 Progress: 0, 257 StartTime: time.Now(), 258 Cancel: cancel, 259 } 260 a.Tasks = append(a.Tasks, taskContext) 261 for _, taskFunc := range a.taskFunc { 262 if taskType == taskFunc.GetName() { 263 gologger.Debugln("执行任务", taskFunc.GetName()) 264 taskContext.Status = TaskStatusRunning 265 // 创建回调函数集合 266 callbacks := TaskCallbacks{ 267 ResultCallback: func(result map[string]interface{}) { 268 gologger.Debugln("ResultCallback", result) 269 a.SendTaskResult(task.SessionId, result) 270 gologger.Debugln("ResultCallback end") 271 }, 272 ToolUseLogCallback: func(actionId, tool, planStepId, actionLog string) { 273 a.SendsToolUsedLog(task.SessionId, actionId, tool, planStepId, actionLog) 274 gologger.Debugln("ToolUseLogCallback", actionId, tool, planStepId, actionLog) 275 }, 276 ToolUsedCallback: func(planStepId, statusId, description string, tools []Tool) { 277 a.SendToolUsed(task.SessionId, planStepId, statusId, description, tools) 278 gologger.Debugln("ToolUsedCallback", planStepId, statusId, description, tools) 279 }, 280 NewPlanStepCallback: func(stepId, title string) { 281 a.SendNewPlanStep(task.SessionId, stepId, title) 282 gologger.Debugln("NewPlanStepCallback", stepId, title) 283 }, 284 StepStatusUpdateCallback: func(planStepId, statusId, agentStatus, brief, description string) { 285 a.SendStepStatusUpdate(task.SessionId, planStepId, statusId, agentStatus, brief, description) 286 gologger.Debugln("StepStatusUpdateCallback", planStepId, statusId, agentStatus, brief, description) 287 }, 288 PlanUpdateCallback: func(tasks []SubTask) { 289 a.SendPlanUpdate(task.SessionId, tasks) 290 gologger.Debugln("PlanUpdateCallback", tasks) 291 }, 292 ErrorCallback: func(error string) { 293 a.SendError(task.SessionId, error) 294 gologger.Debugln("ErrorCallback", error) 295 }, 296 } 297 go func() { 298 err := taskFunc.Execute(taskCtx, task, callbacks) 299 if err != nil { 300 if errors.Is(err, context.Canceled) { 301 gologger.Infof("任务已取消: sessionId=%s", task.SessionId) 302 taskContext.Status = TaskStatusFailed 303 a.removeTask(task.SessionId) 304 return 305 } 306 taskContext.Status = TaskStatusFailed 307 a.SendError(task.SessionId, err.Error()) 308 a.removeTask(task.SessionId) 309 return 310 } 311 taskContext.Status = TaskStatusComplete 312 a.removeTask(task.SessionId) 313 }() 314 break 315 } 316 } 317 case ServerMsgTypeTerminate: 318 var terminateReq TerminateTaskRequest 319 if err := json.Unmarshal(baseMsg.Content, &terminateReq); err != nil { 320 return err 321 } 322 if terminateReq.SessionID == "" { 323 return fmt.Errorf("terminate message missing session_id") 324 } 325 if !a.cancelTask(terminateReq.SessionID) { 326 gologger.Warningf("未找到可终止的任务: sessionId=%s", terminateReq.SessionID) 327 } 328 default: 329 return nil 330 } 331 return nil 332 } 333 334 // SendTaskResult 发送任务最终结果 335 func (a *Agent) SendTaskResult(sessionId string, result map[string]interface{}) error { 336 timestamp := time.Now().Unix() 337 msgId := uuid.New().String() 338 339 // 构建事件数据 340 event := Event{ 341 ID: msgId, 342 Type: "resultUpdate", 343 Timestamp: timestamp, 344 Result: result, 345 } 346 347 // 构建结果更新消息 348 resultUpdate := ResultUpdate{ 349 ID: msgId, 350 Type: "event", 351 SessionId: sessionId, 352 Timestamp: timestamp, 353 Event: event, 354 } 355 356 // 构建发送给服务器的消息 357 resultUpdateContent := RequestData{ 358 Type: AgentMsgTypeResultUpdate, 359 Content: resultUpdate, 360 } 361 362 // 通过发送通道发送消息 363 a.sendChan <- resultUpdateContent 364 return nil 365 } 366 367 // GetTaskBySessionId 根据SessionId获取任务上下文 368 func (a *Agent) GetTaskBySessionId(sessionId string) *TaskContext { 369 a.mutex.RLock() 370 defer a.mutex.RUnlock() 371 372 for _, task := range a.Tasks { 373 if task.Request.SessionId == sessionId { 374 return task 375 } 376 } 377 return nil 378 } 379 380 func (a *Agent) cancelTask(sessionId string) bool { 381 a.mutex.RLock() 382 defer a.mutex.RUnlock() 383 384 for _, task := range a.Tasks { 385 if task.Request.SessionId != sessionId { 386 continue 387 } 388 if task.Cancel != nil { 389 task.Cancel() 390 } 391 return true 392 } 393 return false 394 } 395 396 func (a *Agent) removeTask(sessionId string) { 397 a.mutex.Lock() 398 defer a.mutex.Unlock() 399 400 for idx, task := range a.Tasks { 401 if task.Request.SessionId != sessionId { 402 continue 403 } 404 a.Tasks = append(a.Tasks[:idx], a.Tasks[idx+1:]...) 405 return 406 } 407 } 408 409 func (a *Agent) SendsToolUsedLog(sessionId, actionId, tool, planStepId, actionLog string) error { 410 timestamp := time.Now().Unix() 411 msgId := uuid.New().String() 412 413 // 构建插件日志事件数据 414 event := ActionLogEvent{ 415 ID: msgId, 416 Type: "actionLog", 417 Timestamp: timestamp, 418 ActionId: actionId, 419 Tool: tool, 420 PlanStepId: planStepId, 421 ActionLog: actionLog, 422 } 423 424 // 构建插件日志更新消息 425 actionLogUpdate := ActionLogUpdate{ 426 ID: msgId, 427 Type: "event", 428 SessionId: sessionId, 429 Timestamp: timestamp, 430 Event: event, 431 } 432 433 // 构建发送给服务器的消息 434 actionLogContent := ActionLogContent{ 435 Type: AgentMsgTypeActionLog, 436 Content: actionLogUpdate, 437 } 438 439 // 通过发送通道发送消息 440 a.sendChan <- actionLogContent 441 return nil 442 } 443 444 func (a *Agent) SendToolUsed(sessionId, planStepId, statusId, description string, tools []Tool) error { 445 timestamp := time.Now().Unix() 446 msgId := uuid.New().String() 447 448 // 构建插件工作状态事件数据 449 event := ToolUsedEvent{ 450 ID: msgId, 451 Type: "toolUsed", 452 Timestamp: timestamp, 453 Description: description, 454 PlanStepId: planStepId, 455 StatusId: statusId, 456 Tools: tools, 457 } 458 459 // 构建插件工作状态更新消息 460 toolUsedUpdate := ToolUsedUpdate{ 461 ID: msgId, 462 Type: "event", 463 SessionId: sessionId, 464 Timestamp: timestamp, 465 Event: event, 466 } 467 468 // 构建发送给服务器的消息 469 toolUsedContent := ToolUsedContent{ 470 Type: AgentMsgTypeToolUsed, 471 Content: toolUsedUpdate, 472 } 473 474 // 通过发送通道发送消息 475 a.sendChan <- toolUsedContent 476 return nil 477 } 478 479 // CreateTool 创建工具使用信息 480 func CreateTool(toolId, tool string, status statusString, brief, action, param, result string) Tool { 481 return Tool{ 482 ToolId: toolId, 483 Tool: tool, 484 Status: status, 485 Brief: brief, 486 Message: ToolMessage{ 487 Action: action, 488 Param: param, 489 }, 490 Result: result, 491 } 492 } 493 494 // SendNewPlanStep 新建执行步骤 495 func (a *Agent) SendNewPlanStep(sessionId, stepId, title string) error { 496 timestamp := time.Now().Unix() 497 msgId := uuid.New().String() 498 499 // 构建新建执行步骤事件数据 500 event := NewPlanStepEvent{ 501 ID: msgId, 502 Type: AgentMsgTypeNewPlanStep, 503 Timestamp: timestamp, 504 StepId: stepId, 505 Title: title, 506 } 507 508 // 构建新建执行步骤更新消息 509 newPlanStepUpdate := NewPlanStepUpdate{ 510 ID: msgId, 511 Type: "event", 512 SessionId: sessionId, 513 Timestamp: timestamp, 514 Event: event, 515 } 516 517 // 构建发送给服务器的消息 518 newPlanStepContent := NewPlanStepContent{ 519 Type: AgentMsgTypeNewPlanStep, 520 Content: newPlanStepUpdate, 521 } 522 523 // 通过发送通道发送消息 524 a.sendChan <- newPlanStepContent 525 return nil 526 } 527 528 // SendStepStatusUpdate 发送更新步骤状态 529 func (a *Agent) SendStepStatusUpdate(sessionId, planStepId, statusId, agentStatus, brief, description string) error { 530 timestamp := time.Now().Unix() 531 532 // 构建更新步骤状态事件数据 533 event := StatusUpdateEvent{ 534 ID: statusId, 535 Type: AgentMsgTypeStatusUpdate, 536 Timestamp: timestamp, 537 AgentStatus: agentStatus, 538 Brief: brief, 539 Description: description, 540 NoRender: false, 541 PlanStepId: planStepId, 542 } 543 544 // 构建更新步骤状态更新消息 545 statusUpdateUpdate := StatusUpdateUpdate{ 546 ID: statusId, 547 Type: "event", 548 SessionId: sessionId, 549 Timestamp: timestamp, 550 Event: event, 551 } 552 553 // 构建发送给服务器的消息 554 statusUpdateContent := StatusUpdateContent{ 555 Type: AgentMsgTypeStatusUpdate, 556 Content: statusUpdateUpdate, 557 } 558 559 // 通过发送通道发送消息 560 a.sendChan <- statusUpdateContent 561 return nil 562 } 563 564 // SendPlanUpdate 整体计划更新 565 func (a *Agent) SendPlanUpdate(sessionId string, tasks []SubTask) error { 566 timestamp := time.Now().Unix() 567 msgId := uuid.New().String() 568 569 // 构建更新任务计划事件数据 570 event := PlanUpdateEvent{ 571 ID: msgId, 572 Type: "planUpdate", 573 Timestamp: timestamp, 574 Tasks: tasks, 575 } 576 577 // 构建更新任务计划更新消息 578 planUpdateUpdate := PlanUpdateUpdate{ 579 ID: msgId, 580 Type: "event", 581 SessionId: sessionId, 582 Timestamp: timestamp, 583 Event: event, 584 } 585 586 // 构建发送给服务器的消息 587 planUpdateContent := PlanUpdateContent{ 588 Type: AgentMsgTypePlanUpdate, 589 Content: planUpdateUpdate, 590 } 591 592 // 通过发送通道发送消息 593 a.sendChan <- planUpdateContent 594 return nil 595 } 596 597 // CreateSubTask 创建子任务的便捷方法 598 func CreateSubTask(status statusString, title string, startedAt int64, stepId string) SubTask { 599 return SubTask{ 600 Status: status, 601 Title: title, 602 StartedAt: startedAt, 603 StepId: stepId, 604 } 605 } 606 607 // SendError 发送错误 608 func (a *Agent) SendError(sessionId, msg string) error { 609 timestamp := time.Now().Unix() 610 msgId := uuid.New().String() 611 612 // 构建更新任务计划事件数据 613 event := ErrorEvent{ 614 Id: msgId, 615 Type: "error", 616 Timestamp: timestamp, 617 Message: msg, 618 } 619 620 // 构建更新任务计划更新消息 621 planUpdateUpdate := ErrorUpdate{ 622 ID: msgId, 623 Type: "event", 624 SessionID: sessionId, 625 Timestamp: timestamp, 626 Event: event, 627 } 628 629 // 构建发送给服务器的消息 630 planUpdateContent := ErrorUpdateContent{ 631 Type: AgentMsgTypeError, 632 Content: planUpdateUpdate, 633 } 634 635 // 通过发送通道发送消息 636 a.sendChan <- planUpdateContent 637 return nil 638 639 }