batch.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 """ 20 Batch 工具 - 批量执行多个工具调用 21 校验不超过 10 个子调用,禁止嵌套批处理/finish 22 """ 23 import asyncio 24 from typing import Any, List, Dict, Optional 25 from tools.registry import register_tool, get_tool_by_name, needs_context 26 from utils.logging import logger 27 from utils.tool_context import ToolContext 28 import inspect 29 30 31 # 禁止在 batch 中调用的工具 32 DISALLOWED_TOOLS = {'batch', 'finish'} 33 MAX_BATCH_SIZE = 10 34 35 36 @register_tool 37 async def batch( 38 tool_calls: List[Dict[str, Any]], 39 context: ToolContext = None 40 ) -> dict[str, Any]: 41 """ 42 批量执行多个工具调用 43 44 Args: 45 tool_calls: 工具调用列表,每个元素包含 tool 和 parameters 46 context: 工具上下文 47 48 Returns: 49 包含执行结果的字典 50 """ 51 if not tool_calls: 52 return { 53 "success": False, 54 "error": "No tool calls provided. Provide at least one tool call." 55 } 56 57 if not isinstance(tool_calls, list): 58 return { 59 "success": False, 60 "error": "tool_calls must be an array of tool call objects" 61 } 62 63 # 限制调用数量 64 if len(tool_calls) > MAX_BATCH_SIZE: 65 logger.warning(f"Batch size {len(tool_calls)} exceeds limit {MAX_BATCH_SIZE}, truncating") 66 67 calls_to_execute = tool_calls[:MAX_BATCH_SIZE] 68 discarded_calls = tool_calls[MAX_BATCH_SIZE:] 69 70 results = [] 71 72 async def execute_single_call(call: Dict[str, Any], index: int) -> Dict[str, Any]: 73 """执行单个工具调用""" 74 tool_name = call.get('tool', '') 75 parameters = call.get('parameters', {}) 76 77 # 检查是否为禁止的工具 78 if tool_name in DISALLOWED_TOOLS: 79 return { 80 "index": index, 81 "tool": tool_name, 82 "success": False, 83 "error": f"Tool '{tool_name}' is not allowed in batch. Disallowed tools: {', '.join(DISALLOWED_TOOLS)}" 84 } 85 86 # 获取工具函数 87 tool_func = get_tool_by_name(tool_name) 88 if not tool_func: 89 return { 90 "index": index, 91 "tool": tool_name, 92 "success": False, 93 "error": f"Tool '{tool_name}' not found in registry" 94 } 95 96 try: 97 # 注入 context 如果需要 98 if needs_context(tool_name) and context: 99 parameters = {**parameters, 'context': context} 100 101 # 执行工具 102 result = tool_func(**parameters) 103 104 # 处理异步结果 105 if inspect.isawaitable(result): 106 result = await result 107 108 # 格式化结果 109 if isinstance(result, dict): 110 return { 111 "index": index, 112 "tool": tool_name, 113 "success": result.get('success', True), 114 "result": result 115 } 116 else: 117 return { 118 "index": index, 119 "tool": tool_name, 120 "success": True, 121 "result": str(result) 122 } 123 124 except Exception as e: 125 logger.error(f"Error executing tool '{tool_name}' in batch: {e}") 126 return { 127 "index": index, 128 "tool": tool_name, 129 "success": False, 130 "error": str(e) 131 } 132 133 # 串行执行所有调用(保持顺序和可预测性) 134 for i, call in enumerate(calls_to_execute): 135 result = await execute_single_call(call, i) 136 results.append(result) 137 138 # 添加被丢弃的调用结果 139 for i, call in enumerate(discarded_calls): 140 results.append({ 141 "index": MAX_BATCH_SIZE + i, 142 "tool": call.get('tool', 'unknown'), 143 "success": False, 144 "error": f"Maximum of {MAX_BATCH_SIZE} tools allowed in batch" 145 }) 146 147 # 统计结果 148 successful = sum(1 for r in results if r.get('success', False)) 149 failed = len(results) - successful 150 151 if failed > 0: 152 output = f"Executed {successful}/{len(results)} tools successfully. {failed} failed." 153 else: 154 output = f"All {successful} tools executed successfully.\n\nKeep using the batch tool for optimal performance!" 155 156 return { 157 "success": failed == 0, 158 "title": f"Batch execution ({successful}/{len(results)} successful)", 159 "output": output, 160 "metadata": { 161 "total_calls": len(results), 162 "successful": successful, 163 "failed": failed, 164 "tools": [call.get('tool', 'unknown') for call in tool_calls], 165 "details": results 166 } 167 } 168