/ functioncall.py
functioncall.py
  1  import argparse
  2  import torch
  3  import json
  4  
  5  from transformers import (
  6      AutoModelForCausalLM,
  7      AutoTokenizer,
  8      BitsAndBytesConfig
  9  )
 10  
 11  import functions
 12  from prompter import PromptManager
 13  from validator import validate_function_call_schema
 14  
 15  from utils import (
 16      print_nous_text_art,
 17      inference_logger,
 18      get_assistant_message,
 19      get_chat_template,
 20      validate_and_extract_tool_calls
 21  )
 22  
 23  class ModelInference:
 24      def __init__(self, model_path, chat_template, load_in_4bit):
 25          inference_logger.info(print_nous_text_art())
 26          self.prompter = PromptManager()
 27          self.bnb_config = None
 28  
 29          if load_in_4bit == "True":
 30              self.bnb_config = BitsAndBytesConfig(
 31                  load_in_4bit=True,
 32                  bnb_4bit_quant_type="nf4",
 33                  bnb_4bit_use_double_quant=True,
 34              )
 35          self.model = AutoModelForCausalLM.from_pretrained(
 36              model_path,
 37              trust_remote_code=True,
 38              return_dict=True,
 39              quantization_config=self.bnb_config,
 40              torch_dtype=torch.float16,
 41              attn_implementation="flash_attention_2",
 42              device_map="auto",
 43          )
 44  
 45          self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 46          self.tokenizer.pad_token = self.tokenizer.eos_token
 47          self.tokenizer.padding_side = "left"
 48  
 49          if self.tokenizer.chat_template is None:
 50              print("No chat template defined, getting chat_template...")
 51              self.tokenizer.chat_template = get_chat_template(chat_template)
 52          
 53          inference_logger.info(self.model.config)
 54          inference_logger.info(self.model.generation_config)
 55          inference_logger.info(self.tokenizer.special_tokens_map)
 56  
 57      def process_completion_and_validate(self, completion, chat_template):
 58  
 59          assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
 60  
 61          if assistant_message:
 62              validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message)
 63  
 64              if validation:
 65                  inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
 66                  return tool_calls, assistant_message, error_message
 67              else:
 68                  tool_calls = None
 69                  return tool_calls, assistant_message, error_message
 70          else:
 71              inference_logger.warning("Assistant message is None")
 72              raise ValueError("Assistant message is None")
 73          
 74      def execute_function_call(self, tool_call):
 75          function_name = tool_call.get("name")
 76          function_to_call = getattr(functions, function_name, None)
 77          function_args = tool_call.get("arguments", {})
 78  
 79          inference_logger.info(f"Invoking function call {function_name} ...")
 80          function_response = function_to_call(*function_args.values())
 81          results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
 82          return results_dict
 83      
 84      def run_inference(self, prompt):
 85          inputs = self.tokenizer.apply_chat_template(
 86              prompt,
 87              add_generation_prompt=True,
 88              return_tensors='pt'
 89          )
 90  
 91          tokens = self.model.generate(
 92              inputs.to(self.model.device),
 93              max_new_tokens=1500,
 94              temperature=0.8,
 95              repetition_penalty=1.1,
 96              do_sample=True,
 97              eos_token_id=self.tokenizer.eos_token_id
 98          )
 99          completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
100          return completion
101  
102      def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
103          try:
104              depth = 0
105              user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
106              chat = [{"role": "user", "content": user_message}]
107              tools = functions.get_openai_tools()
108              prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
109              completion = self.run_inference(prompt)
110  
111              def recursive_loop(prompt, completion, depth):
112                  nonlocal max_depth
113                  tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
114                  prompt.append({"role": "assistant", "content": assistant_message})
115  
116                  tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
117                  if tool_calls:
118                      inference_logger.info(f"Assistant Message:\n{assistant_message}")
119  
120                      for tool_call in tool_calls:
121                          validation, message = validate_function_call_schema(tool_call, tools)
122                          if validation:
123                              try:
124                                  function_response = self.execute_function_call(tool_call)
125                                  tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
126                                  inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
127                              except Exception as e:
128                                  inference_logger.info(f"Could not execute function: {e}")
129                                  tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
130                          else:
131                              inference_logger.info(message)
132                              tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
133                      prompt.append({"role": "tool", "content": tool_message})
134  
135                      depth += 1
136                      if depth >= max_depth:
137                          print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
138                          return
139  
140                      completion = self.run_inference(prompt)
141                      recursive_loop(prompt, completion, depth)
142                  elif error_message:
143                      inference_logger.info(f"Assistant Message:\n{assistant_message}")
144                      tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
145                      prompt.append({"role": "tool", "content": tool_message})
146  
147                      depth += 1
148                      if depth >= max_depth:
149                          print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
150                          return
151  
152                      completion = self.run_inference(prompt)
153                      recursive_loop(prompt, completion, depth)
154                  else:
155                      inference_logger.info(f"Assistant Message:\n{assistant_message}")
156  
157              recursive_loop(prompt, completion, depth)
158  
159          except Exception as e:
160              inference_logger.error(f"Exception occurred: {e}")
161              raise e
162  
163  if __name__ == "__main__":
164      parser = argparse.ArgumentParser(description="Run recursive function calling loop")
165      parser.add_argument("--model_path", type=str, help="Path to the model folder")
166      parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting")
167      parser.add_argument("--num_fewshot", type=int, default=None, help="Option to use json mode examples")
168      parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes")
169      parser.add_argument("--query", type=str, default="I need the current stock price of Tesla (TSLA)")
170      parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration")
171      args = parser.parse_args()
172  
173      # specify custom model path
174      if args.model_path:
175          inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit)
176      else:
177          model_path = 'NousResearch/Hermes-2-Pro-Llama-3-8B'
178          inference = ModelInference(model_path, args.chat_template, args.load_in_4bit)
179          
180      # Run the model evaluator
181      inference.generate_function_call(args.query, args.chat_template, args.num_fewshot, args.max_depth)