/ agent-scan / utils / llm.py
llm.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import asyncio
 20  import time
 21  
 22  import openai
 23  from typing import List
 24  from utils.logging import logger
 25  
 26  # Error prefix constant for consistent error detection across modules
 27  LLM_ERROR_PREFIX = "[LLM Error:"
 28  
 29  
 30  def is_llm_error_response(response: str) -> bool:
 31      return isinstance(response, str) and response.startswith(LLM_ERROR_PREFIX)
 32  
 33  
 34  def format_llm_error_message(language: str, zh_message: str, en_message: str) -> str:
 35      message = en_message if language == "en" else zh_message
 36      return f"{LLM_ERROR_PREFIX} {message}]"
 37  
 38  
 39  class LLM:
 40      def __init__(self, model, api_key, base_url):
 41          self.model = model
 42          self.api_key = api_key
 43          self.base_url = base_url
 44          self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url, timeout=60)
 45          self.temperature = 0.7
 46  
 47      async def chat_async(self, message: List[dict], language: str = "zh") -> str:
 48          """Non-blocking wrapper around :meth:`chat` for use inside async contexts.
 49  
 50          Runs the synchronous ``openai.OpenAI`` call in a thread-pool executor
 51          via :func:`asyncio.to_thread`, so the event loop is free to schedule
 52          other coroutines (e.g. parallel skill workers) while waiting for the
 53          LLM response.
 54  
 55          Args:
 56              message: Conversation history in OpenAI chat format.
 57              language: Language for error messages ("zh" or "en").
 58  
 59          Returns:
 60              The model's response text.
 61          """
 62          return await asyncio.to_thread(self.chat, message, False, language)
 63  
 64      def chat(self, message: List[dict], p=False, language: str = "zh"):
 65          """Send a chat request to the LLM.
 66  
 67          Args:
 68              message: Conversation history in OpenAI chat format.
 69              p: Whether to print the response.
 70              language: Language for error messages ("zh" or "en").
 71  
 72          Returns:
 73              The model's response text, or an error string prefixed with LLM_ERROR_PREFIX.
 74          """
 75          retry = 0
 76          while True:
 77              ret = ''
 78              try:
 79                  for word in self.chat_stream(message):
 80                      ret += word
 81                  if ret != '':
 82                      break
 83                  else:
 84                      # Empty response: network jitter or model occasionally returns empty, can retry
 85                      retry += 1
 86                      logger.error(f'LLM chat error (empty response), retry {retry}')
 87                      time.sleep(1.3)
 88                      if retry > 3:
 89                          logger.error('LLM chat error, retry 3 times, exit')
 90                          return format_llm_error_message(
 91                              language,
 92                              "连接LLM失败,已重试3次,模型输出为空,请等待1分钟后再试",
 93                              "Failed to connect to LLM, retried 3 times, model output is empty, please try again after 1 minute",
 94                          )
 95                      continue
 96              except openai.BadRequestError as e:
 97                  # 400 error (e.g. DataInspectionFailed): content issue, retry is meaningless, return immediately
 98                  error_msg = str(e)
 99                  logger.warning(f"LLM BadRequestError (400), no retry: {error_msg}")
100                  return format_llm_error_message(
101                      language,
102                      "输入内容触发安全过滤 (400)",
103                      "Input content triggered safety filter (400)",
104                  )
105              except (openai.APIConnectionError, openai.APITimeoutError) as e:
106                  # Network/timeout error: can retry
107                  retry += 1
108                  logger.warning(f'LLM connection/timeout error, retry {retry}: {e}')
109                  if retry > 5:
110                      logger.error('LLM connection error, retry 5 times, exit')
111                      return format_llm_error_message(
112                          language,
113                          "无法连接到LLM服务,已重试5次",
114                          "Unable to connect to LLM service, retried 5 times",
115                      )
116                  time.sleep(2)
117                  continue
118              except openai.APIError as e:
119                  # Other API errors (5xx, etc.): can retry
120                  retry += 1
121                  logger.warning(f'LLM API error, retry {retry}: {e}')
122                  if retry > 3:
123                      logger.error('LLM API error, retry 3 times, exit')
124                      return format_llm_error_message(
125                          language,
126                          "无法连接到LLM服务,已重试3次",
127                          "Unable to connect to LLM service, retried 3 times",
128                      )
129                  time.sleep(1)
130                  continue
131              except Exception as e:
132                  # Unexpected exception: return immediately, do not retry
133                  logger.error(f'Unexpected LLM error: {e}', exc_info=True)
134                  return format_llm_error_message(
135                      language,
136                      f"发生未预期的错误 - {str(e)[:100]}",
137                      f"Unexpected error occurred - {str(e)[:100]}",
138                  )
139  
140          if p:
141              print(ret)
142          return ret
143  
144  
145      def chat_stream(self, message: List[dict]):
146          """Stream chat completions from the LLM.
147  
148          Exceptions from the underlying API call propagate to chat() for
149          centralized handling and retry logic. Only unexpected (non-OpenAI)
150          exceptions are logged here before re-raising.
151  
152          Args:
153              message: Conversation history in OpenAI chat format.
154  
155          Yields:
156              Content chunks from the model response.
157  
158          Raises:
159              openai.BadRequestError: Content triggered safety filter (400).
160              openai.APIConnectionError: Network connection failed.
161              openai.APITimeoutError: Request timed out.
162              openai.APIError: Other API errors (5xx, etc.).
163          """
164          try:
165              response = self.client.chat.completions.create(
166                  model=self.model,
167                  messages=message,
168                  temperature=self.temperature,
169                  stream=True
170              )
171  
172              for chunk in response:
173                  choices = getattr(chunk, "choices", None)
174  
175                  # Ensure choices is a non-empty list
176                  if not isinstance(choices, list) or not choices:
177                      continue
178                  choice = choices[0]
179  
180                  delta = getattr(choice, "delta", None)
181                  if not delta:
182                      continue
183  
184                  content = getattr(delta, "content", None)
185                  if content:
186                      yield content
187  
188          except (openai.BadRequestError, openai.APIConnectionError,
189                  openai.APITimeoutError, openai.APIError):
190              # OpenAI exceptions propagate directly to chat() for handling
191              raise
192          except Exception as e:
193              # Log unexpected (non-OpenAI) exceptions before re-raising
194              logger.error(f'Unexpected error in chat_stream: {e}', exc_info=True)
195              raise