test_state_class.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import inspect 6 from dataclasses import dataclass 7 from typing import Dict, Generic, List, Optional, TypeVar, Union 8 9 import pytest 10 11 from haystack.components.agents.state.state import ( 12 State, 13 _is_list_type, 14 _is_valid_type, 15 _schema_from_dict, 16 _schema_to_dict, 17 _validate_schema, 18 merge_lists, 19 ) 20 from haystack.dataclasses import ChatMessage 21 22 23 @pytest.fixture 24 def basic_schema(): 25 return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}} 26 27 28 def numbers_handler(current, new): 29 if current is None: 30 return sorted(set(new)) 31 return sorted(set(current + new)) 32 33 34 @pytest.fixture 35 def complex_schema(): 36 return {"numbers": {"type": list, "handler": numbers_handler}, "metadata": {"type": dict}, "name": {"type": str}} 37 38 39 def test_is_list_type(): 40 assert _is_list_type(list) is True 41 assert _is_list_type(list[int]) is True 42 assert _is_list_type(list[str]) is True 43 assert _is_list_type(dict) is False 44 assert _is_list_type(int) is False 45 assert _is_list_type(Union[list[int], None]) is False 46 assert _is_list_type(list[int] | None) is False 47 48 49 class TestMergeLists: 50 def test_merge_two_lists(self): 51 current = [1, 2, 3] 52 new = [4, 5, 6] 53 result = merge_lists(current, new) 54 assert result == [1, 2, 3, 4, 5, 6] 55 # Ensure original lists weren't modified 56 assert current == [1, 2, 3] 57 assert new == [4, 5, 6] 58 59 def test_append_to_list(self): 60 current = [1, 2, 3] 61 new = 4 62 result = merge_lists(current, new) 63 assert result == [1, 2, 3, 4] 64 assert current == [1, 2, 3] # Ensure original wasn't modified 65 66 def test_create_new_list(self): 67 current = 1 68 new = 2 69 result = merge_lists(current, new) 70 assert result == [1, 2] 71 72 def test_replace_with_list(self): 73 current = 1 74 new = [2, 3] 75 result = merge_lists(current, new) 76 assert result == [1, 2, 3] 77 78 79 class TestIsValidType: 80 def test_builtin_types(self): 81 assert _is_valid_type(str) is True 82 assert _is_valid_type(int) is True 83 assert _is_valid_type(dict) is True 84 assert _is_valid_type(list) is True 85 assert _is_valid_type(tuple) is True 86 assert _is_valid_type(set) is True 87 assert _is_valid_type(bool) is True 88 assert _is_valid_type(float) is True 89 90 def test_generic_types(self): 91 assert _is_valid_type(list[str]) is True 92 assert _is_valid_type(List[str]) is True 93 assert _is_valid_type(dict[str, int]) is True 94 assert _is_valid_type(Dict[str, int]) is True 95 assert _is_valid_type(list[dict[str, int]]) is True 96 assert _is_valid_type(List[Dict[str, int]]) is True 97 assert _is_valid_type(dict[str, list[int]]) is True 98 assert _is_valid_type(Dict[str, List[int]]) is True 99 100 def test_custom_classes(self): 101 @dataclass 102 class CustomClass: 103 value: int 104 105 T = TypeVar("T") 106 107 class GenericCustomClass(Generic[T]): 108 pass 109 110 # Test regular and generic custom classes 111 assert _is_valid_type(CustomClass) is True 112 assert _is_valid_type(GenericCustomClass) is True 113 assert _is_valid_type(GenericCustomClass[int]) is True 114 115 # Test generic types with custom classes 116 assert _is_valid_type(list[CustomClass]) is True 117 assert _is_valid_type(List[CustomClass]) is True 118 assert _is_valid_type(dict[str, CustomClass]) is True 119 assert _is_valid_type(Dict[str, CustomClass]) is True 120 assert _is_valid_type(dict[str, GenericCustomClass[int]]) is True 121 assert _is_valid_type(Dict[str, GenericCustomClass[int]]) is True 122 123 def test_invalid_types(self): 124 # Test regular values 125 assert _is_valid_type(42) is False 126 assert _is_valid_type("string") is False 127 assert _is_valid_type([1, 2, 3]) is False 128 assert _is_valid_type({"a": 1}) is False 129 assert _is_valid_type(True) is False 130 131 # Test class instances 132 @dataclass 133 class SampleClass: 134 value: int 135 136 instance = SampleClass(42) 137 assert _is_valid_type(instance) is False 138 139 # Test callable objects 140 assert _is_valid_type(len) is False 141 assert _is_valid_type(lambda x: x) is False 142 assert _is_valid_type(print) is False 143 144 def test_union_and_optional_types(self): 145 # Test basic Union types 146 assert _is_valid_type(Union[str, int]) is True 147 assert _is_valid_type(Union[str, None]) is True 148 assert _is_valid_type(Union[list[int], dict[str, str]]) is True 149 150 # Test Optional types (which are Union[T, None]) 151 assert _is_valid_type(Optional[str]) is True 152 assert _is_valid_type(Optional[list[int]]) is True 153 assert _is_valid_type(Optional[dict[str, list]]) is True 154 155 # Test that Union itself is not a valid type (only instantiated Unions are) 156 assert _is_valid_type(Union) is False 157 158 # Test PEP 604 union types (X | Y syntax) 159 assert _is_valid_type(str | int) is True 160 assert _is_valid_type(str | None) is True 161 assert _is_valid_type(list[int] | dict[str, str]) is True 162 163 # Test PEP 604 Optional-like types (X | None syntax) 164 assert _is_valid_type(list[int] | None) is True 165 assert _is_valid_type(dict[str, list] | None) is True 166 167 def test_nested_generic_types(self): 168 assert _is_valid_type(list[list[dict[str, list[int]]]]) is True 169 assert _is_valid_type(dict[str, list[dict[str, set]]]) is True 170 assert _is_valid_type(dict[str, Optional[list[int]]]) is True 171 assert _is_valid_type(list[Union[str, dict[str, list[int]]]]) is True 172 # PEP 604 nested types 173 assert _is_valid_type(dict[str, list[int] | None]) is True 174 assert _is_valid_type(list[str | dict[str, list[int]]]) is True 175 176 def test_edge_cases(self): 177 # Test None and NoneType 178 assert _is_valid_type(None) is False 179 assert _is_valid_type(type(None)) is True 180 181 # Test functions and methods 182 def sample_func(): 183 pass 184 185 assert _is_valid_type(sample_func) is False 186 assert _is_valid_type(type(sample_func)) is True 187 188 # Test modules 189 assert _is_valid_type(inspect) is False 190 191 # Test type itself 192 assert _is_valid_type(type) is True 193 194 @pytest.mark.parametrize( 195 "test_input,expected", 196 [ 197 (str, True), 198 (int, True), 199 (list[int], True), 200 (dict[str, int], True), 201 (List[int], True), 202 (Dict[str, int], True), 203 (Union[str, int], True), 204 (Optional[str], True), 205 # PEP 604 union types 206 (str | int, True), 207 (str | None, True), 208 (list[int] | None, True), 209 (42, False), 210 ("string", False), 211 ([1, 2, 3], False), 212 (lambda x: x, False), 213 ], 214 ) 215 def test_parametrized_cases(self, test_input, expected): 216 assert _is_valid_type(test_input) is expected 217 218 219 class TestState: 220 def test_validate_schema_valid(self, basic_schema): 221 # Should not raise any exceptions 222 _validate_schema(basic_schema) 223 224 def test_validate_schema_invalid_type(self): 225 invalid_schema = {"test": {"type": "not_a_type"}} 226 with pytest.raises(ValueError, match="must be a Python type"): 227 _validate_schema(invalid_schema) 228 229 def test_validate_schema_missing_type(self): 230 invalid_schema = {"test": {"handler": lambda x, y: x + y}} 231 with pytest.raises(ValueError, match="missing a 'type' entry"): 232 _validate_schema(invalid_schema) 233 234 def test_validate_schema_invalid_handler(self): 235 invalid_schema = {"test": {"type": list, "handler": "not_callable"}} 236 with pytest.raises(ValueError, match="must be callable or None"): 237 _validate_schema(invalid_schema) 238 239 def test_validate_schema_with_messages(self): 240 class ChatMessageSubclass(ChatMessage): 241 pass 242 243 schema_with_messages = {"messages": {"type": List[ChatMessage]}} 244 _validate_schema(schema_with_messages) 245 246 schema_with_messages_subclass = {"messages": {"type": List[ChatMessageSubclass]}} 247 _validate_schema(schema_with_messages_subclass) 248 249 def test_state_initialization(self, basic_schema): 250 # Test empty initialization 251 state = State(basic_schema) 252 assert state.data == {} 253 254 # Test initialization with data 255 initial_data = {"numbers": [1, 2, 3], "name": "test"} 256 state = State(basic_schema, initial_data) 257 assert state.data["numbers"] == [1, 2, 3] 258 assert state.data["name"] == "test" 259 260 def test_state_get(self, basic_schema): 261 state = State(basic_schema, {"name": "test"}) 262 assert state.get("name") == "test" 263 assert state.get("non_existent") is None 264 assert state.get("non_existent", "default") == "default" 265 266 def test_state_set_basic(self, basic_schema): 267 state = State(basic_schema) 268 269 # Test setting new values 270 state.set("numbers", [1, 2]) 271 assert state.get("numbers") == [1, 2] 272 273 # Test updating existing values 274 state.set("numbers", [3, 4]) 275 assert state.get("numbers") == [1, 2, 3, 4] 276 277 def test_state_set_with_handler(self, complex_schema): 278 state = State(complex_schema) 279 280 # Test custom handler for numbers 281 state.set("numbers", [3, 2, 1]) 282 assert state.get("numbers") == [1, 2, 3] 283 284 state.set("numbers", [6, 5, 4]) 285 assert state.get("numbers") == [1, 2, 3, 4, 5, 6] 286 287 def test_state_set_with_handler_override(self, basic_schema): 288 state = State(basic_schema) 289 290 # Custom handler that concatenates strings 291 custom_handler = lambda current, new: f"{current}-{new}" if current else new 292 293 state.set("name", "first") 294 state.set("name", "second", handler_override=custom_handler) 295 assert state.get("name") == "first-second" 296 297 def test_state_has(self, basic_schema): 298 state = State(basic_schema, {"name": "test"}) 299 assert state.has("name") is True 300 assert state.has("non_existent") is False 301 302 def test_state_empty_schema(self): 303 state = State({}) 304 assert state.data == {} 305 306 # Instead of comparing the entire schema directly, check structure separately 307 assert "messages" in state.schema 308 assert state.schema["messages"]["type"] == list[ChatMessage] 309 assert callable(state.schema["messages"]["handler"]) 310 311 with pytest.raises(ValueError, match="Key 'any_key' not found in schema"): 312 state.set("any_key", "value") 313 314 def test_state_none_values(self, basic_schema): 315 state = State(basic_schema) 316 state.set("name", None) 317 assert state.get("name") is None 318 state.set("name", "value") 319 assert state.get("name") == "value" 320 321 def test_state_merge_lists(self, basic_schema): 322 state = State(basic_schema) 323 state.set("numbers", "not_a_list") 324 assert state.get("numbers") == ["not_a_list"] 325 state.set("numbers", [1, 2]) 326 assert state.get("numbers") == ["not_a_list", 1, 2] 327 328 def test_state_nested_structures(self): 329 schema = { 330 "complex": { 331 "type": dict[str, list[int]], 332 "handler": lambda current, new: ( 333 {k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys())} 334 if current 335 else new 336 ), 337 } 338 } 339 340 state = State(schema) 341 state.set("complex", {"a": [1, 2], "b": [3, 4]}) 342 state.set("complex", {"b": [5, 6], "c": [7, 8]}) 343 344 expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]} 345 assert state.get("complex") == expected 346 347 def test_schema_to_dict(self, basic_schema): 348 expected_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}} 349 result = _schema_to_dict(basic_schema) 350 assert result == expected_dict 351 352 def test_schema_to_dict_with_handlers(self, complex_schema): 353 expected_dict = { 354 "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"}, 355 "metadata": {"type": "dict"}, 356 "name": {"type": "str"}, 357 } 358 result = _schema_to_dict(complex_schema) 359 assert result == expected_dict 360 361 def test_schema_from_dict(self, basic_schema): 362 schema_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}} 363 result = _schema_from_dict(schema_dict) 364 assert result == basic_schema 365 366 def test_schema_from_dict_with_handlers(self, complex_schema): 367 schema_dict = { 368 "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"}, 369 "metadata": {"type": "dict"}, 370 "name": {"type": "str"}, 371 } 372 result = _schema_from_dict(schema_dict) 373 assert result == complex_schema 374 375 def test_state_mutability(self): 376 state = State({"my_list": {"type": list}}, {"my_list": [1, 2]}) 377 378 my_list = state.get("my_list") 379 my_list.append(3) 380 381 assert state.get("my_list") == [1, 2] 382 383 def test_state_to_dict(self): 384 # we test dict, a python type and a haystack dataclass 385 state_schema = { 386 "numbers": {"type": int}, 387 "messages": {"type": list[ChatMessage]}, 388 "dict_of_lists": {"type": dict}, 389 } 390 391 data = { 392 "numbers": 1, 393 "messages": [ChatMessage.from_user(text="Hello, world!")], 394 "dict_of_lists": {"numbers": [1, 2, 3]}, 395 } 396 state = State(state_schema, data) 397 state_dict = state.to_dict() 398 assert state_dict["schema"] == { 399 "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, 400 "messages": { 401 "type": "list[haystack.dataclasses.chat_message.ChatMessage]", 402 "handler": "haystack.components.agents.state.state_utils.merge_lists", 403 }, 404 "dict_of_lists": {"type": "dict", "handler": "haystack.components.agents.state.state_utils.replace_values"}, 405 } 406 assert state_dict["data"] == { 407 "serialization_schema": { 408 "type": "object", 409 "properties": { 410 "numbers": {"type": "integer"}, 411 "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}}, 412 "dict_of_lists": { 413 "type": "object", 414 "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, 415 }, 416 }, 417 }, 418 "serialized_data": { 419 "numbers": 1, 420 "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], 421 "dict_of_lists": {"numbers": [1, 2, 3]}, 422 }, 423 } 424 425 def test_state_from_dict(self): 426 state_dict = { 427 "schema": { 428 "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, 429 "messages": { 430 "type": "list[haystack.dataclasses.chat_message.ChatMessage]", 431 "handler": "haystack.components.agents.state.state_utils.merge_lists", 432 }, 433 "dict_of_lists": { 434 "type": "dict", 435 "handler": "haystack.components.agents.state.state_utils.replace_values", 436 }, 437 }, 438 "data": { 439 "serialization_schema": { 440 "type": "object", 441 "properties": { 442 "numbers": {"type": "integer"}, 443 "messages": { 444 "type": "array", 445 "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}, 446 }, 447 "dict_of_lists": { 448 "type": "object", 449 "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, 450 }, 451 }, 452 }, 453 "serialized_data": { 454 "numbers": 1, 455 "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], 456 "dict_of_lists": {"numbers": [1, 2, 3]}, 457 }, 458 }, 459 } 460 state = State.from_dict(state_dict) 461 # Check types are correctly converted 462 assert state.schema["numbers"]["type"] == int 463 assert state.schema["dict_of_lists"]["type"] == dict 464 # Check handlers are functions, not comparing exact functions as they might be different references 465 assert callable(state.schema["numbers"]["handler"]) 466 assert callable(state.schema["messages"]["handler"]) 467 assert callable(state.schema["dict_of_lists"]["handler"]) 468 # Check data is correct 469 assert state.data["numbers"] == 1 470 assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] 471 assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]} 472 473 def test_state_to_dict_typing_list(self): 474 # we test dict, a python type and a haystack dataclass 475 state_schema = { 476 "numbers": {"type": int}, 477 "messages": {"type": List[ChatMessage]}, 478 "dict_of_lists": {"type": dict}, 479 } 480 481 data = { 482 "numbers": 1, 483 "messages": [ChatMessage.from_user(text="Hello, world!")], 484 "dict_of_lists": {"numbers": [1, 2, 3]}, 485 } 486 state = State(state_schema, data) 487 state_dict = state.to_dict() 488 assert state_dict["schema"] == { 489 "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, 490 "messages": { 491 "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", 492 "handler": "haystack.components.agents.state.state_utils.merge_lists", 493 }, 494 "dict_of_lists": {"type": "dict", "handler": "haystack.components.agents.state.state_utils.replace_values"}, 495 } 496 assert state_dict["data"] == { 497 "serialization_schema": { 498 "type": "object", 499 "properties": { 500 "numbers": {"type": "integer"}, 501 "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}}, 502 "dict_of_lists": { 503 "type": "object", 504 "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, 505 }, 506 }, 507 }, 508 "serialized_data": { 509 "numbers": 1, 510 "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], 511 "dict_of_lists": {"numbers": [1, 2, 3]}, 512 }, 513 } 514 515 def test_state_from_dict_typing_list(self): 516 state_dict = { 517 "schema": { 518 "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, 519 "messages": { 520 "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", 521 "handler": "haystack.components.agents.state.state_utils.merge_lists", 522 }, 523 "dict_of_lists": { 524 "type": "dict", 525 "handler": "haystack.components.agents.state.state_utils.replace_values", 526 }, 527 }, 528 "data": { 529 "serialization_schema": { 530 "type": "object", 531 "properties": { 532 "numbers": {"type": "integer"}, 533 "messages": { 534 "type": "array", 535 "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}, 536 }, 537 "dict_of_lists": { 538 "type": "object", 539 "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, 540 }, 541 }, 542 }, 543 "serialized_data": { 544 "numbers": 1, 545 "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], 546 "dict_of_lists": {"numbers": [1, 2, 3]}, 547 }, 548 }, 549 } 550 state = State.from_dict(state_dict) 551 # Check types are correctly converted 552 assert state.schema["numbers"]["type"] == int 553 assert state.schema["dict_of_lists"]["type"] == dict 554 # Check handlers are functions, not comparing exact functions as they might be different references 555 assert callable(state.schema["numbers"]["handler"]) 556 assert callable(state.schema["messages"]["handler"]) 557 assert callable(state.schema["dict_of_lists"]["handler"]) 558 # Check data is correct 559 assert state.data["numbers"] == 1 560 assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] 561 assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}