/ test / components / validators / test_json_schema.py
test_json_schema.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import json
  6  
  7  import pytest
  8  
  9  from haystack import Pipeline, component
 10  from haystack.components.validators import JsonSchemaValidator
 11  from haystack.dataclasses import ChatMessage
 12  
 13  
 14  @pytest.fixture
 15  def genuine_fc_message():
 16      return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n    \\"basehead\\": \\"main...amzn_chat\\",\\n    \\"owner\\": \\"deepset-ai\\",\\n    \\"repo\\": \\"haystack-core-integrations\\"\\n  }", "name": "compare_branches"}, "type": "function"}]"""  # noqa: E501
 17  
 18  
 19  @pytest.fixture
 20  def json_schema_github_compare():
 21      return {
 22          "type": "object",
 23          "properties": {
 24              "id": {"type": "string", "description": "A unique identifier for the call"},
 25              "function": {
 26                  "type": "object",
 27                  "properties": {
 28                      "arguments": {
 29                          "type": "object",
 30                          "properties": {
 31                              "basehead": {
 32                                  "type": "string",
 33                                  "pattern": "^[^\\.]+(\\.{3}).+$",
 34                                  "description": "Branch names must be in the format 'base_branch...head_branch'",
 35                              },
 36                              "owner": {"type": "string", "description": "Owner of the repository"},
 37                              "repo": {"type": "string", "description": "Name of the repository"},
 38                          },
 39                          "required": ["basehead", "owner", "repo"],
 40                          "description": "Parameters for the function call",
 41                      },
 42                      "name": {"type": "string", "description": "Name of the function to be called"},
 43                  },
 44                  "required": ["arguments", "name"],
 45                  "description": "Details of the function being called",
 46              },
 47              "type": {"type": "string", "description": "Type of the call (e.g., 'function')"},
 48          },
 49          "required": ["function", "type"],
 50          "description": "Structure representing a function call",
 51      }
 52  
 53  
 54  @pytest.fixture
 55  def json_schema_github_compare_openai():
 56      return {
 57          "name": "compare_branches",
 58          "description": "Compares two branches in a GitHub repository",
 59          "parameters": {
 60              "type": "object",
 61              "properties": {
 62                  "basehead": {
 63                      "type": "string",
 64                      "pattern": "^[^\\.]+(\\.{3}).+$",
 65                      "description": "Branch names must be in the format 'base_branch...head_branch'",
 66                  },
 67                  "owner": {"type": "string", "description": "Owner of the repository"},
 68                  "repo": {"type": "string", "description": "Name of the repository"},
 69              },
 70              "required": ["basehead", "owner", "repo"],
 71              "description": "Parameters for the function call",
 72          },
 73      }
 74  
 75  
 76  class TestJsonSchemaValidator:
 77      #  Validates a message against a provided JSON schema successfully.
 78      def test_validates_message_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
 79          validator = JsonSchemaValidator()
 80          message = ChatMessage.from_assistant(genuine_fc_message)
 81  
 82          result = validator.run([message], json_schema_github_compare)
 83  
 84          assert "validated" in result
 85          assert len(result["validated"]) == 1
 86          assert result["validated"][0] == message
 87  
 88      # Validates recursive_json_to_object method
 89      def test_recursive_json_to_object(self, genuine_fc_message):
 90          arguments_is_string = json.loads(genuine_fc_message)
 91          assert isinstance(arguments_is_string[0]["function"]["arguments"], str)
 92  
 93          # but ensure_json_objects converts the string to a json object
 94          validator = JsonSchemaValidator()
 95          result = validator._recursive_json_to_object({"key": genuine_fc_message})
 96  
 97          # we need this recursive json conversion to validate the message
 98          assert result["key"][0]["function"]["arguments"]["basehead"] == "main...amzn_chat"
 99  
100      #  Validates multiple messages against a provided JSON schema successfully.
101      def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
102          validator = JsonSchemaValidator()
103  
104          messages = [
105              ChatMessage.from_user("I'm not being validated, but the message after me is!"),
106              ChatMessage.from_assistant(genuine_fc_message),
107          ]
108  
109          result = validator.run(messages, json_schema_github_compare)
110          assert "validated" in result
111          assert len(result["validated"]) == 1
112          assert result["validated"][0] == messages[1]
113  
114      #  Validates a message against an OpenAI function calling schema successfully.
115      def test_validates_message_against_openai_function_calling_schema(
116          self, json_schema_github_compare_openai, genuine_fc_message
117      ):
118          validator = JsonSchemaValidator()
119  
120          message = ChatMessage.from_assistant(genuine_fc_message)
121          result = validator.run([message], json_schema_github_compare_openai)
122  
123          assert "validated" in result
124          assert len(result["validated"]) == 1
125          assert result["validated"][0] == message
126  
127      #  Validates multiple messages against an OpenAI function calling schema successfully.
128      def test_validates_multiple_messages_against_openai_function_calling_schema(
129          self, json_schema_github_compare_openai, genuine_fc_message
130      ):
131          validator = JsonSchemaValidator()
132  
133          messages = [
134              ChatMessage.from_system("Common use case is that this is for example system message"),
135              ChatMessage.from_assistant(genuine_fc_message),
136          ]
137  
138          result = validator.run(messages, json_schema_github_compare_openai)
139  
140          assert "validated" in result
141          assert len(result["validated"]) == 1
142          assert result["validated"][0] == messages[1]
143  
144      #  Constructs a custom error recovery message when validation fails.
145      def test_construct_custom_error_recovery_message(self):
146          validator = JsonSchemaValidator()
147  
148          new_error_template = (
149              "Error details:\n- Message: {error_message}\n"
150              "- Error Path in JSON: {error_path}\n"
151              "- Schema Path: {error_schema_path}\n"
152              "Please match the following schema:\n"
153              "{json_schema}\n"
154              "Failing Json: {failing_json}\n"
155          )
156  
157          recovery_message = validator._construct_error_recovery_message(
158              new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}, "Failing Json"
159          )
160  
161          expected_recovery_message = (
162              "Error details:\n- Message: Error message\n"
163              "- Error Path in JSON: Error path\n"
164              "- Schema Path: Error schema path\n"
165              "Please match the following schema:\n"
166              "{'type': 'object'}\n"
167              "Failing Json: Failing Json\n"
168          )
169          assert recovery_message == expected_recovery_message
170  
171      def test_schema_validator_in_pipeline_validated(self, json_schema_github_compare, genuine_fc_message):
172          @component
173          class ChatMessageProducer:
174              @component.output_types(messages=list[ChatMessage])
175              def run(self):
176                  return {"messages": [ChatMessage.from_assistant(genuine_fc_message)]}
177  
178          pipe = Pipeline()
179          pipe.add_component(name="schema_validator", instance=JsonSchemaValidator())
180          pipe.add_component(name="message_producer", instance=ChatMessageProducer())
181          pipe.connect("message_producer", "schema_validator")
182          result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
183          assert "validated" in result["schema_validator"]
184          assert len(result["schema_validator"]["validated"]) == 1
185          assert result["schema_validator"]["validated"][0].text == genuine_fc_message
186  
187      def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare):
188          @component
189          class ChatMessageProducer:
190              @component.output_types(messages=list[ChatMessage])
191              def run(self):
192                  # example json string that is not valid
193                  simple_invalid_json = '{"key": "value"}'
194                  return {"messages": [ChatMessage.from_assistant(simple_invalid_json)]}  # invalid message
195  
196          pipe = Pipeline()
197          pipe.add_component(name="schema_validator", instance=JsonSchemaValidator())
198          pipe.add_component(name="message_producer", instance=ChatMessageProducer())
199          pipe.connect("message_producer", "schema_validator")
200          result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
201          assert "validation_error" in result["schema_validator"]
202          assert len(result["schema_validator"]["validation_error"]) == 1
203          assert "Error details" in result["schema_validator"]["validation_error"][0].text