/ test / components / agents / test_state_class.py
test_state_class.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import inspect
  6  from dataclasses import dataclass
  7  from typing import Dict, Generic, List, Optional, TypeVar, Union
  8  
  9  import pytest
 10  
 11  from haystack.components.agents.state.state import (
 12      State,
 13      _is_list_type,
 14      _is_valid_type,
 15      _schema_from_dict,
 16      _schema_to_dict,
 17      _validate_schema,
 18      merge_lists,
 19  )
 20  from haystack.dataclasses import ChatMessage
 21  
 22  
 23  @pytest.fixture
 24  def basic_schema():
 25      return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}}
 26  
 27  
 28  def numbers_handler(current, new):
 29      if current is None:
 30          return sorted(set(new))
 31      return sorted(set(current + new))
 32  
 33  
 34  @pytest.fixture
 35  def complex_schema():
 36      return {"numbers": {"type": list, "handler": numbers_handler}, "metadata": {"type": dict}, "name": {"type": str}}
 37  
 38  
 39  def test_is_list_type():
 40      assert _is_list_type(list) is True
 41      assert _is_list_type(list[int]) is True
 42      assert _is_list_type(list[str]) is True
 43      assert _is_list_type(dict) is False
 44      assert _is_list_type(int) is False
 45      assert _is_list_type(Union[list[int], None]) is False
 46      assert _is_list_type(list[int] | None) is False
 47  
 48  
 49  class TestMergeLists:
 50      def test_merge_two_lists(self):
 51          current = [1, 2, 3]
 52          new = [4, 5, 6]
 53          result = merge_lists(current, new)
 54          assert result == [1, 2, 3, 4, 5, 6]
 55          # Ensure original lists weren't modified
 56          assert current == [1, 2, 3]
 57          assert new == [4, 5, 6]
 58  
 59      def test_append_to_list(self):
 60          current = [1, 2, 3]
 61          new = 4
 62          result = merge_lists(current, new)
 63          assert result == [1, 2, 3, 4]
 64          assert current == [1, 2, 3]  # Ensure original wasn't modified
 65  
 66      def test_create_new_list(self):
 67          current = 1
 68          new = 2
 69          result = merge_lists(current, new)
 70          assert result == [1, 2]
 71  
 72      def test_replace_with_list(self):
 73          current = 1
 74          new = [2, 3]
 75          result = merge_lists(current, new)
 76          assert result == [1, 2, 3]
 77  
 78  
 79  class TestIsValidType:
 80      def test_builtin_types(self):
 81          assert _is_valid_type(str) is True
 82          assert _is_valid_type(int) is True
 83          assert _is_valid_type(dict) is True
 84          assert _is_valid_type(list) is True
 85          assert _is_valid_type(tuple) is True
 86          assert _is_valid_type(set) is True
 87          assert _is_valid_type(bool) is True
 88          assert _is_valid_type(float) is True
 89  
 90      def test_generic_types(self):
 91          assert _is_valid_type(list[str]) is True
 92          assert _is_valid_type(List[str]) is True
 93          assert _is_valid_type(dict[str, int]) is True
 94          assert _is_valid_type(Dict[str, int]) is True
 95          assert _is_valid_type(list[dict[str, int]]) is True
 96          assert _is_valid_type(List[Dict[str, int]]) is True
 97          assert _is_valid_type(dict[str, list[int]]) is True
 98          assert _is_valid_type(Dict[str, List[int]]) is True
 99  
100      def test_custom_classes(self):
101          @dataclass
102          class CustomClass:
103              value: int
104  
105          T = TypeVar("T")
106  
107          class GenericCustomClass(Generic[T]):
108              pass
109  
110          # Test regular and generic custom classes
111          assert _is_valid_type(CustomClass) is True
112          assert _is_valid_type(GenericCustomClass) is True
113          assert _is_valid_type(GenericCustomClass[int]) is True
114  
115          # Test generic types with custom classes
116          assert _is_valid_type(list[CustomClass]) is True
117          assert _is_valid_type(List[CustomClass]) is True
118          assert _is_valid_type(dict[str, CustomClass]) is True
119          assert _is_valid_type(Dict[str, CustomClass]) is True
120          assert _is_valid_type(dict[str, GenericCustomClass[int]]) is True
121          assert _is_valid_type(Dict[str, GenericCustomClass[int]]) is True
122  
123      def test_invalid_types(self):
124          # Test regular values
125          assert _is_valid_type(42) is False
126          assert _is_valid_type("string") is False
127          assert _is_valid_type([1, 2, 3]) is False
128          assert _is_valid_type({"a": 1}) is False
129          assert _is_valid_type(True) is False
130  
131          # Test class instances
132          @dataclass
133          class SampleClass:
134              value: int
135  
136          instance = SampleClass(42)
137          assert _is_valid_type(instance) is False
138  
139          # Test callable objects
140          assert _is_valid_type(len) is False
141          assert _is_valid_type(lambda x: x) is False
142          assert _is_valid_type(print) is False
143  
144      def test_union_and_optional_types(self):
145          # Test basic Union types
146          assert _is_valid_type(Union[str, int]) is True
147          assert _is_valid_type(Union[str, None]) is True
148          assert _is_valid_type(Union[list[int], dict[str, str]]) is True
149  
150          # Test Optional types (which are Union[T, None])
151          assert _is_valid_type(Optional[str]) is True
152          assert _is_valid_type(Optional[list[int]]) is True
153          assert _is_valid_type(Optional[dict[str, list]]) is True
154  
155          # Test that Union itself is not a valid type (only instantiated Unions are)
156          assert _is_valid_type(Union) is False
157  
158          # Test PEP 604 union types (X | Y syntax)
159          assert _is_valid_type(str | int) is True
160          assert _is_valid_type(str | None) is True
161          assert _is_valid_type(list[int] | dict[str, str]) is True
162  
163          # Test PEP 604 Optional-like types (X | None syntax)
164          assert _is_valid_type(list[int] | None) is True
165          assert _is_valid_type(dict[str, list] | None) is True
166  
167      def test_nested_generic_types(self):
168          assert _is_valid_type(list[list[dict[str, list[int]]]]) is True
169          assert _is_valid_type(dict[str, list[dict[str, set]]]) is True
170          assert _is_valid_type(dict[str, Optional[list[int]]]) is True
171          assert _is_valid_type(list[Union[str, dict[str, list[int]]]]) is True
172          # PEP 604 nested types
173          assert _is_valid_type(dict[str, list[int] | None]) is True
174          assert _is_valid_type(list[str | dict[str, list[int]]]) is True
175  
176      def test_edge_cases(self):
177          # Test None and NoneType
178          assert _is_valid_type(None) is False
179          assert _is_valid_type(type(None)) is True
180  
181          # Test functions and methods
182          def sample_func():
183              pass
184  
185          assert _is_valid_type(sample_func) is False
186          assert _is_valid_type(type(sample_func)) is True
187  
188          # Test modules
189          assert _is_valid_type(inspect) is False
190  
191          # Test type itself
192          assert _is_valid_type(type) is True
193  
194      @pytest.mark.parametrize(
195          "test_input,expected",
196          [
197              (str, True),
198              (int, True),
199              (list[int], True),
200              (dict[str, int], True),
201              (List[int], True),
202              (Dict[str, int], True),
203              (Union[str, int], True),
204              (Optional[str], True),
205              # PEP 604 union types
206              (str | int, True),
207              (str | None, True),
208              (list[int] | None, True),
209              (42, False),
210              ("string", False),
211              ([1, 2, 3], False),
212              (lambda x: x, False),
213          ],
214      )
215      def test_parametrized_cases(self, test_input, expected):
216          assert _is_valid_type(test_input) is expected
217  
218  
219  class TestState:
220      def test_validate_schema_valid(self, basic_schema):
221          # Should not raise any exceptions
222          _validate_schema(basic_schema)
223  
224      def test_validate_schema_invalid_type(self):
225          invalid_schema = {"test": {"type": "not_a_type"}}
226          with pytest.raises(ValueError, match="must be a Python type"):
227              _validate_schema(invalid_schema)
228  
229      def test_validate_schema_missing_type(self):
230          invalid_schema = {"test": {"handler": lambda x, y: x + y}}
231          with pytest.raises(ValueError, match="missing a 'type' entry"):
232              _validate_schema(invalid_schema)
233  
234      def test_validate_schema_invalid_handler(self):
235          invalid_schema = {"test": {"type": list, "handler": "not_callable"}}
236          with pytest.raises(ValueError, match="must be callable or None"):
237              _validate_schema(invalid_schema)
238  
239      def test_validate_schema_with_messages(self):
240          class ChatMessageSubclass(ChatMessage):
241              pass
242  
243          schema_with_messages = {"messages": {"type": List[ChatMessage]}}
244          _validate_schema(schema_with_messages)
245  
246          schema_with_messages_subclass = {"messages": {"type": List[ChatMessageSubclass]}}
247          _validate_schema(schema_with_messages_subclass)
248  
249      def test_state_initialization(self, basic_schema):
250          # Test empty initialization
251          state = State(basic_schema)
252          assert state.data == {}
253  
254          # Test initialization with data
255          initial_data = {"numbers": [1, 2, 3], "name": "test"}
256          state = State(basic_schema, initial_data)
257          assert state.data["numbers"] == [1, 2, 3]
258          assert state.data["name"] == "test"
259  
260      def test_state_get(self, basic_schema):
261          state = State(basic_schema, {"name": "test"})
262          assert state.get("name") == "test"
263          assert state.get("non_existent") is None
264          assert state.get("non_existent", "default") == "default"
265  
266      def test_state_set_basic(self, basic_schema):
267          state = State(basic_schema)
268  
269          # Test setting new values
270          state.set("numbers", [1, 2])
271          assert state.get("numbers") == [1, 2]
272  
273          # Test updating existing values
274          state.set("numbers", [3, 4])
275          assert state.get("numbers") == [1, 2, 3, 4]
276  
277      def test_state_set_with_handler(self, complex_schema):
278          state = State(complex_schema)
279  
280          # Test custom handler for numbers
281          state.set("numbers", [3, 2, 1])
282          assert state.get("numbers") == [1, 2, 3]
283  
284          state.set("numbers", [6, 5, 4])
285          assert state.get("numbers") == [1, 2, 3, 4, 5, 6]
286  
287      def test_state_set_with_handler_override(self, basic_schema):
288          state = State(basic_schema)
289  
290          # Custom handler that concatenates strings
291          custom_handler = lambda current, new: f"{current}-{new}" if current else new
292  
293          state.set("name", "first")
294          state.set("name", "second", handler_override=custom_handler)
295          assert state.get("name") == "first-second"
296  
297      def test_state_has(self, basic_schema):
298          state = State(basic_schema, {"name": "test"})
299          assert state.has("name") is True
300          assert state.has("non_existent") is False
301  
302      def test_state_empty_schema(self):
303          state = State({})
304          assert state.data == {}
305  
306          # Instead of comparing the entire schema directly, check structure separately
307          assert "messages" in state.schema
308          assert state.schema["messages"]["type"] == list[ChatMessage]
309          assert callable(state.schema["messages"]["handler"])
310  
311          with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
312              state.set("any_key", "value")
313  
314      def test_state_none_values(self, basic_schema):
315          state = State(basic_schema)
316          state.set("name", None)
317          assert state.get("name") is None
318          state.set("name", "value")
319          assert state.get("name") == "value"
320  
321      def test_state_merge_lists(self, basic_schema):
322          state = State(basic_schema)
323          state.set("numbers", "not_a_list")
324          assert state.get("numbers") == ["not_a_list"]
325          state.set("numbers", [1, 2])
326          assert state.get("numbers") == ["not_a_list", 1, 2]
327  
328      def test_state_nested_structures(self):
329          schema = {
330              "complex": {
331                  "type": dict[str, list[int]],
332                  "handler": lambda current, new: (
333                      {k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys())}
334                      if current
335                      else new
336                  ),
337              }
338          }
339  
340          state = State(schema)
341          state.set("complex", {"a": [1, 2], "b": [3, 4]})
342          state.set("complex", {"b": [5, 6], "c": [7, 8]})
343  
344          expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]}
345          assert state.get("complex") == expected
346  
347      def test_schema_to_dict(self, basic_schema):
348          expected_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}}
349          result = _schema_to_dict(basic_schema)
350          assert result == expected_dict
351  
352      def test_schema_to_dict_with_handlers(self, complex_schema):
353          expected_dict = {
354              "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"},
355              "metadata": {"type": "dict"},
356              "name": {"type": "str"},
357          }
358          result = _schema_to_dict(complex_schema)
359          assert result == expected_dict
360  
361      def test_schema_from_dict(self, basic_schema):
362          schema_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}}
363          result = _schema_from_dict(schema_dict)
364          assert result == basic_schema
365  
366      def test_schema_from_dict_with_handlers(self, complex_schema):
367          schema_dict = {
368              "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"},
369              "metadata": {"type": "dict"},
370              "name": {"type": "str"},
371          }
372          result = _schema_from_dict(schema_dict)
373          assert result == complex_schema
374  
375      def test_state_mutability(self):
376          state = State({"my_list": {"type": list}}, {"my_list": [1, 2]})
377  
378          my_list = state.get("my_list")
379          my_list.append(3)
380  
381          assert state.get("my_list") == [1, 2]
382  
383      def test_state_to_dict(self):
384          # we test dict, a python type and a haystack dataclass
385          state_schema = {
386              "numbers": {"type": int},
387              "messages": {"type": list[ChatMessage]},
388              "dict_of_lists": {"type": dict},
389          }
390  
391          data = {
392              "numbers": 1,
393              "messages": [ChatMessage.from_user(text="Hello, world!")],
394              "dict_of_lists": {"numbers": [1, 2, 3]},
395          }
396          state = State(state_schema, data)
397          state_dict = state.to_dict()
398          assert state_dict["schema"] == {
399              "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"},
400              "messages": {
401                  "type": "list[haystack.dataclasses.chat_message.ChatMessage]",
402                  "handler": "haystack.components.agents.state.state_utils.merge_lists",
403              },
404              "dict_of_lists": {"type": "dict", "handler": "haystack.components.agents.state.state_utils.replace_values"},
405          }
406          assert state_dict["data"] == {
407              "serialization_schema": {
408                  "type": "object",
409                  "properties": {
410                      "numbers": {"type": "integer"},
411                      "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
412                      "dict_of_lists": {
413                          "type": "object",
414                          "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}},
415                      },
416                  },
417              },
418              "serialized_data": {
419                  "numbers": 1,
420                  "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
421                  "dict_of_lists": {"numbers": [1, 2, 3]},
422              },
423          }
424  
425      def test_state_from_dict(self):
426          state_dict = {
427              "schema": {
428                  "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"},
429                  "messages": {
430                      "type": "list[haystack.dataclasses.chat_message.ChatMessage]",
431                      "handler": "haystack.components.agents.state.state_utils.merge_lists",
432                  },
433                  "dict_of_lists": {
434                      "type": "dict",
435                      "handler": "haystack.components.agents.state.state_utils.replace_values",
436                  },
437              },
438              "data": {
439                  "serialization_schema": {
440                      "type": "object",
441                      "properties": {
442                          "numbers": {"type": "integer"},
443                          "messages": {
444                              "type": "array",
445                              "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
446                          },
447                          "dict_of_lists": {
448                              "type": "object",
449                              "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}},
450                          },
451                      },
452                  },
453                  "serialized_data": {
454                      "numbers": 1,
455                      "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
456                      "dict_of_lists": {"numbers": [1, 2, 3]},
457                  },
458              },
459          }
460          state = State.from_dict(state_dict)
461          # Check types are correctly converted
462          assert state.schema["numbers"]["type"] == int
463          assert state.schema["dict_of_lists"]["type"] == dict
464          # Check handlers are functions, not comparing exact functions as they might be different references
465          assert callable(state.schema["numbers"]["handler"])
466          assert callable(state.schema["messages"]["handler"])
467          assert callable(state.schema["dict_of_lists"]["handler"])
468          # Check data is correct
469          assert state.data["numbers"] == 1
470          assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")]
471          assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}
472  
473      def test_state_to_dict_typing_list(self):
474          # we test dict, a python type and a haystack dataclass
475          state_schema = {
476              "numbers": {"type": int},
477              "messages": {"type": List[ChatMessage]},
478              "dict_of_lists": {"type": dict},
479          }
480  
481          data = {
482              "numbers": 1,
483              "messages": [ChatMessage.from_user(text="Hello, world!")],
484              "dict_of_lists": {"numbers": [1, 2, 3]},
485          }
486          state = State(state_schema, data)
487          state_dict = state.to_dict()
488          assert state_dict["schema"] == {
489              "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"},
490              "messages": {
491                  "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]",
492                  "handler": "haystack.components.agents.state.state_utils.merge_lists",
493              },
494              "dict_of_lists": {"type": "dict", "handler": "haystack.components.agents.state.state_utils.replace_values"},
495          }
496          assert state_dict["data"] == {
497              "serialization_schema": {
498                  "type": "object",
499                  "properties": {
500                      "numbers": {"type": "integer"},
501                      "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
502                      "dict_of_lists": {
503                          "type": "object",
504                          "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}},
505                      },
506                  },
507              },
508              "serialized_data": {
509                  "numbers": 1,
510                  "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
511                  "dict_of_lists": {"numbers": [1, 2, 3]},
512              },
513          }
514  
515      def test_state_from_dict_typing_list(self):
516          state_dict = {
517              "schema": {
518                  "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"},
519                  "messages": {
520                      "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]",
521                      "handler": "haystack.components.agents.state.state_utils.merge_lists",
522                  },
523                  "dict_of_lists": {
524                      "type": "dict",
525                      "handler": "haystack.components.agents.state.state_utils.replace_values",
526                  },
527              },
528              "data": {
529                  "serialization_schema": {
530                      "type": "object",
531                      "properties": {
532                          "numbers": {"type": "integer"},
533                          "messages": {
534                              "type": "array",
535                              "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
536                          },
537                          "dict_of_lists": {
538                              "type": "object",
539                              "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}},
540                          },
541                      },
542                  },
543                  "serialized_data": {
544                      "numbers": 1,
545                      "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
546                      "dict_of_lists": {"numbers": [1, 2, 3]},
547                  },
548              },
549          }
550          state = State.from_dict(state_dict)
551          # Check types are correctly converted
552          assert state.schema["numbers"]["type"] == int
553          assert state.schema["dict_of_lists"]["type"] == dict
554          # Check handlers are functions, not comparing exact functions as they might be different references
555          assert callable(state.schema["numbers"]["handler"])
556          assert callable(state.schema["messages"]["handler"])
557          assert callable(state.schema["dict_of_lists"]["handler"])
558          # Check data is correct
559          assert state.data["numbers"] == 1
560          assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")]
561          assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}