/ validator.py
validator.py
  1  import ast
  2  import json
  3  from jsonschema import validate
  4  from pydantic import ValidationError
  5  from utils import inference_logger, extract_json_from_markdown
  6  from schema import FunctionCall, FunctionSignature
  7  
  8  def validate_function_call_schema(call, signatures):
  9      try:
 10          call_data = FunctionCall(**call)
 11      except ValidationError as e:
 12          return False, str(e)
 13  
 14      for signature in signatures:
 15          try:
 16              signature_data = FunctionSignature(**signature)
 17              if signature_data.function.name == call_data.name:
 18                  # Validate types in function arguments
 19                  for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
 20                      if arg_name in call_data.arguments:
 21                          call_arg_value = call_data.arguments[arg_name]
 22                          if call_arg_value:
 23                              try:
 24                                  validate_argument_type(arg_name, call_arg_value, arg_schema)
 25                              except Exception as arg_validation_error:
 26                                  return False, str(arg_validation_error)
 27  
 28                  # Check if all required arguments are present
 29                  required_arguments = signature_data.function.parameters.get('required', [])
 30                  result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
 31                  if not result:
 32                      return False, f"Missing required arguments: {missing_arguments}"
 33  
 34                  return True, None
 35          except Exception as e:
 36              # Handle validation errors for the function signature
 37              return False, str(e)
 38  
 39      # No matching function signature found
 40      return False, f"No matching function signature found for function: {call_data.name}"
 41  
 42  def check_required_arguments(call_arguments, required_arguments):
 43      missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
 44      return not bool(missing_arguments), missing_arguments
 45  
 46  def validate_enum_value(arg_name, arg_value, enum_values):
 47      if arg_value not in enum_values:
 48          raise Exception(
 49              f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
 50          )
 51  
 52  def validate_argument_type(arg_name, arg_value, arg_schema):
 53      arg_type = arg_schema.get('type', None)
 54      if arg_type:
 55          if arg_type == 'string' and 'enum' in arg_schema:
 56              enum_values = arg_schema['enum']
 57              if None not in enum_values and enum_values != []:
 58                  try:
 59                      validate_enum_value(arg_name, arg_value, enum_values)
 60                  except Exception as e:
 61                      # Propagate the validation error message
 62                      raise Exception(f"Error validating function call: {e}")
 63  
 64          python_type = get_python_type(arg_type)
 65          if not isinstance(arg_value, python_type):
 66              raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
 67  
 68  def get_python_type(json_type):
 69      type_mapping = {
 70          'string': str,
 71          'number': (int, float),
 72          'integer': int,
 73          'boolean': bool,
 74          'array': list,
 75          'object': dict,
 76          'null': type(None),
 77      }
 78      return type_mapping[json_type]
 79  
 80  def validate_json_data(json_object, json_schema):
 81      valid = False
 82      error_message = None
 83      result_json = None
 84  
 85      try:
 86          # Attempt to load JSON using json.loads
 87          try:
 88              result_json = json.loads(json_object)
 89          except json.decoder.JSONDecodeError:
 90              # If json.loads fails, try ast.literal_eval
 91              try:
 92                  result_json = ast.literal_eval(json_object)
 93              except (SyntaxError, ValueError) as e:
 94                  try:
 95                      result_json = extract_json_from_markdown(json_object)
 96                  except Exception as e:
 97                      error_message = f"JSON decoding error: {e}"
 98                      inference_logger.info(f"Validation failed for JSON data: {error_message}")
 99                      return valid, result_json, error_message
100  
101          # Return early if both json.loads and ast.literal_eval fail
102          if result_json is None:
103              error_message = "Failed to decode JSON data"
104              inference_logger.info(f"Validation failed for JSON data: {error_message}")
105              return valid, result_json, error_message
106  
107          # Validate each item in the list against schema if it's a list
108          if isinstance(result_json, list):
109              for index, item in enumerate(result_json):
110                  try:
111                      validate(instance=item, schema=json_schema)
112                      inference_logger.info(f"Item {index+1} is valid against the schema.")
113                  except ValidationError as e:
114                      error_message = f"Validation failed for item {index+1}: {e}"
115                      break
116          else:
117              # Default to validation without list
118              try:
119                  validate(instance=result_json, schema=json_schema)
120              except ValidationError as e:
121                  error_message = f"Validation failed: {e}"
122  
123      except Exception as e:
124          error_message = f"Error occurred: {e}"
125  
126      if error_message is None:
127          valid = True
128          inference_logger.info("JSON data is valid against the schema.")
129      else:
130          inference_logger.info(f"Validation failed for JSON data: {error_message}")
131  
132      return valid, result_json, error_message