/ mcp-scan / agent / base_agent.py
base_agent.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import json
 20  import uuid
 21  
 22  from tools.dispatcher import ToolDispatcher
 23  from utils.aig_logger import mcpLogger
 24  from utils.llm import LLM
 25  from utils.loging import logger
 26  from utils.parse import clean_content, parse_tool_invocations
 27  from utils.prompt_manager import prompt_manager
 28  from utils.tool_context import ToolContext
 29  
 30  
 31  class BaseAgent:
 32      def __init__(
 33          self,
 34          name: str,
 35          instruction: str,
 36          llm: LLM,
 37          dispatcher: ToolDispatcher,
 38          specialized_llms: dict = None,
 39          log_step_id: str = None,
 40          debug: bool = False,
 41          capabilities: list[str] = None,
 42          output_format: str | None = None,
 43          output_check_fn: callable = None,
 44          language: str = "zh",
 45      ):
 46          self.llm = llm
 47          self.name = name
 48          self.dispatcher = dispatcher
 49          self.specialized_llms = specialized_llms or {}
 50          self.instruction = instruction
 51          self.capabilities = capabilities or ["standard"]
 52          self.output_format = output_format
 53          self.step_id = log_step_id
 54          self.debug = debug
 55          self.repo_dir = ""
 56          self.output_check_fn = output_check_fn
 57          self.language = language
 58          # loop control
 59          self.iter = 0
 60          self.max_iter = 80
 61          self.is_finished = False
 62          # context
 63          self.history = []
 64          self.original_task = ""
 65          self.summary_memory = ""
 66          # 在模型上下文窗口达到 60% 左右时开始压缩,给后续输出和工具结果留余量。
 67          self.max_history_tokens = max(int(self.llm.context_window * 0.6), 1)
 68          # 压缩时保留最近若干条对话,避免丢失当前执行轨迹。
 69          self.keep_recent_msgs = 8
 70  
 71      async def initialize(self):
 72          """异步初始化系统提示词"""
 73          if not self.history:
 74              system_prompt = await self.generate_system_prompt()
 75              self.history.append({"role": "system", "content": system_prompt})
 76  
 77      def add_user_message(self, message: str):
 78          self.history.append({"role": "user", "content": message})
 79  
 80      def set_repo_dir(self, repo_dir: str):
 81          self.repo_dir = repo_dir
 82  
 83      def should_compact_history(self, usage: dict | None = None) -> bool:
 84          # 没有可压缩的旧消息时,不触发压缩。
 85          if len(self.history) - 2 <= self.keep_recent_msgs:
 86              return False
 87  
 88          prompt_tokens = None
 89          if usage:
 90              prompt_tokens = usage.get("prompt_tokens")
 91          if isinstance(prompt_tokens, int):
 92              return prompt_tokens >= self.max_history_tokens
 93  
 94          # 某些兼容接口没有 usage,退回到消息条数做保守判定。
 95          return len(self.history) > 24
 96  
 97      def compact_history(self):
 98          recent_start = max(2, len(self.history) - self.keep_recent_msgs)
 99  
100          msgs_to_compact = []
101          if self.summary_memory:
102              msgs_to_compact.append(
103                  {"role": "user", "content": self._build_summary_memory_message()}
104              )
105          msgs_to_compact.extend(self.history[2:recent_start])
106          if not msgs_to_compact:
107              return
108  
109          compact_prompt = prompt_manager.load_template("compact")
110          msgs_to_compact.append({"role": "user", "content": compact_prompt})
111          compacted_msgs = self.llm.chat(msgs_to_compact)
112          self.summary_memory = compacted_msgs
113  
114          if not self.original_task:
115              self.original_task = self.history[1]["content"]
116  
117          system_prompt = self.history[0]
118          recent_msgs = self.history[-self.keep_recent_msgs :]
119          self.history = [
120              system_prompt,
121              {
122                  "role": "user",
123                  "content": self._build_task_message(),
124              },
125              *recent_msgs,
126          ]
127  
128      async def generate_system_prompt(self):
129          tools_prompt = await self.dispatcher.get_all_tools_prompt()
130  
131          template_name = "system_prompt"
132          format_kwargs = {
133              "generate_tools": tools_prompt,
134              "name": self.name,
135              "instruction": self.instruction,
136          }
137  
138          return prompt_manager.format_prompt(template_name, **format_kwargs)
139  
140      def next_prompt(self):
141          return prompt_manager.format_prompt("next_prompt", round=self.iter)
142  
143      async def run(self):
144          await self.initialize()
145          return await self._run()
146  
147      async def _run(self):
148          logger.info(f"Agent {self.name} started with max_iter={self.max_iter}")
149          result = ""
150          while not self.is_finished and self.iter < self.max_iter:
151              logger.debug(f"\n{'=' * 50}\nIteration {self.iter}\n{'=' * 50}")
152              response, usage = self.llm.chat(self.history, self.debug, ret_usage=True)
153              logger.debug(f"LLM Response: {response}")
154  
155              self.history.append({"role": "assistant", "content": response})
156              res = await self.handle_response(response)
157              if res is not None:
158                  result = res
159  
160              self.iter += 1
161              if self.should_compact_history(usage) and not self.is_finished:
162                  logger.info(
163                      "Prompt tokens %s exceeded limit %s, compacting context",
164                      usage.get("prompt_tokens") if usage else None,
165                      self.max_history_tokens,
166                  )
167                  self.compact_history()
168  
169          if not self.is_finished:
170              logger.warning(f"Max iterations ({self.max_iter}) reached")
171              mcpLogger.status_update(
172                  self.step_id,
173                  "达到最大迭代次数,返回当前结果"
174                  if self.language != "en"
175                  else "Max iterations reached, returning current result",
176                  "",
177                  "completed",
178              )
179              if not result:
180                  result = await self._format_final_output()
181          return result
182  
183      async def handle_response(self, response: str):
184          tool_invocations = parse_tool_invocations(response)
185          description = clean_content(response)
186          if tool_invocations and tool_invocations["toolName"] == "finish" and description == "":
187              description = "报告完成。"
188              if self.language == "en":
189                  description = "Report completed."
190          if description == "":
191              description = "我将继续执行"
192              if self.language == "en":
193                  description = "I will continue to execute"
194  
195          if tool_invocations:
196              if tool_invocations["toolName"] != "finish":
197                  mcpLogger.status_update(self.step_id, description, "", "running")
198              return await self.process_tool_call(tool_invocations, description)
199          else:
200              mcpLogger.status_update(self.step_id, description, "", "running")
201              return await self.handle_no_tool(description)
202  
203      async def process_tool_call(self, tool_call: dict, description: str):
204          tool_name = tool_call["toolName"]
205          tool_args = tool_call["args"]
206          tool_id = uuid.uuid4().__str__()
207  
208          params = json.dumps(tool_args, ensure_ascii=False) if tool_args else ""
209          if isinstance(params, str):
210              params = params.replace(self.repo_dir, "")
211  
212          mcpLogger.tool_used(self.step_id, tool_id, tool_name, "done", tool_name, f"{params}")
213  
214          if tool_name == "finish":
215              self.is_finished = True
216              logger.info("Finish tool called, final result formatted.")
217  
218              mcpLogger.status_update(self.step_id, description, "", "completed")
219              result = await self._format_final_output()
220              mcpLogger.action_log(tool_id, tool_name, self.step_id, result)
221              return result
222  
223          # 构造上下文
224          context = ToolContext(
225              llm=self.llm,
226              history=self.history,
227              agent_name=self.name,
228              iteration=self.iter,
229              specialized_llms=self.specialized_llms,
230              folder=self.repo_dir,
231              tool_dispatcher=self.dispatcher,
232          )
233  
234          # 通过 Dispatcher 调用工具
235          tool_result = await self.dispatcher.call_tool(tool_name, tool_args, context)
236  
237          # 格式化工具结果并添加到历史
238          result_message = f"{tool_result}"
239  
240          # 添加下一轮提示
241          next_p = self.next_prompt()
242          full_message = f"{next_p}\n\n{result_message}"
243  
244          self.history.append({"role": "user", "content": full_message})
245          mcpLogger.status_update(self.step_id, description, "", "completed")
246  
247          if tool_name != "read_file":
248              mcpLogger.action_log(tool_id, tool_name, self.step_id, f"```\n{result_message}\n```")
249  
250          return None
251  
252      async def handle_no_tool(self, description: str):
253          next_p = self.next_prompt()
254          if self.language == "en":
255              reminder = (
256                  f"{next_p}\n\n"
257                  "No tool call was detected. You must call exactly one tool in your next response. "
258                  "If the task is complete, call finish."
259              )
260          else:
261              reminder = (
262                  f"{next_p}\n\n"
263                  "未检测到工具调用。你下一次回复必须严格调用一个工具。"
264                  "如果任务已完成,请调用 finish。"
265              )
266  
267          self.history.append({"role": "user", "content": reminder})
268          return None
269  
270      async def _format_final_output(self) -> str:
271          """使用 LLM 根据历史记录和预设格式生成最终输出"""
272          # 取最近的对话历史作为参考
273          recent_history = self.history[1:]
274          formatting_prompt = prompt_manager.format_prompt(
275              "format_report", output_format=self.output_format
276          )
277          recent_history.append({"role": "user", "content": formatting_prompt})
278          final_output = ""
279          for _ in range(3):
280              final_output = self.llm.chat(recent_history)
281              logger.info(f"Final Output: {final_output}")
282              if self.output_check_fn:
283                  ret = self.output_check_fn(final_output)
284                  if isinstance(ret, bool) and ret:
285                      break
286              else:
287                  break
288          return final_output
289  
290      def _build_task_message(self) -> str:
291          if not self.summary_memory:
292              return self.original_task
293  
294          if self.language == "en":
295              return (
296                  f"I want you to complete: {self.original_task}\n\n"
297                  f"The following context is provided for your reference:\n{self.summary_memory}"
298              )
299          return (
300              f"我希望你完成: {self.original_task}\n\n有以下上下文提供你参考:\n{self.summary_memory}"
301          )
302  
303      def _build_summary_memory_message(self) -> str:
304          if self.language == "en":
305              return f"Summary of previous context:\n{self.summary_memory}"
306          return f"此前上下文摘要:\n{self.summary_memory}"