/ common / agent / agent.go
agent.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package agent
 20  
 21  import (
 22  	"context"
 23  	"encoding/json"
 24  	"errors"
 25  	"fmt"
 26  	"net/url"
 27  	"sync"
 28  	"time"
 29  
 30  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 31  
 32  	"github.com/google/uuid"
 33  	"github.com/gorilla/websocket"
 34  )
 35  
 36  // Agent 客户端结构
 37  type Agent struct {
 38  	// 基本信息
 39  	info      AgentInfo
 40  	serverURL string
 41  	conn      *websocket.Conn
 42  
 43  	// 任务管理
 44  	Tasks    []*TaskContext
 45  	taskFunc []TaskInterface
 46  
 47  	// 通信管理
 48  	sendChan chan interface{}
 49  	ctx      context.Context
 50  	cancel   context.CancelFunc
 51  
 52  	// 配置
 53  	mutex sync.RWMutex
 54  }
 55  
 56  // TaskContext 任务上下文
 57  type TaskContext struct {
 58  	Request   TaskRequest
 59  	Status    string
 60  	Progress  int
 61  	StartTime time.Time
 62  	Cancel    context.CancelFunc
 63  	Result    interface{}
 64  	Error     error
 65  }
 66  
 67  // AgentConfig Agent配置
 68  type AgentConfig struct {
 69  	ServerURL string
 70  	Info      AgentInfo
 71  }
 72  
 73  // NewAgent 创建新的Agent实例
 74  func NewAgent(config AgentConfig) *Agent {
 75  	ctx, cancel := context.WithCancel(context.Background())
 76  	agent := &Agent{
 77  		info:      config.Info,
 78  		serverURL: config.ServerURL,
 79  		conn:      nil,
 80  		Tasks:     make([]*TaskContext, 0),
 81  		sendChan:  make(chan interface{}, 100),
 82  		ctx:       ctx,
 83  		cancel:    cancel,
 84  		mutex:     sync.RWMutex{},
 85  		taskFunc:  make([]TaskInterface, 0),
 86  	}
 87  	return agent
 88  }
 89  
 90  func (a *Agent) RegisterTaskFunc(taskFunc TaskInterface) {
 91  	a.mutex.Lock()
 92  	defer a.mutex.Unlock()
 93  	a.taskFunc = append(a.taskFunc, taskFunc)
 94  	a.info.Capabilities = append(a.info.Capabilities, taskFunc.GetName())
 95  }
 96  
 97  // Start 启动Agent
 98  func (a *Agent) Start() error {
 99  	// 尝试连接到服务器
100  	if err := a.connect(); err != nil {
101  		return fmt.Errorf("failed to connect to server: %v", err)
102  	}
103  	// 启动各种协程
104  	go a.handleSend()
105  	a.handleReceive()
106  	return nil
107  }
108  
109  // Stop 停止Agent
110  func (a *Agent) Stop() {
111  	a.mutex.Lock()
112  	defer a.mutex.Unlock()
113  
114  	// 取消所有运行中的任务
115  	for _, task := range a.Tasks {
116  		if task.Cancel != nil {
117  			task.Cancel()
118  		}
119  	}
120  	// 发送停止信号
121  	a.cancel()
122  
123  	// 关闭连接
124  	if a.conn != nil {
125  		a.conn.Close()
126  	}
127  }
128  
129  // connect 连接到服务器
130  func (a *Agent) connect() error {
131  	u, err := url.Parse(a.serverURL)
132  	if err != nil {
133  		return err
134  	}
135  	dialer := websocket.DefaultDialer
136  	conn, _, err := dialer.Dial(u.String(), nil)
137  	if err != nil {
138  		return err
139  	}
140  	conn.SetReadLimit(1024 * 1024 * 5)
141  	a.conn = conn
142  
143  	// 设置ping处理器:收到ping消息后自动回复pong
144  	a.conn.SetPingHandler(func(appData string) error {
145  		gologger.Debugln("Received ping message, sending pong response", appData)
146  		return a.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(time.Second*60))
147  	})
148  
149  	// 设置pong处理器:收到pong消息时的处理逻辑
150  	a.conn.SetPongHandler(func(appData string) error {
151  		gologger.Debugln("Received pong message")
152  		// 更新读取超时时间,保持连接活跃
153  		return a.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
154  	})
155  
156  	// 发送注册消息
157  	return a.register()
158  }
159  
160  // register 向服务器注册
161  func (a *Agent) register() error {
162  	registerMsg := RequestData{
163  		Type:    AgentMsgTypeRegister,
164  		Content: a.info,
165  	}
166  	a.sendChan <- registerMsg
167  	return nil
168  }
169  
170  // Disconnect 断开连接
171  func (a *Agent) Disconnect(reason string) error {
172  	msg := Disconnect{
173  		AgentID: a.info.ID,
174  		Reason:  reason,
175  	}
176  	return a.sendMessage(msg)
177  }
178  
179  // sendMessage 发送消息
180  func (a *Agent) sendMessage(msg interface{}) error {
181  	if a.conn == nil {
182  		return fmt.Errorf("connection is nil")
183  	}
184  
185  	data, err := json.Marshal(msg)
186  	if err != nil {
187  		return err
188  	}
189  
190  	return a.conn.WriteMessage(websocket.TextMessage, data)
191  }
192  
193  // handleSend 处理发送队列
194  func (a *Agent) handleSend() {
195  	for {
196  		select {
197  		case <-a.ctx.Done():
198  			return
199  		case msg := <-a.sendChan:
200  			if err := a.sendMessage(msg); err != nil {
201  				gologger.WithError(err).Errorln("Failed to send message")
202  			}
203  		}
204  	}
205  }
206  
207  // handleReceive 处理接收消息
208  func (a *Agent) handleReceive() {
209  	for {
210  		select {
211  		case <-a.ctx.Done():
212  			return
213  		default:
214  			if a.conn == nil {
215  				break
216  			}
217  			_, message, err := a.conn.ReadMessage()
218  			if err != nil {
219  				gologger.WithError(err).Errorln("Failed to read message")
220  				a.conn = nil
221  				return
222  			}
223  			gologger.Debugln("recv", string(message))
224  			if err = a.processMessage(message); err != nil {
225  				gologger.WithError(err).Errorln("Failed to send message")
226  			}
227  		}
228  	}
229  }
230  
231  // processMessage 处理接收到的消息
232  func (a *Agent) processMessage(data []byte) error {
233  	var baseMsg ResponseData
234  
235  	if err := json.Unmarshal(data, &baseMsg); err != nil {
236  		return err
237  	}
238  
239  	switch baseMsg.Type {
240  	case ServerMsgTypeRegisterResp:
241  	case ServerMsgTypeTaskAssign:
242  		var task TaskRequest
243  		if err := json.Unmarshal(baseMsg.Content, &task); err != nil {
244  			return err
245  		}
246  		taskType := task.TaskType
247  		taskCtx, cancel := context.WithCancel(a.ctx)
248  		// 创建任务上下文
249  		if task.Timeout > 0 {
250  			//taskCtx, cancel = context.WithTimeout(taskCtx, time.Duration(task.Timeout)*time.Second)
251  		}
252  		// 加入task 上下文
253  		taskContext := &TaskContext{
254  			Request:   task,
255  			Status:    TaskStatusPending,
256  			Progress:  0,
257  			StartTime: time.Now(),
258  			Cancel:    cancel,
259  		}
260  		a.Tasks = append(a.Tasks, taskContext)
261  		for _, taskFunc := range a.taskFunc {
262  			if taskType == taskFunc.GetName() {
263  				gologger.Debugln("执行任务", taskFunc.GetName())
264  				taskContext.Status = TaskStatusRunning
265  				// 创建回调函数集合
266  				callbacks := TaskCallbacks{
267  					ResultCallback: func(result map[string]interface{}) {
268  						gologger.Debugln("ResultCallback", result)
269  						a.SendTaskResult(task.SessionId, result)
270  						gologger.Debugln("ResultCallback end")
271  					},
272  					ToolUseLogCallback: func(actionId, tool, planStepId, actionLog string) {
273  						a.SendsToolUsedLog(task.SessionId, actionId, tool, planStepId, actionLog)
274  						gologger.Debugln("ToolUseLogCallback", actionId, tool, planStepId, actionLog)
275  					},
276  					ToolUsedCallback: func(planStepId, statusId, description string, tools []Tool) {
277  						a.SendToolUsed(task.SessionId, planStepId, statusId, description, tools)
278  						gologger.Debugln("ToolUsedCallback", planStepId, statusId, description, tools)
279  					},
280  					NewPlanStepCallback: func(stepId, title string) {
281  						a.SendNewPlanStep(task.SessionId, stepId, title)
282  						gologger.Debugln("NewPlanStepCallback", stepId, title)
283  					},
284  					StepStatusUpdateCallback: func(planStepId, statusId, agentStatus, brief, description string) {
285  						a.SendStepStatusUpdate(task.SessionId, planStepId, statusId, agentStatus, brief, description)
286  						gologger.Debugln("StepStatusUpdateCallback", planStepId, statusId, agentStatus, brief, description)
287  					},
288  					PlanUpdateCallback: func(tasks []SubTask) {
289  						a.SendPlanUpdate(task.SessionId, tasks)
290  						gologger.Debugln("PlanUpdateCallback", tasks)
291  					},
292  					ErrorCallback: func(error string) {
293  						a.SendError(task.SessionId, error)
294  						gologger.Debugln("ErrorCallback", error)
295  					},
296  				}
297  				go func() {
298  					err := taskFunc.Execute(taskCtx, task, callbacks)
299  					if err != nil {
300  						if errors.Is(err, context.Canceled) {
301  							gologger.Infof("任务已取消: sessionId=%s", task.SessionId)
302  							taskContext.Status = TaskStatusFailed
303  							a.removeTask(task.SessionId)
304  							return
305  						}
306  						taskContext.Status = TaskStatusFailed
307  						a.SendError(task.SessionId, err.Error())
308  						a.removeTask(task.SessionId)
309  						return
310  					}
311  					taskContext.Status = TaskStatusComplete
312  					a.removeTask(task.SessionId)
313  				}()
314  				break
315  			}
316  		}
317  	case ServerMsgTypeTerminate:
318  		var terminateReq TerminateTaskRequest
319  		if err := json.Unmarshal(baseMsg.Content, &terminateReq); err != nil {
320  			return err
321  		}
322  		if terminateReq.SessionID == "" {
323  			return fmt.Errorf("terminate message missing session_id")
324  		}
325  		if !a.cancelTask(terminateReq.SessionID) {
326  			gologger.Warningf("未找到可终止的任务: sessionId=%s", terminateReq.SessionID)
327  		}
328  	default:
329  		return nil
330  	}
331  	return nil
332  }
333  
334  // SendTaskResult 发送任务最终结果
335  func (a *Agent) SendTaskResult(sessionId string, result map[string]interface{}) error {
336  	timestamp := time.Now().Unix()
337  	msgId := uuid.New().String()
338  
339  	// 构建事件数据
340  	event := Event{
341  		ID:        msgId,
342  		Type:      "resultUpdate",
343  		Timestamp: timestamp,
344  		Result:    result,
345  	}
346  
347  	// 构建结果更新消息
348  	resultUpdate := ResultUpdate{
349  		ID:        msgId,
350  		Type:      "event",
351  		SessionId: sessionId,
352  		Timestamp: timestamp,
353  		Event:     event,
354  	}
355  
356  	// 构建发送给服务器的消息
357  	resultUpdateContent := RequestData{
358  		Type:    AgentMsgTypeResultUpdate,
359  		Content: resultUpdate,
360  	}
361  
362  	// 通过发送通道发送消息
363  	a.sendChan <- resultUpdateContent
364  	return nil
365  }
366  
367  // GetTaskBySessionId 根据SessionId获取任务上下文
368  func (a *Agent) GetTaskBySessionId(sessionId string) *TaskContext {
369  	a.mutex.RLock()
370  	defer a.mutex.RUnlock()
371  
372  	for _, task := range a.Tasks {
373  		if task.Request.SessionId == sessionId {
374  			return task
375  		}
376  	}
377  	return nil
378  }
379  
380  func (a *Agent) cancelTask(sessionId string) bool {
381  	a.mutex.RLock()
382  	defer a.mutex.RUnlock()
383  
384  	for _, task := range a.Tasks {
385  		if task.Request.SessionId != sessionId {
386  			continue
387  		}
388  		if task.Cancel != nil {
389  			task.Cancel()
390  		}
391  		return true
392  	}
393  	return false
394  }
395  
396  func (a *Agent) removeTask(sessionId string) {
397  	a.mutex.Lock()
398  	defer a.mutex.Unlock()
399  
400  	for idx, task := range a.Tasks {
401  		if task.Request.SessionId != sessionId {
402  			continue
403  		}
404  		a.Tasks = append(a.Tasks[:idx], a.Tasks[idx+1:]...)
405  		return
406  	}
407  }
408  
409  func (a *Agent) SendsToolUsedLog(sessionId, actionId, tool, planStepId, actionLog string) error {
410  	timestamp := time.Now().Unix()
411  	msgId := uuid.New().String()
412  
413  	// 构建插件日志事件数据
414  	event := ActionLogEvent{
415  		ID:         msgId,
416  		Type:       "actionLog",
417  		Timestamp:  timestamp,
418  		ActionId:   actionId,
419  		Tool:       tool,
420  		PlanStepId: planStepId,
421  		ActionLog:  actionLog,
422  	}
423  
424  	// 构建插件日志更新消息
425  	actionLogUpdate := ActionLogUpdate{
426  		ID:        msgId,
427  		Type:      "event",
428  		SessionId: sessionId,
429  		Timestamp: timestamp,
430  		Event:     event,
431  	}
432  
433  	// 构建发送给服务器的消息
434  	actionLogContent := ActionLogContent{
435  		Type:    AgentMsgTypeActionLog,
436  		Content: actionLogUpdate,
437  	}
438  
439  	// 通过发送通道发送消息
440  	a.sendChan <- actionLogContent
441  	return nil
442  }
443  
444  func (a *Agent) SendToolUsed(sessionId, planStepId, statusId, description string, tools []Tool) error {
445  	timestamp := time.Now().Unix()
446  	msgId := uuid.New().String()
447  
448  	// 构建插件工作状态事件数据
449  	event := ToolUsedEvent{
450  		ID:          msgId,
451  		Type:        "toolUsed",
452  		Timestamp:   timestamp,
453  		Description: description,
454  		PlanStepId:  planStepId,
455  		StatusId:    statusId,
456  		Tools:       tools,
457  	}
458  
459  	// 构建插件工作状态更新消息
460  	toolUsedUpdate := ToolUsedUpdate{
461  		ID:        msgId,
462  		Type:      "event",
463  		SessionId: sessionId,
464  		Timestamp: timestamp,
465  		Event:     event,
466  	}
467  
468  	// 构建发送给服务器的消息
469  	toolUsedContent := ToolUsedContent{
470  		Type:    AgentMsgTypeToolUsed,
471  		Content: toolUsedUpdate,
472  	}
473  
474  	// 通过发送通道发送消息
475  	a.sendChan <- toolUsedContent
476  	return nil
477  }
478  
479  // CreateTool 创建工具使用信息
480  func CreateTool(toolId, tool string, status statusString, brief, action, param, result string) Tool {
481  	return Tool{
482  		ToolId: toolId,
483  		Tool:   tool,
484  		Status: status,
485  		Brief:  brief,
486  		Message: ToolMessage{
487  			Action: action,
488  			Param:  param,
489  		},
490  		Result: result,
491  	}
492  }
493  
494  // SendNewPlanStep 新建执行步骤
495  func (a *Agent) SendNewPlanStep(sessionId, stepId, title string) error {
496  	timestamp := time.Now().Unix()
497  	msgId := uuid.New().String()
498  
499  	// 构建新建执行步骤事件数据
500  	event := NewPlanStepEvent{
501  		ID:        msgId,
502  		Type:      AgentMsgTypeNewPlanStep,
503  		Timestamp: timestamp,
504  		StepId:    stepId,
505  		Title:     title,
506  	}
507  
508  	// 构建新建执行步骤更新消息
509  	newPlanStepUpdate := NewPlanStepUpdate{
510  		ID:        msgId,
511  		Type:      "event",
512  		SessionId: sessionId,
513  		Timestamp: timestamp,
514  		Event:     event,
515  	}
516  
517  	// 构建发送给服务器的消息
518  	newPlanStepContent := NewPlanStepContent{
519  		Type:    AgentMsgTypeNewPlanStep,
520  		Content: newPlanStepUpdate,
521  	}
522  
523  	// 通过发送通道发送消息
524  	a.sendChan <- newPlanStepContent
525  	return nil
526  }
527  
528  // SendStepStatusUpdate 发送更新步骤状态
529  func (a *Agent) SendStepStatusUpdate(sessionId, planStepId, statusId, agentStatus, brief, description string) error {
530  	timestamp := time.Now().Unix()
531  
532  	// 构建更新步骤状态事件数据
533  	event := StatusUpdateEvent{
534  		ID:          statusId,
535  		Type:        AgentMsgTypeStatusUpdate,
536  		Timestamp:   timestamp,
537  		AgentStatus: agentStatus,
538  		Brief:       brief,
539  		Description: description,
540  		NoRender:    false,
541  		PlanStepId:  planStepId,
542  	}
543  
544  	// 构建更新步骤状态更新消息
545  	statusUpdateUpdate := StatusUpdateUpdate{
546  		ID:        statusId,
547  		Type:      "event",
548  		SessionId: sessionId,
549  		Timestamp: timestamp,
550  		Event:     event,
551  	}
552  
553  	// 构建发送给服务器的消息
554  	statusUpdateContent := StatusUpdateContent{
555  		Type:    AgentMsgTypeStatusUpdate,
556  		Content: statusUpdateUpdate,
557  	}
558  
559  	// 通过发送通道发送消息
560  	a.sendChan <- statusUpdateContent
561  	return nil
562  }
563  
564  // SendPlanUpdate 整体计划更新
565  func (a *Agent) SendPlanUpdate(sessionId string, tasks []SubTask) error {
566  	timestamp := time.Now().Unix()
567  	msgId := uuid.New().String()
568  
569  	// 构建更新任务计划事件数据
570  	event := PlanUpdateEvent{
571  		ID:        msgId,
572  		Type:      "planUpdate",
573  		Timestamp: timestamp,
574  		Tasks:     tasks,
575  	}
576  
577  	// 构建更新任务计划更新消息
578  	planUpdateUpdate := PlanUpdateUpdate{
579  		ID:        msgId,
580  		Type:      "event",
581  		SessionId: sessionId,
582  		Timestamp: timestamp,
583  		Event:     event,
584  	}
585  
586  	// 构建发送给服务器的消息
587  	planUpdateContent := PlanUpdateContent{
588  		Type:    AgentMsgTypePlanUpdate,
589  		Content: planUpdateUpdate,
590  	}
591  
592  	// 通过发送通道发送消息
593  	a.sendChan <- planUpdateContent
594  	return nil
595  }
596  
597  // CreateSubTask 创建子任务的便捷方法
598  func CreateSubTask(status statusString, title string, startedAt int64, stepId string) SubTask {
599  	return SubTask{
600  		Status:    status,
601  		Title:     title,
602  		StartedAt: startedAt,
603  		StepId:    stepId,
604  	}
605  }
606  
607  // SendError 发送错误
608  func (a *Agent) SendError(sessionId, msg string) error {
609  	timestamp := time.Now().Unix()
610  	msgId := uuid.New().String()
611  
612  	// 构建更新任务计划事件数据
613  	event := ErrorEvent{
614  		Id:        msgId,
615  		Type:      "error",
616  		Timestamp: timestamp,
617  		Message:   msg,
618  	}
619  
620  	// 构建更新任务计划更新消息
621  	planUpdateUpdate := ErrorUpdate{
622  		ID:        msgId,
623  		Type:      "event",
624  		SessionID: sessionId,
625  		Timestamp: timestamp,
626  		Event:     event,
627  	}
628  
629  	// 构建发送给服务器的消息
630  	planUpdateContent := ErrorUpdateContent{
631  		Type:    AgentMsgTypeError,
632  		Content: planUpdateUpdate,
633  	}
634  
635  	// 通过发送通道发送消息
636  	a.sendChan <- planUpdateContent
637  	return nil
638  
639  }