/ jsonmode.py
jsonmode.py
  1  import argparse
  2  import torch
  3  import json
  4  
  5  from transformers import (
  6      AutoModelForCausalLM,
  7      AutoTokenizer,
  8      BitsAndBytesConfig
  9  )
 10  
 11  from validator import validate_json_data
 12  
 13  from utils import (
 14      print_nous_text_art,
 15      inference_logger,
 16      get_assistant_message,
 17      get_chat_template,
 18      validate_and_extract_tool_calls
 19  )
 20  
 21  # create your pydantic model for json object here
 22  from typing import List, Optional
 23  from pydantic import BaseModel
 24  
 25  class Character(BaseModel):
 26      name: str
 27      species: str
 28      role: str
 29      personality_traits: Optional[List[str]]
 30      special_attacks: Optional[List[str]]
 31  
 32      class Config:
 33          schema_extra = {
 34              "additionalProperties": False
 35          }
 36  
 37  # serialize pydantic model into json schema
 38  pydantic_schema = Character.schema_json()
 39  
 40  class ModelInference:
 41      def __init__(self, model_path, chat_template, load_in_4bit):
 42          inference_logger.info(print_nous_text_art())
 43          self.bnb_config = None
 44  
 45          if load_in_4bit == "True":
 46              self.bnb_config = BitsAndBytesConfig(
 47                  load_in_4bit=True,
 48                  bnb_4bit_quant_type="nf4",
 49                  bnb_4bit_use_double_quant=True,
 50              )
 51          self.model = AutoModelForCausalLM.from_pretrained(
 52              model_path,
 53              trust_remote_code=True,
 54              return_dict=True,
 55              quantization_config=self.bnb_config,
 56              torch_dtype=torch.float16,
 57              attn_implementation="flash_attention_2",
 58              device_map="auto",
 59          )
 60  
 61          self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 62          self.tokenizer.pad_token = self.tokenizer.eos_token
 63          self.tokenizer.padding_side = "left"
 64  
 65          if self.tokenizer.chat_template is None:
 66              print("No chat template defined, getting chat_template...")
 67              self.tokenizer.chat_template = get_chat_template(chat_template)
 68          
 69          inference_logger.info(self.model.config)
 70          inference_logger.info(self.model.generation_config)
 71          inference_logger.info(self.tokenizer.special_tokens_map)
 72      
 73      def run_inference(self, prompt):
 74          inputs = self.tokenizer.apply_chat_template(
 75              prompt,
 76              add_generation_prompt=True,
 77              return_tensors='pt'
 78          )
 79  
 80          tokens = self.model.generate(
 81              inputs.to(self.model.device),
 82              max_new_tokens=1500,
 83              temperature=0.8,
 84              repetition_penalty=1.1,
 85              do_sample=True,
 86              eos_token_id=self.tokenizer.eos_token_id
 87          )
 88          completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
 89          return completion
 90  
 91      def generate_json_completion(self, query, chat_template, max_depth=5):
 92          try:
 93              depth = 0
 94              sys_prompt = f"You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n<schema>\n{pydantic_schema}\n</schema>"
 95              prompt = [{"role": "system", "content": sys_prompt}]
 96              prompt.append({"role": "user", "content": query})
 97  
 98              inference_logger.info(f"Running inference to generate json object for pydantic schema:\n{json.dumps(json.loads(pydantic_schema), indent=2)}")
 99              completion = self.run_inference(prompt)
100  
101              def recursive_loop(prompt, completion, depth):
102                  nonlocal max_depth
103  
104                  assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
105  
106                  tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
107                  if assistant_message is not None:
108                      validation, json_object, error_message = validate_json_data(assistant_message, json.loads(pydantic_schema))
109                      if validation:
110                          inference_logger.info(f"Assistant Message:\n{assistant_message}")
111                          inference_logger.info(f"json schema validation passed")
112                          inference_logger.info(f"parsed json object:\n{json.dumps(json_object, indent=2)}")
113                      elif error_message:
114                          inference_logger.info(f"Assistant Message:\n{assistant_message}")
115                          inference_logger.info(f"json schema validation failed")
116                          tool_message += f"<tool_response>\nJson schema validation failed\nHere's the error stacktrace: {error_message}\nPlease return corrrect json object\n<tool_response>"
117                          
118                          depth += 1
119                          if depth >= max_depth:
120                              print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
121                              return
122                          
123                          prompt.append({"role": "tool", "content": tool_message})
124                          completion = self.run_inference(prompt)
125                          recursive_loop(prompt, completion, depth)
126                  else:
127                      inference_logger.warning("Assistant message is None")
128              recursive_loop(prompt, completion, depth)
129          except Exception as e:
130              inference_logger.error(f"Exception occurred: {e}")
131              raise e
132  
133  if __name__ == "__main__":
134      parser = argparse.ArgumentParser(description="Run json mode completion")
135      parser.add_argument("--model_path", type=str, help="Path to the model folder")
136      parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting")
137      parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes")
138      parser.add_argument("--query", type=str, default="Please return a json object to represent Goku from the anime Dragon Ball Z?")
139      parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration")
140      args = parser.parse_args()
141  
142      # specify custom model path
143      if args.model_path:
144          inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit)
145      else:
146          model_path = 'NousResearch/Hermes-2-Pro-Llama-3-8B'
147          inference = ModelInference(model_path, args.chat_template, args.load_in_4bit)
148          
149      # Run the model evaluator
150      inference.generate_json_completion(args.query, args.chat_template, args.max_depth)