/ jsonmode.py
jsonmode.py
1 import argparse 2 import torch 3 import json 4 5 from transformers import ( 6 AutoModelForCausalLM, 7 AutoTokenizer, 8 BitsAndBytesConfig 9 ) 10 11 from validator import validate_json_data 12 13 from utils import ( 14 print_nous_text_art, 15 inference_logger, 16 get_assistant_message, 17 get_chat_template, 18 validate_and_extract_tool_calls 19 ) 20 21 # create your pydantic model for json object here 22 from typing import List, Optional 23 from pydantic import BaseModel 24 25 class Character(BaseModel): 26 name: str 27 species: str 28 role: str 29 personality_traits: Optional[List[str]] 30 special_attacks: Optional[List[str]] 31 32 class Config: 33 schema_extra = { 34 "additionalProperties": False 35 } 36 37 # serialize pydantic model into json schema 38 pydantic_schema = Character.schema_json() 39 40 class ModelInference: 41 def __init__(self, model_path, chat_template, load_in_4bit): 42 inference_logger.info(print_nous_text_art()) 43 self.bnb_config = None 44 45 if load_in_4bit == "True": 46 self.bnb_config = BitsAndBytesConfig( 47 load_in_4bit=True, 48 bnb_4bit_quant_type="nf4", 49 bnb_4bit_use_double_quant=True, 50 ) 51 self.model = AutoModelForCausalLM.from_pretrained( 52 model_path, 53 trust_remote_code=True, 54 return_dict=True, 55 quantization_config=self.bnb_config, 56 torch_dtype=torch.float16, 57 attn_implementation="flash_attention_2", 58 device_map="auto", 59 ) 60 61 self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 62 self.tokenizer.pad_token = self.tokenizer.eos_token 63 self.tokenizer.padding_side = "left" 64 65 if self.tokenizer.chat_template is None: 66 print("No chat template defined, getting chat_template...") 67 self.tokenizer.chat_template = get_chat_template(chat_template) 68 69 inference_logger.info(self.model.config) 70 inference_logger.info(self.model.generation_config) 71 inference_logger.info(self.tokenizer.special_tokens_map) 72 73 def run_inference(self, prompt): 74 inputs = self.tokenizer.apply_chat_template( 75 prompt, 76 add_generation_prompt=True, 77 return_tensors='pt' 78 ) 79 80 tokens = self.model.generate( 81 inputs.to(self.model.device), 82 max_new_tokens=1500, 83 temperature=0.8, 84 repetition_penalty=1.1, 85 do_sample=True, 86 eos_token_id=self.tokenizer.eos_token_id 87 ) 88 completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True) 89 return completion 90 91 def generate_json_completion(self, query, chat_template, max_depth=5): 92 try: 93 depth = 0 94 sys_prompt = f"You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n<schema>\n{pydantic_schema}\n</schema>" 95 prompt = [{"role": "system", "content": sys_prompt}] 96 prompt.append({"role": "user", "content": query}) 97 98 inference_logger.info(f"Running inference to generate json object for pydantic schema:\n{json.dumps(json.loads(pydantic_schema), indent=2)}") 99 completion = self.run_inference(prompt) 100 101 def recursive_loop(prompt, completion, depth): 102 nonlocal max_depth 103 104 assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token) 105 106 tool_message = f"Agent iteration {depth} to assist with user query: {query}\n" 107 if assistant_message is not None: 108 validation, json_object, error_message = validate_json_data(assistant_message, json.loads(pydantic_schema)) 109 if validation: 110 inference_logger.info(f"Assistant Message:\n{assistant_message}") 111 inference_logger.info(f"json schema validation passed") 112 inference_logger.info(f"parsed json object:\n{json.dumps(json_object, indent=2)}") 113 elif error_message: 114 inference_logger.info(f"Assistant Message:\n{assistant_message}") 115 inference_logger.info(f"json schema validation failed") 116 tool_message += f"<tool_response>\nJson schema validation failed\nHere's the error stacktrace: {error_message}\nPlease return corrrect json object\n<tool_response>" 117 118 depth += 1 119 if depth >= max_depth: 120 print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") 121 return 122 123 prompt.append({"role": "tool", "content": tool_message}) 124 completion = self.run_inference(prompt) 125 recursive_loop(prompt, completion, depth) 126 else: 127 inference_logger.warning("Assistant message is None") 128 recursive_loop(prompt, completion, depth) 129 except Exception as e: 130 inference_logger.error(f"Exception occurred: {e}") 131 raise e 132 133 if __name__ == "__main__": 134 parser = argparse.ArgumentParser(description="Run json mode completion") 135 parser.add_argument("--model_path", type=str, help="Path to the model folder") 136 parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting") 137 parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes") 138 parser.add_argument("--query", type=str, default="Please return a json object to represent Goku from the anime Dragon Ball Z?") 139 parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration") 140 args = parser.parse_args() 141 142 # specify custom model path 143 if args.model_path: 144 inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit) 145 else: 146 model_path = 'NousResearch/Hermes-2-Pro-Llama-3-8B' 147 inference = ModelInference(model_path, args.chat_template, args.load_in_4bit) 148 149 # Run the model evaluator 150 inference.generate_json_completion(args.query, args.chat_template, args.max_depth)