/ utils.py
utils.py
1 import ast 2 import os 3 import re 4 import json 5 import logging 6 import datetime 7 import xml.etree.ElementTree as ET 8 9 from art import text2art 10 from logging.handlers import RotatingFileHandler 11 12 logging.basicConfig( 13 format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", 14 datefmt="%Y-%m-%d:%H:%M:%S", 15 level=logging.INFO, 16 ) 17 script_dir = os.path.dirname(os.path.abspath(__file__)) 18 now = datetime.datetime.now() 19 log_folder = os.path.join(script_dir, "inference_logs") 20 os.makedirs(log_folder, exist_ok=True) 21 log_file_path = os.path.join( 22 log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log" 23 ) 24 # Use RotatingFileHandler from the logging.handlers module 25 file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0) 26 file_handler.setLevel(logging.INFO) 27 28 formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S") 29 file_handler.setFormatter(formatter) 30 31 inference_logger = logging.getLogger("function-calling-inference") 32 inference_logger.addHandler(file_handler) 33 34 def print_nous_text_art(suffix=None): 35 font = "nancyj" 36 ascii_text = " nousresearch" 37 if suffix: 38 ascii_text += f" x {suffix}" 39 ascii_art = text2art(ascii_text, font=font) 40 print(ascii_art) 41 42 def get_fewshot_examples(num_fewshot): 43 """return a list of few shot examples""" 44 example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json') 45 with open(example_path, 'r') as file: 46 examples = json.load(file) # Use json.load with the file object, not the file path 47 if num_fewshot > len(examples): 48 raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).") 49 return examples[:num_fewshot] 50 51 def get_chat_template(chat_template): 52 """read chat template from jinja file""" 53 template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2") 54 55 if not os.path.exists(template_path): 56 print 57 inference_logger.error(f"Template file not found: {chat_template}") 58 return None 59 try: 60 with open(template_path, 'r') as file: 61 template = file.read() 62 return template 63 except Exception as e: 64 print(f"Error loading template: {e}") 65 return None 66 67 def get_assistant_message(completion, chat_template, eos_token): 68 """define and match pattern to find the assistant message""" 69 completion = completion.strip() 70 71 if chat_template == "zephyr": 72 assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL) 73 elif chat_template == "chatml": 74 assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL) 75 76 elif chat_template == "vicuna": 77 assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL) 78 else: 79 raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.") 80 81 assistant_match = assistant_pattern.search(completion) 82 if assistant_match: 83 assistant_content = assistant_match.group(1).strip() 84 if chat_template == "vicuna": 85 eos_token = f"</s>{eos_token}" 86 return assistant_content.replace(eos_token, "") 87 else: 88 assistant_content = None 89 inference_logger.info("No match found for the assistant pattern") 90 return assistant_content 91 92 def validate_and_extract_tool_calls(assistant_content): 93 validation_result = False 94 tool_calls = [] 95 error_message = None 96 97 try: 98 # wrap content in root element 99 xml_root_element = f"<root>{assistant_content}</root>" 100 root = ET.fromstring(xml_root_element) 101 102 # extract JSON data 103 for element in root.findall(".//tool_call"): 104 json_data = None 105 try: 106 json_text = element.text.strip() 107 108 try: 109 # Prioritize json.loads for better error handling 110 json_data = json.loads(json_text) 111 except json.JSONDecodeError as json_err: 112 try: 113 # Fallback to ast.literal_eval if json.loads fails 114 json_data = ast.literal_eval(json_text) 115 except (SyntaxError, ValueError) as eval_err: 116 error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\ 117 f"- JSON Decode Error: {json_err}\n"\ 118 f"- Fallback Syntax/Value Error: {eval_err}\n"\ 119 f"- Problematic JSON text: {json_text}" 120 inference_logger.error(error_message) 121 continue 122 except Exception as e: 123 error_message = f"Cannot strip text: {e}" 124 inference_logger.error(error_message) 125 126 if json_data is not None: 127 tool_calls.append(json_data) 128 validation_result = True 129 130 except ET.ParseError as err: 131 error_message = f"XML Parse Error: {err}" 132 inference_logger.error(f"XML Parse Error: {err}") 133 134 # Return default values if no valid data is extracted 135 return validation_result, tool_calls, error_message 136 137 def extract_json_from_markdown(text): 138 """ 139 Extracts the JSON string from the given text using a regular expression pattern. 140 141 Args: 142 text (str): The input text containing the JSON string. 143 144 Returns: 145 dict: The JSON data loaded from the extracted string, or None if the JSON string is not found. 146 """ 147 json_pattern = r'```json\r?\n(.*?)\r?\n```' 148 match = re.search(json_pattern, text, re.DOTALL) 149 if match: 150 json_string = match.group(1) 151 try: 152 data = json.loads(json_string) 153 return data 154 except json.JSONDecodeError as e: 155 print(f"Error decoding JSON string: {e}") 156 else: 157 print("JSON string not found in the text.") 158 return None 159