/ src / python / txtai / agent / model.py
model.py
  1  """
  2  Model module
  3  """
  4  
  5  import re
  6  
  7  from enum import Enum
  8  
  9  from smolagents import ChatMessage, Model, get_clean_message_list, tool_role_conversions
 10  from smolagents.models import get_tool_call_from_text, remove_content_after_stop_sequences
 11  
 12  from ..pipeline import LLM
 13  
 14  
 15  class PipelineModel(Model):
 16      """
 17      Model backed by a LLM pipeline.
 18      """
 19  
 20      def __init__(self, path=None, method=None, **kwargs):
 21          """
 22          Creates a new LLM model.
 23  
 24          Args:
 25              path: model path or instance
 26              method: llm model framework, infers from path if not provided
 27              kwargs: model keyword arguments
 28          """
 29  
 30          self.llm = path if isinstance(path, LLM) else LLM(path, method, **kwargs)
 31          self.maxlength = 8192
 32  
 33          # Call parent constructor
 34          super().__init__(flatten_messages_as_text=not self.llm.isvision(), model_id=self.llm.generator.path, **kwargs)
 35  
 36      # pylint: disable=W0613
 37      def generate(self, messages, stop_sequences=None, response_format=None, tools_to_call_from=None, **kwargs):
 38          """
 39          Runs LLM inference. This method signature must match the smolagents specification.
 40  
 41          Args:
 42              messages: list of messages to run
 43              stop_sequences: optional list of stop sequences
 44              response_format: response format to use in the model's response.
 45              tools_to_call_from: list of tools that the model can use to generate responses.
 46              kwargs: additional keyword arguments
 47  
 48          Returns:
 49              result
 50          """
 51  
 52          # Get clean message list
 53          messages = self.clean(messages)
 54  
 55          # Get LLM output
 56          response = self.llm(messages, maxlength=self.maxlength, stop=stop_sequences, **kwargs)
 57  
 58          # Remove stop sequences from LLM output
 59          if stop_sequences is not None:
 60              response = remove_content_after_stop_sequences(response, stop_sequences)
 61  
 62          # Load response into a chat message
 63          message = ChatMessage(role="assistant", content=response)
 64  
 65          # Extract first tool action, if necessary
 66          if tools_to_call_from:
 67              message.tool_calls = [
 68                  get_tool_call_from_text(
 69                      re.sub(r".*?Action:(.*?\n\}).*", r"\1", response, flags=re.DOTALL), self.tool_name_key, self.tool_arguments_key
 70                  )
 71              ]
 72  
 73          return message
 74  
 75      def parameters(self, maxlength):
 76          """
 77          Set LLM inference parameters.
 78  
 79          Args:
 80              maxlength: maximum sequence length
 81          """
 82  
 83          self.maxlength = maxlength
 84  
 85      def clean(self, messages):
 86          """
 87          Gets a clean message list.
 88  
 89          Args:
 90              messages: input messages
 91  
 92          Returns:
 93              clean messages
 94          """
 95  
 96          # Get clean message list
 97          messages = get_clean_message_list(messages, role_conversions=tool_role_conversions, flatten_messages_as_text=self.flatten_messages_as_text)
 98  
 99          # Ensure all roles are strings and not enums for compability across LLM frameworks
100          for message in messages:
101              if "role" in message:
102                  message["role"] = message["role"].value if isinstance(message["role"], Enum) else message["role"]
103  
104          return messages