/ utils.py
utils.py
  1  import ast
  2  import os
  3  import re
  4  import json
  5  import logging
  6  import datetime
  7  import xml.etree.ElementTree as ET
  8  
  9  from art import text2art
 10  from logging.handlers import RotatingFileHandler
 11  
 12  logging.basicConfig(
 13      format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
 14      datefmt="%Y-%m-%d:%H:%M:%S",
 15      level=logging.INFO,
 16  )
 17  script_dir = os.path.dirname(os.path.abspath(__file__))
 18  now = datetime.datetime.now()
 19  log_folder = os.path.join(script_dir, "inference_logs")
 20  os.makedirs(log_folder, exist_ok=True)
 21  log_file_path = os.path.join(
 22      log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
 23  )
 24  # Use RotatingFileHandler from the logging.handlers module
 25  file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
 26  file_handler.setLevel(logging.INFO)
 27  
 28  formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
 29  file_handler.setFormatter(formatter)
 30  
 31  inference_logger = logging.getLogger("function-calling-inference")
 32  inference_logger.addHandler(file_handler)
 33  
 34  def print_nous_text_art(suffix=None):
 35      font = "nancyj"
 36      ascii_text = "  nousresearch"
 37      if suffix:
 38          ascii_text += f"  x  {suffix}"
 39      ascii_art = text2art(ascii_text, font=font)
 40      print(ascii_art)
 41  
 42  def get_fewshot_examples(num_fewshot):
 43      """return a list of few shot examples"""
 44      example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
 45      with open(example_path, 'r') as file:
 46          examples = json.load(file)  # Use json.load with the file object, not the file path
 47      if num_fewshot > len(examples):
 48          raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
 49      return examples[:num_fewshot]
 50  
 51  def get_chat_template(chat_template):
 52      """read chat template from jinja file"""
 53      template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
 54  
 55      if not os.path.exists(template_path):
 56          print
 57          inference_logger.error(f"Template file not found: {chat_template}")
 58          return None
 59      try:
 60          with open(template_path, 'r') as file:
 61              template = file.read()
 62          return template
 63      except Exception as e:
 64          print(f"Error loading template: {e}")
 65          return None
 66  
 67  def get_assistant_message(completion, chat_template, eos_token):
 68      """define and match pattern to find the assistant message"""
 69      completion = completion.strip()
 70  
 71      if chat_template == "zephyr":
 72          assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL)
 73      elif chat_template == "chatml":
 74          assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL)
 75  
 76      elif chat_template == "vicuna":
 77          assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL)
 78      else:
 79          raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
 80      
 81      assistant_match = assistant_pattern.search(completion)
 82      if assistant_match:
 83          assistant_content = assistant_match.group(1).strip()
 84          if chat_template == "vicuna":
 85              eos_token = f"</s>{eos_token}"
 86          return assistant_content.replace(eos_token, "")
 87      else:
 88          assistant_content = None
 89          inference_logger.info("No match found for the assistant pattern")
 90          return assistant_content
 91  
 92  def validate_and_extract_tool_calls(assistant_content):
 93      validation_result = False
 94      tool_calls = []
 95      error_message = None
 96  
 97      try:
 98          # wrap content in root element
 99          xml_root_element = f"<root>{assistant_content}</root>"
100          root = ET.fromstring(xml_root_element)
101  
102          # extract JSON data
103          for element in root.findall(".//tool_call"):
104              json_data = None
105              try:
106                  json_text = element.text.strip()
107  
108                  try:
109                      # Prioritize json.loads for better error handling
110                      json_data = json.loads(json_text)
111                  except json.JSONDecodeError as json_err:
112                      try:
113                          # Fallback to ast.literal_eval if json.loads fails
114                          json_data = ast.literal_eval(json_text)
115                      except (SyntaxError, ValueError) as eval_err:
116                          error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\
117                                          f"- JSON Decode Error: {json_err}\n"\
118                                          f"- Fallback Syntax/Value Error: {eval_err}\n"\
119                                          f"- Problematic JSON text: {json_text}"
120                          inference_logger.error(error_message)
121                          continue
122              except Exception as e:
123                  error_message = f"Cannot strip text: {e}"
124                  inference_logger.error(error_message)
125  
126              if json_data is not None:
127                  tool_calls.append(json_data)
128                  validation_result = True
129  
130      except ET.ParseError as err:
131          error_message = f"XML Parse Error: {err}"
132          inference_logger.error(f"XML Parse Error: {err}")
133  
134      # Return default values if no valid data is extracted
135      return validation_result, tool_calls, error_message
136  
137  def extract_json_from_markdown(text):
138      """
139      Extracts the JSON string from the given text using a regular expression pattern.
140      
141      Args:
142          text (str): The input text containing the JSON string.
143          
144      Returns:
145          dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
146      """
147      json_pattern = r'```json\r?\n(.*?)\r?\n```'
148      match = re.search(json_pattern, text, re.DOTALL)
149      if match:
150          json_string = match.group(1)
151          try:
152              data = json.loads(json_string)
153              return data
154          except json.JSONDecodeError as e:
155              print(f"Error decoding JSON string: {e}")
156      else:
157          print("JSON string not found in the text.")
158      return None
159