model.py
1 """ 2 Model module 3 """ 4 5 import re 6 7 from enum import Enum 8 9 from smolagents import ChatMessage, Model, get_clean_message_list, tool_role_conversions 10 from smolagents.models import get_tool_call_from_text, remove_content_after_stop_sequences 11 12 from ..pipeline import LLM 13 14 15 class PipelineModel(Model): 16 """ 17 Model backed by a LLM pipeline. 18 """ 19 20 def __init__(self, path=None, method=None, **kwargs): 21 """ 22 Creates a new LLM model. 23 24 Args: 25 path: model path or instance 26 method: llm model framework, infers from path if not provided 27 kwargs: model keyword arguments 28 """ 29 30 self.llm = path if isinstance(path, LLM) else LLM(path, method, **kwargs) 31 self.maxlength = 8192 32 33 # Call parent constructor 34 super().__init__(flatten_messages_as_text=not self.llm.isvision(), model_id=self.llm.generator.path, **kwargs) 35 36 # pylint: disable=W0613 37 def generate(self, messages, stop_sequences=None, response_format=None, tools_to_call_from=None, **kwargs): 38 """ 39 Runs LLM inference. This method signature must match the smolagents specification. 40 41 Args: 42 messages: list of messages to run 43 stop_sequences: optional list of stop sequences 44 response_format: response format to use in the model's response. 45 tools_to_call_from: list of tools that the model can use to generate responses. 46 kwargs: additional keyword arguments 47 48 Returns: 49 result 50 """ 51 52 # Get clean message list 53 messages = self.clean(messages) 54 55 # Get LLM output 56 response = self.llm(messages, maxlength=self.maxlength, stop=stop_sequences, **kwargs) 57 58 # Remove stop sequences from LLM output 59 if stop_sequences is not None: 60 response = remove_content_after_stop_sequences(response, stop_sequences) 61 62 # Load response into a chat message 63 message = ChatMessage(role="assistant", content=response) 64 65 # Extract first tool action, if necessary 66 if tools_to_call_from: 67 message.tool_calls = [ 68 get_tool_call_from_text( 69 re.sub(r".*?Action:(.*?\n\}).*", r"\1", response, flags=re.DOTALL), self.tool_name_key, self.tool_arguments_key 70 ) 71 ] 72 73 return message 74 75 def parameters(self, maxlength): 76 """ 77 Set LLM inference parameters. 78 79 Args: 80 maxlength: maximum sequence length 81 """ 82 83 self.maxlength = maxlength 84 85 def clean(self, messages): 86 """ 87 Gets a clean message list. 88 89 Args: 90 messages: input messages 91 92 Returns: 93 clean messages 94 """ 95 96 # Get clean message list 97 messages = get_clean_message_list(messages, role_conversions=tool_role_conversions, flatten_messages_as_text=self.flatten_messages_as_text) 98 99 # Ensure all roles are strings and not enums for compability across LLM frameworks 100 for message in messages: 101 if "role" in message: 102 message["role"] = message["role"].value if isinstance(message["role"], Enum) else message["role"] 103 104 return messages