base_agent.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import json 20 import uuid 21 22 from tools.dispatcher import ToolDispatcher 23 from utils.aig_logger import mcpLogger 24 from utils.llm import LLM 25 from utils.loging import logger 26 from utils.parse import clean_content, parse_tool_invocations 27 from utils.prompt_manager import prompt_manager 28 from utils.tool_context import ToolContext 29 30 31 class BaseAgent: 32 def __init__( 33 self, 34 name: str, 35 instruction: str, 36 llm: LLM, 37 dispatcher: ToolDispatcher, 38 specialized_llms: dict = None, 39 log_step_id: str = None, 40 debug: bool = False, 41 capabilities: list[str] = None, 42 output_format: str | None = None, 43 output_check_fn: callable = None, 44 language: str = "zh", 45 ): 46 self.llm = llm 47 self.name = name 48 self.dispatcher = dispatcher 49 self.specialized_llms = specialized_llms or {} 50 self.instruction = instruction 51 self.capabilities = capabilities or ["standard"] 52 self.output_format = output_format 53 self.step_id = log_step_id 54 self.debug = debug 55 self.repo_dir = "" 56 self.output_check_fn = output_check_fn 57 self.language = language 58 # loop control 59 self.iter = 0 60 self.max_iter = 80 61 self.is_finished = False 62 # context 63 self.history = [] 64 self.original_task = "" 65 self.summary_memory = "" 66 # 在模型上下文窗口达到 60% 左右时开始压缩,给后续输出和工具结果留余量。 67 self.max_history_tokens = max(int(self.llm.context_window * 0.6), 1) 68 # 压缩时保留最近若干条对话,避免丢失当前执行轨迹。 69 self.keep_recent_msgs = 8 70 71 async def initialize(self): 72 """异步初始化系统提示词""" 73 if not self.history: 74 system_prompt = await self.generate_system_prompt() 75 self.history.append({"role": "system", "content": system_prompt}) 76 77 def add_user_message(self, message: str): 78 self.history.append({"role": "user", "content": message}) 79 80 def set_repo_dir(self, repo_dir: str): 81 self.repo_dir = repo_dir 82 83 def should_compact_history(self, usage: dict | None = None) -> bool: 84 # 没有可压缩的旧消息时,不触发压缩。 85 if len(self.history) - 2 <= self.keep_recent_msgs: 86 return False 87 88 prompt_tokens = None 89 if usage: 90 prompt_tokens = usage.get("prompt_tokens") 91 if isinstance(prompt_tokens, int): 92 return prompt_tokens >= self.max_history_tokens 93 94 # 某些兼容接口没有 usage,退回到消息条数做保守判定。 95 return len(self.history) > 24 96 97 def compact_history(self): 98 recent_start = max(2, len(self.history) - self.keep_recent_msgs) 99 100 msgs_to_compact = [] 101 if self.summary_memory: 102 msgs_to_compact.append( 103 {"role": "user", "content": self._build_summary_memory_message()} 104 ) 105 msgs_to_compact.extend(self.history[2:recent_start]) 106 if not msgs_to_compact: 107 return 108 109 compact_prompt = prompt_manager.load_template("compact") 110 msgs_to_compact.append({"role": "user", "content": compact_prompt}) 111 compacted_msgs = self.llm.chat(msgs_to_compact) 112 self.summary_memory = compacted_msgs 113 114 if not self.original_task: 115 self.original_task = self.history[1]["content"] 116 117 system_prompt = self.history[0] 118 recent_msgs = self.history[-self.keep_recent_msgs :] 119 self.history = [ 120 system_prompt, 121 { 122 "role": "user", 123 "content": self._build_task_message(), 124 }, 125 *recent_msgs, 126 ] 127 128 async def generate_system_prompt(self): 129 tools_prompt = await self.dispatcher.get_all_tools_prompt() 130 131 template_name = "system_prompt" 132 format_kwargs = { 133 "generate_tools": tools_prompt, 134 "name": self.name, 135 "instruction": self.instruction, 136 } 137 138 return prompt_manager.format_prompt(template_name, **format_kwargs) 139 140 def next_prompt(self): 141 return prompt_manager.format_prompt("next_prompt", round=self.iter) 142 143 async def run(self): 144 await self.initialize() 145 return await self._run() 146 147 async def _run(self): 148 logger.info(f"Agent {self.name} started with max_iter={self.max_iter}") 149 result = "" 150 while not self.is_finished and self.iter < self.max_iter: 151 logger.debug(f"\n{'=' * 50}\nIteration {self.iter}\n{'=' * 50}") 152 response, usage = self.llm.chat(self.history, self.debug, ret_usage=True) 153 logger.debug(f"LLM Response: {response}") 154 155 self.history.append({"role": "assistant", "content": response}) 156 res = await self.handle_response(response) 157 if res is not None: 158 result = res 159 160 self.iter += 1 161 if self.should_compact_history(usage) and not self.is_finished: 162 logger.info( 163 "Prompt tokens %s exceeded limit %s, compacting context", 164 usage.get("prompt_tokens") if usage else None, 165 self.max_history_tokens, 166 ) 167 self.compact_history() 168 169 if not self.is_finished: 170 logger.warning(f"Max iterations ({self.max_iter}) reached") 171 mcpLogger.status_update( 172 self.step_id, 173 "达到最大迭代次数,返回当前结果" 174 if self.language != "en" 175 else "Max iterations reached, returning current result", 176 "", 177 "completed", 178 ) 179 if not result: 180 result = await self._format_final_output() 181 return result 182 183 async def handle_response(self, response: str): 184 tool_invocations = parse_tool_invocations(response) 185 description = clean_content(response) 186 if tool_invocations and tool_invocations["toolName"] == "finish" and description == "": 187 description = "报告完成。" 188 if self.language == "en": 189 description = "Report completed." 190 if description == "": 191 description = "我将继续执行" 192 if self.language == "en": 193 description = "I will continue to execute" 194 195 if tool_invocations: 196 if tool_invocations["toolName"] != "finish": 197 mcpLogger.status_update(self.step_id, description, "", "running") 198 return await self.process_tool_call(tool_invocations, description) 199 else: 200 mcpLogger.status_update(self.step_id, description, "", "running") 201 return await self.handle_no_tool(description) 202 203 async def process_tool_call(self, tool_call: dict, description: str): 204 tool_name = tool_call["toolName"] 205 tool_args = tool_call["args"] 206 tool_id = uuid.uuid4().__str__() 207 208 params = json.dumps(tool_args, ensure_ascii=False) if tool_args else "" 209 if isinstance(params, str): 210 params = params.replace(self.repo_dir, "") 211 212 mcpLogger.tool_used(self.step_id, tool_id, tool_name, "done", tool_name, f"{params}") 213 214 if tool_name == "finish": 215 self.is_finished = True 216 logger.info("Finish tool called, final result formatted.") 217 218 mcpLogger.status_update(self.step_id, description, "", "completed") 219 result = await self._format_final_output() 220 mcpLogger.action_log(tool_id, tool_name, self.step_id, result) 221 return result 222 223 # 构造上下文 224 context = ToolContext( 225 llm=self.llm, 226 history=self.history, 227 agent_name=self.name, 228 iteration=self.iter, 229 specialized_llms=self.specialized_llms, 230 folder=self.repo_dir, 231 tool_dispatcher=self.dispatcher, 232 ) 233 234 # 通过 Dispatcher 调用工具 235 tool_result = await self.dispatcher.call_tool(tool_name, tool_args, context) 236 237 # 格式化工具结果并添加到历史 238 result_message = f"{tool_result}" 239 240 # 添加下一轮提示 241 next_p = self.next_prompt() 242 full_message = f"{next_p}\n\n{result_message}" 243 244 self.history.append({"role": "user", "content": full_message}) 245 mcpLogger.status_update(self.step_id, description, "", "completed") 246 247 if tool_name != "read_file": 248 mcpLogger.action_log(tool_id, tool_name, self.step_id, f"```\n{result_message}\n```") 249 250 return None 251 252 async def handle_no_tool(self, description: str): 253 next_p = self.next_prompt() 254 if self.language == "en": 255 reminder = ( 256 f"{next_p}\n\n" 257 "No tool call was detected. You must call exactly one tool in your next response. " 258 "If the task is complete, call finish." 259 ) 260 else: 261 reminder = ( 262 f"{next_p}\n\n" 263 "未检测到工具调用。你下一次回复必须严格调用一个工具。" 264 "如果任务已完成,请调用 finish。" 265 ) 266 267 self.history.append({"role": "user", "content": reminder}) 268 return None 269 270 async def _format_final_output(self) -> str: 271 """使用 LLM 根据历史记录和预设格式生成最终输出""" 272 # 取最近的对话历史作为参考 273 recent_history = self.history[1:] 274 formatting_prompt = prompt_manager.format_prompt( 275 "format_report", output_format=self.output_format 276 ) 277 recent_history.append({"role": "user", "content": formatting_prompt}) 278 final_output = "" 279 for _ in range(3): 280 final_output = self.llm.chat(recent_history) 281 logger.info(f"Final Output: {final_output}") 282 if self.output_check_fn: 283 ret = self.output_check_fn(final_output) 284 if isinstance(ret, bool) and ret: 285 break 286 else: 287 break 288 return final_output 289 290 def _build_task_message(self) -> str: 291 if not self.summary_memory: 292 return self.original_task 293 294 if self.language == "en": 295 return ( 296 f"I want you to complete: {self.original_task}\n\n" 297 f"The following context is provided for your reference:\n{self.summary_memory}" 298 ) 299 return ( 300 f"我希望你完成: {self.original_task}\n\n有以下上下文提供你参考:\n{self.summary_memory}" 301 ) 302 303 def _build_summary_memory_message(self) -> str: 304 if self.language == "en": 305 return f"Summary of previous context:\n{self.summary_memory}" 306 return f"此前上下文摘要:\n{self.summary_memory}"