/ functioncall.py
functioncall.py
1 import argparse 2 import torch 3 import json 4 5 from transformers import ( 6 AutoModelForCausalLM, 7 AutoTokenizer, 8 BitsAndBytesConfig 9 ) 10 11 import functions 12 from prompter import PromptManager 13 from validator import validate_function_call_schema 14 15 from utils import ( 16 print_nous_text_art, 17 inference_logger, 18 get_assistant_message, 19 get_chat_template, 20 validate_and_extract_tool_calls 21 ) 22 23 class ModelInference: 24 def __init__(self, model_path, chat_template, load_in_4bit): 25 inference_logger.info(print_nous_text_art()) 26 self.prompter = PromptManager() 27 self.bnb_config = None 28 29 if load_in_4bit == "True": 30 self.bnb_config = BitsAndBytesConfig( 31 load_in_4bit=True, 32 bnb_4bit_quant_type="nf4", 33 bnb_4bit_use_double_quant=True, 34 ) 35 self.model = AutoModelForCausalLM.from_pretrained( 36 model_path, 37 trust_remote_code=True, 38 return_dict=True, 39 quantization_config=self.bnb_config, 40 torch_dtype=torch.float16, 41 attn_implementation="flash_attention_2", 42 device_map="auto", 43 ) 44 45 self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 46 self.tokenizer.pad_token = self.tokenizer.eos_token 47 self.tokenizer.padding_side = "left" 48 49 if self.tokenizer.chat_template is None: 50 print("No chat template defined, getting chat_template...") 51 self.tokenizer.chat_template = get_chat_template(chat_template) 52 53 inference_logger.info(self.model.config) 54 inference_logger.info(self.model.generation_config) 55 inference_logger.info(self.tokenizer.special_tokens_map) 56 57 def process_completion_and_validate(self, completion, chat_template): 58 59 assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token) 60 61 if assistant_message: 62 validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message) 63 64 if validation: 65 inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}") 66 return tool_calls, assistant_message, error_message 67 else: 68 tool_calls = None 69 return tool_calls, assistant_message, error_message 70 else: 71 inference_logger.warning("Assistant message is None") 72 raise ValueError("Assistant message is None") 73 74 def execute_function_call(self, tool_call): 75 function_name = tool_call.get("name") 76 function_to_call = getattr(functions, function_name, None) 77 function_args = tool_call.get("arguments", {}) 78 79 inference_logger.info(f"Invoking function call {function_name} ...") 80 function_response = function_to_call(*function_args.values()) 81 results_dict = f'{{"name": "{function_name}", "content": {function_response}}}' 82 return results_dict 83 84 def run_inference(self, prompt): 85 inputs = self.tokenizer.apply_chat_template( 86 prompt, 87 add_generation_prompt=True, 88 return_tensors='pt' 89 ) 90 91 tokens = self.model.generate( 92 inputs.to(self.model.device), 93 max_new_tokens=1500, 94 temperature=0.8, 95 repetition_penalty=1.1, 96 do_sample=True, 97 eos_token_id=self.tokenizer.eos_token_id 98 ) 99 completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True) 100 return completion 101 102 def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5): 103 try: 104 depth = 0 105 user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet" 106 chat = [{"role": "user", "content": user_message}] 107 tools = functions.get_openai_tools() 108 prompt = self.prompter.generate_prompt(chat, tools, num_fewshot) 109 completion = self.run_inference(prompt) 110 111 def recursive_loop(prompt, completion, depth): 112 nonlocal max_depth 113 tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template) 114 prompt.append({"role": "assistant", "content": assistant_message}) 115 116 tool_message = f"Agent iteration {depth} to assist with user query: {query}\n" 117 if tool_calls: 118 inference_logger.info(f"Assistant Message:\n{assistant_message}") 119 120 for tool_call in tool_calls: 121 validation, message = validate_function_call_schema(tool_call, tools) 122 if validation: 123 try: 124 function_response = self.execute_function_call(tool_call) 125 tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n" 126 inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}") 127 except Exception as e: 128 inference_logger.info(f"Could not execute function: {e}") 129 tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n" 130 else: 131 inference_logger.info(message) 132 tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n" 133 prompt.append({"role": "tool", "content": tool_message}) 134 135 depth += 1 136 if depth >= max_depth: 137 print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") 138 return 139 140 completion = self.run_inference(prompt) 141 recursive_loop(prompt, completion, depth) 142 elif error_message: 143 inference_logger.info(f"Assistant Message:\n{assistant_message}") 144 tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>" 145 prompt.append({"role": "tool", "content": tool_message}) 146 147 depth += 1 148 if depth >= max_depth: 149 print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") 150 return 151 152 completion = self.run_inference(prompt) 153 recursive_loop(prompt, completion, depth) 154 else: 155 inference_logger.info(f"Assistant Message:\n{assistant_message}") 156 157 recursive_loop(prompt, completion, depth) 158 159 except Exception as e: 160 inference_logger.error(f"Exception occurred: {e}") 161 raise e 162 163 if __name__ == "__main__": 164 parser = argparse.ArgumentParser(description="Run recursive function calling loop") 165 parser.add_argument("--model_path", type=str, help="Path to the model folder") 166 parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting") 167 parser.add_argument("--num_fewshot", type=int, default=None, help="Option to use json mode examples") 168 parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes") 169 parser.add_argument("--query", type=str, default="I need the current stock price of Tesla (TSLA)") 170 parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration") 171 args = parser.parse_args() 172 173 # specify custom model path 174 if args.model_path: 175 inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit) 176 else: 177 model_path = 'NousResearch/Hermes-2-Pro-Llama-3-8B' 178 inference = ModelInference(model_path, args.chat_template, args.load_in_4bit) 179 180 # Run the model evaluator 181 inference.generate_function_call(args.query, args.chat_template, args.num_fewshot, args.max_depth)