/ agent-scan / tools / batch / batch.py
batch.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  """
 20  Batch 工具 - 批量执行多个工具调用
 21  校验不超过 10 个子调用,禁止嵌套批处理/finish
 22  """
 23  import asyncio
 24  from typing import Any, List, Dict, Optional
 25  from tools.registry import register_tool, get_tool_by_name, needs_context
 26  from utils.logging import logger
 27  from utils.tool_context import ToolContext
 28  import inspect
 29  
 30  
 31  # 禁止在 batch 中调用的工具
 32  DISALLOWED_TOOLS = {'batch', 'finish'}
 33  MAX_BATCH_SIZE = 10
 34  
 35  
 36  @register_tool
 37  async def batch(
 38      tool_calls: List[Dict[str, Any]],
 39      context: ToolContext = None
 40  ) -> dict[str, Any]:
 41      """
 42      批量执行多个工具调用
 43      
 44      Args:
 45          tool_calls: 工具调用列表,每个元素包含 tool 和 parameters
 46          context: 工具上下文
 47          
 48      Returns:
 49          包含执行结果的字典
 50      """
 51      if not tool_calls:
 52          return {
 53              "success": False,
 54              "error": "No tool calls provided. Provide at least one tool call."
 55          }
 56      
 57      if not isinstance(tool_calls, list):
 58          return {
 59              "success": False,
 60              "error": "tool_calls must be an array of tool call objects"
 61          }
 62      
 63      # 限制调用数量
 64      if len(tool_calls) > MAX_BATCH_SIZE:
 65          logger.warning(f"Batch size {len(tool_calls)} exceeds limit {MAX_BATCH_SIZE}, truncating")
 66      
 67      calls_to_execute = tool_calls[:MAX_BATCH_SIZE]
 68      discarded_calls = tool_calls[MAX_BATCH_SIZE:]
 69      
 70      results = []
 71      
 72      async def execute_single_call(call: Dict[str, Any], index: int) -> Dict[str, Any]:
 73          """执行单个工具调用"""
 74          tool_name = call.get('tool', '')
 75          parameters = call.get('parameters', {})
 76          
 77          # 检查是否为禁止的工具
 78          if tool_name in DISALLOWED_TOOLS:
 79              return {
 80                  "index": index,
 81                  "tool": tool_name,
 82                  "success": False,
 83                  "error": f"Tool '{tool_name}' is not allowed in batch. Disallowed tools: {', '.join(DISALLOWED_TOOLS)}"
 84              }
 85          
 86          # 获取工具函数
 87          tool_func = get_tool_by_name(tool_name)
 88          if not tool_func:
 89              return {
 90                  "index": index,
 91                  "tool": tool_name,
 92                  "success": False,
 93                  "error": f"Tool '{tool_name}' not found in registry"
 94              }
 95          
 96          try:
 97              # 注入 context 如果需要
 98              if needs_context(tool_name) and context:
 99                  parameters = {**parameters, 'context': context}
100              
101              # 执行工具
102              result = tool_func(**parameters)
103              
104              # 处理异步结果
105              if inspect.isawaitable(result):
106                  result = await result
107              
108              # 格式化结果
109              if isinstance(result, dict):
110                  return {
111                      "index": index,
112                      "tool": tool_name,
113                      "success": result.get('success', True),
114                      "result": result
115                  }
116              else:
117                  return {
118                      "index": index,
119                      "tool": tool_name,
120                      "success": True,
121                      "result": str(result)
122                  }
123                  
124          except Exception as e:
125              logger.error(f"Error executing tool '{tool_name}' in batch: {e}")
126              return {
127                  "index": index,
128                  "tool": tool_name,
129                  "success": False,
130                  "error": str(e)
131              }
132      
133      # 串行执行所有调用(保持顺序和可预测性)
134      for i, call in enumerate(calls_to_execute):
135          result = await execute_single_call(call, i)
136          results.append(result)
137      
138      # 添加被丢弃的调用结果
139      for i, call in enumerate(discarded_calls):
140          results.append({
141              "index": MAX_BATCH_SIZE + i,
142              "tool": call.get('tool', 'unknown'),
143              "success": False,
144              "error": f"Maximum of {MAX_BATCH_SIZE} tools allowed in batch"
145          })
146      
147      # 统计结果
148      successful = sum(1 for r in results if r.get('success', False))
149      failed = len(results) - successful
150      
151      if failed > 0:
152          output = f"Executed {successful}/{len(results)} tools successfully. {failed} failed."
153      else:
154          output = f"All {successful} tools executed successfully.\n\nKeep using the batch tool for optimal performance!"
155      
156      return {
157          "success": failed == 0,
158          "title": f"Batch execution ({successful}/{len(results)} successful)",
159          "output": output,
160          "metadata": {
161              "total_calls": len(results),
162              "successful": successful,
163              "failed": failed,
164              "tools": [call.get('tool', 'unknown') for call in tool_calls],
165              "details": results
166          }
167      }
168