/ haystack / components / agents / state / state.py
state.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from collections.abc import Callable
  6  from copy import deepcopy
  7  from typing import Any, get_args
  8  
  9  from haystack.dataclasses import ChatMessage
 10  from haystack.utils import _deserialize_value_with_schema, _serialize_value_with_schema
 11  from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
 12  from haystack.utils.type_serialization import deserialize_type, serialize_type
 13  
 14  from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
 15  
 16  
 17  def _schema_to_dict(schema: dict[str, Any]) -> dict[str, Any]:
 18      """
 19      Convert a schema dictionary to a serializable format.
 20  
 21      Converts each parameter's type and optional handler function into a serializable
 22      format using type and callable serialization utilities.
 23  
 24      :param schema: Dictionary mapping parameter names to their type and handler configs
 25      :returns: Dictionary with serialized type and handler information
 26      """
 27      serialized_schema = {}
 28      for param, config in schema.items():
 29          serialized_schema[param] = {"type": serialize_type(config["type"])}
 30          if config.get("handler"):
 31              serialized_schema[param]["handler"] = serialize_callable(config["handler"])
 32  
 33      return serialized_schema
 34  
 35  
 36  def _schema_from_dict(schema: dict[str, Any]) -> dict[str, Any]:
 37      """
 38      Convert a serialized schema dictionary back to its original format.
 39  
 40      Deserializes the type and optional handler function for each parameter from their
 41      serialized format back into Python types and callables.
 42  
 43      :param schema: Dictionary containing serialized schema information
 44      :returns: Dictionary with deserialized type and handler configurations
 45      """
 46      deserialized_schema = {}
 47      for param, config in schema.items():
 48          deserialized_schema[param] = {"type": deserialize_type(config["type"])}
 49  
 50          if config.get("handler"):
 51              deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])
 52  
 53      return deserialized_schema
 54  
 55  
 56  def _validate_schema(schema: dict[str, Any]) -> None:
 57      """
 58      Validate that a schema dictionary meets all required constraints.
 59  
 60      Checks that each parameter definition has a valid type field and that any handler
 61      specified is a callable function.
 62  
 63      :param schema: Dictionary mapping parameter names to their type and handler configs
 64      :raises ValueError: If schema validation fails due to missing or invalid fields
 65      """
 66      for param, definition in schema.items():
 67          if "type" not in definition:
 68              raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
 69          if not _is_valid_type(definition["type"]):
 70              raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
 71          if definition.get("handler") is not None and not callable(definition["handler"]):
 72              raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
 73          if param == "messages":  # definition["type"] != list[ChatMessage] but split to cover also List[ChatMessage]
 74              if not _is_list_type(definition["type"]):
 75                  raise ValueError(f"StateSchema: 'messages' must be of type list[ChatMessage], got {definition['type']}")
 76              # Check if the list contains ChatMessage elements
 77              args = get_args(definition["type"])
 78              if not args or not issubclass(args[0], ChatMessage):
 79                  raise ValueError(f"StateSchema: 'messages' must be of type list[ChatMessage], got {definition['type']}")
 80  
 81  
 82  class State:
 83      """
 84      State is a container for storing shared information during the execution of an Agent and its tools.
 85  
 86      For instance, State can be used to store documents, context, and intermediate results.
 87  
 88      Internally it wraps a `_data` dictionary defined by a `schema`. Each schema entry has:
 89      ```json
 90        "parameter_name": {
 91          "type": SomeType,  # expected type
 92          "handler": Optional[Callable[[Any, Any], Any]]  # merge/update function
 93        }
 94        ```
 95  
 96      Handlers control how values are merged when using the `set()` method:
 97      - For list types: defaults to `merge_lists` (concatenates lists)
 98      - For other types: defaults to `replace_values` (overwrites existing value)
 99  
100      A `messages` field with type `list[ChatMessage]` is automatically added to the schema.
101  
102      This makes it possible for the Agent to read from and write to the same context.
103  
104      ### Usage example
105      ```python
106      from haystack.components.agents.state import State
107  
108      my_state = State(
109          schema={"gh_repo_name": {"type": str}, "user_name": {"type": str}},
110          data={"gh_repo_name": "my_repo", "user_name": "my_user_name"}
111      )
112      ```
113      """
114  
115      def __init__(self, schema: dict[str, Any], data: dict[str, Any] | None = None) -> None:
116          """
117          Initialize a State object with a schema and optional data.
118  
119          :param schema: Dictionary mapping parameter names to their type and handler configs.
120              Type must be a valid Python type, and handler must be a callable function or None.
121              If handler is None, the default handler for the type will be used. The default handlers are:
122                  - For list types: `haystack.agents.state.state_utils.merge_lists`
123                  - For all other types: `haystack.agents.state.state_utils.replace_values`
124          :param data: Optional dictionary of initial data to populate the state
125          """
126          _validate_schema(schema)
127          self.schema = deepcopy(schema)
128          if self.schema.get("messages") is None:
129              self.schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists}
130          self._data = data or {}
131  
132          # Set default handlers if not provided in schema
133          for definition in self.schema.values():
134              # Skip if handler is already defined and not None
135              if definition.get("handler") is not None:
136                  continue
137              # Set default handler based on type
138              if _is_list_type(definition["type"]):
139                  definition["handler"] = merge_lists
140              else:
141                  definition["handler"] = replace_values
142  
143      def get(self, key: str, default: Any = None) -> Any:
144          """
145          Retrieve a value from the state by key.
146  
147          :param key: Key to look up in the state
148          :param default: Value to return if key is not found
149          :returns: Value associated with key or default if not found
150          """
151          return deepcopy(self._data.get(key, default))
152  
153      def set(self, key: str, value: Any, handler_override: Callable[[Any, Any], Any] | None = None) -> None:
154          """
155          Set or merge a value in the state according to schema rules.
156  
157          Value is merged or overwritten according to these rules:
158            - if handler_override is given, use that
159            - else use the handler defined in the schema for 'key'
160  
161          :param key: Key to store the value under
162          :param value: Value to store or merge
163          :param handler_override: Optional function to override the default merge behavior
164          """
165          # If key not in schema, we throw an error
166          definition = self.schema.get(key, None)
167          if definition is None:
168              raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")
169  
170          # Get current value from state and apply handler
171          current_value = self._data.get(key, None)
172          handler = handler_override or definition["handler"]
173          self._data[key] = handler(current_value, value)
174  
175      @property
176      def data(self) -> dict[str, Any]:
177          """
178          All current data of the state.
179          """
180          return self._data
181  
182      def has(self, key: str) -> bool:
183          """
184          Check if a key exists in the state.
185  
186          :param key: Key to check for existence
187          :returns: True if key exists in state, False otherwise
188          """
189          return key in self._data
190  
191      def to_dict(self) -> dict[str, Any]:
192          """
193          Convert the State object to a dictionary.
194          """
195          serialized = {}
196          serialized["schema"] = _schema_to_dict(self.schema)
197          serialized["data"] = _serialize_value_with_schema(self._data)
198          return serialized
199  
200      @classmethod
201      def from_dict(cls, data: dict[str, Any]) -> "State":
202          """
203          Convert a dictionary back to a State object.
204          """
205          schema = _schema_from_dict(data.get("schema", {}))
206          deserialized_data = _deserialize_value_with_schema(data.get("data", {}))
207          return State(schema, deserialized_data)