test_json_schema.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import json 6 7 import pytest 8 9 from haystack import Pipeline, component 10 from haystack.components.validators import JsonSchemaValidator 11 from haystack.dataclasses import ChatMessage 12 13 14 @pytest.fixture 15 def genuine_fc_message(): 16 return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }", "name": "compare_branches"}, "type": "function"}]""" # noqa: E501 17 18 19 @pytest.fixture 20 def json_schema_github_compare(): 21 return { 22 "type": "object", 23 "properties": { 24 "id": {"type": "string", "description": "A unique identifier for the call"}, 25 "function": { 26 "type": "object", 27 "properties": { 28 "arguments": { 29 "type": "object", 30 "properties": { 31 "basehead": { 32 "type": "string", 33 "pattern": "^[^\\.]+(\\.{3}).+$", 34 "description": "Branch names must be in the format 'base_branch...head_branch'", 35 }, 36 "owner": {"type": "string", "description": "Owner of the repository"}, 37 "repo": {"type": "string", "description": "Name of the repository"}, 38 }, 39 "required": ["basehead", "owner", "repo"], 40 "description": "Parameters for the function call", 41 }, 42 "name": {"type": "string", "description": "Name of the function to be called"}, 43 }, 44 "required": ["arguments", "name"], 45 "description": "Details of the function being called", 46 }, 47 "type": {"type": "string", "description": "Type of the call (e.g., 'function')"}, 48 }, 49 "required": ["function", "type"], 50 "description": "Structure representing a function call", 51 } 52 53 54 @pytest.fixture 55 def json_schema_github_compare_openai(): 56 return { 57 "name": "compare_branches", 58 "description": "Compares two branches in a GitHub repository", 59 "parameters": { 60 "type": "object", 61 "properties": { 62 "basehead": { 63 "type": "string", 64 "pattern": "^[^\\.]+(\\.{3}).+$", 65 "description": "Branch names must be in the format 'base_branch...head_branch'", 66 }, 67 "owner": {"type": "string", "description": "Owner of the repository"}, 68 "repo": {"type": "string", "description": "Name of the repository"}, 69 }, 70 "required": ["basehead", "owner", "repo"], 71 "description": "Parameters for the function call", 72 }, 73 } 74 75 76 class TestJsonSchemaValidator: 77 # Validates a message against a provided JSON schema successfully. 78 def test_validates_message_against_json_schema(self, json_schema_github_compare, genuine_fc_message): 79 validator = JsonSchemaValidator() 80 message = ChatMessage.from_assistant(genuine_fc_message) 81 82 result = validator.run([message], json_schema_github_compare) 83 84 assert "validated" in result 85 assert len(result["validated"]) == 1 86 assert result["validated"][0] == message 87 88 # Validates recursive_json_to_object method 89 def test_recursive_json_to_object(self, genuine_fc_message): 90 arguments_is_string = json.loads(genuine_fc_message) 91 assert isinstance(arguments_is_string[0]["function"]["arguments"], str) 92 93 # but ensure_json_objects converts the string to a json object 94 validator = JsonSchemaValidator() 95 result = validator._recursive_json_to_object({"key": genuine_fc_message}) 96 97 # we need this recursive json conversion to validate the message 98 assert result["key"][0]["function"]["arguments"]["basehead"] == "main...amzn_chat" 99 100 # Validates multiple messages against a provided JSON schema successfully. 101 def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message): 102 validator = JsonSchemaValidator() 103 104 messages = [ 105 ChatMessage.from_user("I'm not being validated, but the message after me is!"), 106 ChatMessage.from_assistant(genuine_fc_message), 107 ] 108 109 result = validator.run(messages, json_schema_github_compare) 110 assert "validated" in result 111 assert len(result["validated"]) == 1 112 assert result["validated"][0] == messages[1] 113 114 # Validates a message against an OpenAI function calling schema successfully. 115 def test_validates_message_against_openai_function_calling_schema( 116 self, json_schema_github_compare_openai, genuine_fc_message 117 ): 118 validator = JsonSchemaValidator() 119 120 message = ChatMessage.from_assistant(genuine_fc_message) 121 result = validator.run([message], json_schema_github_compare_openai) 122 123 assert "validated" in result 124 assert len(result["validated"]) == 1 125 assert result["validated"][0] == message 126 127 # Validates multiple messages against an OpenAI function calling schema successfully. 128 def test_validates_multiple_messages_against_openai_function_calling_schema( 129 self, json_schema_github_compare_openai, genuine_fc_message 130 ): 131 validator = JsonSchemaValidator() 132 133 messages = [ 134 ChatMessage.from_system("Common use case is that this is for example system message"), 135 ChatMessage.from_assistant(genuine_fc_message), 136 ] 137 138 result = validator.run(messages, json_schema_github_compare_openai) 139 140 assert "validated" in result 141 assert len(result["validated"]) == 1 142 assert result["validated"][0] == messages[1] 143 144 # Constructs a custom error recovery message when validation fails. 145 def test_construct_custom_error_recovery_message(self): 146 validator = JsonSchemaValidator() 147 148 new_error_template = ( 149 "Error details:\n- Message: {error_message}\n" 150 "- Error Path in JSON: {error_path}\n" 151 "- Schema Path: {error_schema_path}\n" 152 "Please match the following schema:\n" 153 "{json_schema}\n" 154 "Failing Json: {failing_json}\n" 155 ) 156 157 recovery_message = validator._construct_error_recovery_message( 158 new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}, "Failing Json" 159 ) 160 161 expected_recovery_message = ( 162 "Error details:\n- Message: Error message\n" 163 "- Error Path in JSON: Error path\n" 164 "- Schema Path: Error schema path\n" 165 "Please match the following schema:\n" 166 "{'type': 'object'}\n" 167 "Failing Json: Failing Json\n" 168 ) 169 assert recovery_message == expected_recovery_message 170 171 def test_schema_validator_in_pipeline_validated(self, json_schema_github_compare, genuine_fc_message): 172 @component 173 class ChatMessageProducer: 174 @component.output_types(messages=list[ChatMessage]) 175 def run(self): 176 return {"messages": [ChatMessage.from_assistant(genuine_fc_message)]} 177 178 pipe = Pipeline() 179 pipe.add_component(name="schema_validator", instance=JsonSchemaValidator()) 180 pipe.add_component(name="message_producer", instance=ChatMessageProducer()) 181 pipe.connect("message_producer", "schema_validator") 182 result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}}) 183 assert "validated" in result["schema_validator"] 184 assert len(result["schema_validator"]["validated"]) == 1 185 assert result["schema_validator"]["validated"][0].text == genuine_fc_message 186 187 def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare): 188 @component 189 class ChatMessageProducer: 190 @component.output_types(messages=list[ChatMessage]) 191 def run(self): 192 # example json string that is not valid 193 simple_invalid_json = '{"key": "value"}' 194 return {"messages": [ChatMessage.from_assistant(simple_invalid_json)]} # invalid message 195 196 pipe = Pipeline() 197 pipe.add_component(name="schema_validator", instance=JsonSchemaValidator()) 198 pipe.add_component(name="message_producer", instance=ChatMessageProducer()) 199 pipe.connect("message_producer", "schema_validator") 200 result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}}) 201 assert "validation_error" in result["schema_validator"] 202 assert len(result["schema_validator"]["validation_error"]) == 1 203 assert "Error details" in result["schema_validator"]["validation_error"][0].text