llm.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import asyncio 20 import time 21 22 import openai 23 from typing import List 24 from utils.logging import logger 25 26 # Error prefix constant for consistent error detection across modules 27 LLM_ERROR_PREFIX = "[LLM Error:" 28 29 30 def is_llm_error_response(response: str) -> bool: 31 return isinstance(response, str) and response.startswith(LLM_ERROR_PREFIX) 32 33 34 def format_llm_error_message(language: str, zh_message: str, en_message: str) -> str: 35 message = en_message if language == "en" else zh_message 36 return f"{LLM_ERROR_PREFIX} {message}]" 37 38 39 class LLM: 40 def __init__(self, model, api_key, base_url): 41 self.model = model 42 self.api_key = api_key 43 self.base_url = base_url 44 self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url, timeout=60) 45 self.temperature = 0.7 46 47 async def chat_async(self, message: List[dict], language: str = "zh") -> str: 48 """Non-blocking wrapper around :meth:`chat` for use inside async contexts. 49 50 Runs the synchronous ``openai.OpenAI`` call in a thread-pool executor 51 via :func:`asyncio.to_thread`, so the event loop is free to schedule 52 other coroutines (e.g. parallel skill workers) while waiting for the 53 LLM response. 54 55 Args: 56 message: Conversation history in OpenAI chat format. 57 language: Language for error messages ("zh" or "en"). 58 59 Returns: 60 The model's response text. 61 """ 62 return await asyncio.to_thread(self.chat, message, False, language) 63 64 def chat(self, message: List[dict], p=False, language: str = "zh"): 65 """Send a chat request to the LLM. 66 67 Args: 68 message: Conversation history in OpenAI chat format. 69 p: Whether to print the response. 70 language: Language for error messages ("zh" or "en"). 71 72 Returns: 73 The model's response text, or an error string prefixed with LLM_ERROR_PREFIX. 74 """ 75 retry = 0 76 while True: 77 ret = '' 78 try: 79 for word in self.chat_stream(message): 80 ret += word 81 if ret != '': 82 break 83 else: 84 # Empty response: network jitter or model occasionally returns empty, can retry 85 retry += 1 86 logger.error(f'LLM chat error (empty response), retry {retry}') 87 time.sleep(1.3) 88 if retry > 3: 89 logger.error('LLM chat error, retry 3 times, exit') 90 return format_llm_error_message( 91 language, 92 "连接LLM失败,已重试3次,模型输出为空,请等待1分钟后再试", 93 "Failed to connect to LLM, retried 3 times, model output is empty, please try again after 1 minute", 94 ) 95 continue 96 except openai.BadRequestError as e: 97 # 400 error (e.g. DataInspectionFailed): content issue, retry is meaningless, return immediately 98 error_msg = str(e) 99 logger.warning(f"LLM BadRequestError (400), no retry: {error_msg}") 100 return format_llm_error_message( 101 language, 102 "输入内容触发安全过滤 (400)", 103 "Input content triggered safety filter (400)", 104 ) 105 except (openai.APIConnectionError, openai.APITimeoutError) as e: 106 # Network/timeout error: can retry 107 retry += 1 108 logger.warning(f'LLM connection/timeout error, retry {retry}: {e}') 109 if retry > 5: 110 logger.error('LLM connection error, retry 5 times, exit') 111 return format_llm_error_message( 112 language, 113 "无法连接到LLM服务,已重试5次", 114 "Unable to connect to LLM service, retried 5 times", 115 ) 116 time.sleep(2) 117 continue 118 except openai.APIError as e: 119 # Other API errors (5xx, etc.): can retry 120 retry += 1 121 logger.warning(f'LLM API error, retry {retry}: {e}') 122 if retry > 3: 123 logger.error('LLM API error, retry 3 times, exit') 124 return format_llm_error_message( 125 language, 126 "无法连接到LLM服务,已重试3次", 127 "Unable to connect to LLM service, retried 3 times", 128 ) 129 time.sleep(1) 130 continue 131 except Exception as e: 132 # Unexpected exception: return immediately, do not retry 133 logger.error(f'Unexpected LLM error: {e}', exc_info=True) 134 return format_llm_error_message( 135 language, 136 f"发生未预期的错误 - {str(e)[:100]}", 137 f"Unexpected error occurred - {str(e)[:100]}", 138 ) 139 140 if p: 141 print(ret) 142 return ret 143 144 145 def chat_stream(self, message: List[dict]): 146 """Stream chat completions from the LLM. 147 148 Exceptions from the underlying API call propagate to chat() for 149 centralized handling and retry logic. Only unexpected (non-OpenAI) 150 exceptions are logged here before re-raising. 151 152 Args: 153 message: Conversation history in OpenAI chat format. 154 155 Yields: 156 Content chunks from the model response. 157 158 Raises: 159 openai.BadRequestError: Content triggered safety filter (400). 160 openai.APIConnectionError: Network connection failed. 161 openai.APITimeoutError: Request timed out. 162 openai.APIError: Other API errors (5xx, etc.). 163 """ 164 try: 165 response = self.client.chat.completions.create( 166 model=self.model, 167 messages=message, 168 temperature=self.temperature, 169 stream=True 170 ) 171 172 for chunk in response: 173 choices = getattr(chunk, "choices", None) 174 175 # Ensure choices is a non-empty list 176 if not isinstance(choices, list) or not choices: 177 continue 178 choice = choices[0] 179 180 delta = getattr(choice, "delta", None) 181 if not delta: 182 continue 183 184 content = getattr(delta, "content", None) 185 if content: 186 yield content 187 188 except (openai.BadRequestError, openai.APIConnectionError, 189 openai.APITimeoutError, openai.APIError): 190 # OpenAI exceptions propagate directly to chat() for handling 191 raise 192 except Exception as e: 193 # Log unexpected (non-OpenAI) exceptions before re-raising 194 logger.error(f'Unexpected error in chat_stream: {e}', exc_info=True) 195 raise