state.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from collections.abc import Callable 6 from copy import deepcopy 7 from typing import Any, get_args 8 9 from haystack.dataclasses import ChatMessage 10 from haystack.utils import _deserialize_value_with_schema, _serialize_value_with_schema 11 from haystack.utils.callable_serialization import deserialize_callable, serialize_callable 12 from haystack.utils.type_serialization import deserialize_type, serialize_type 13 14 from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values 15 16 17 def _schema_to_dict(schema: dict[str, Any]) -> dict[str, Any]: 18 """ 19 Convert a schema dictionary to a serializable format. 20 21 Converts each parameter's type and optional handler function into a serializable 22 format using type and callable serialization utilities. 23 24 :param schema: Dictionary mapping parameter names to their type and handler configs 25 :returns: Dictionary with serialized type and handler information 26 """ 27 serialized_schema = {} 28 for param, config in schema.items(): 29 serialized_schema[param] = {"type": serialize_type(config["type"])} 30 if config.get("handler"): 31 serialized_schema[param]["handler"] = serialize_callable(config["handler"]) 32 33 return serialized_schema 34 35 36 def _schema_from_dict(schema: dict[str, Any]) -> dict[str, Any]: 37 """ 38 Convert a serialized schema dictionary back to its original format. 39 40 Deserializes the type and optional handler function for each parameter from their 41 serialized format back into Python types and callables. 42 43 :param schema: Dictionary containing serialized schema information 44 :returns: Dictionary with deserialized type and handler configurations 45 """ 46 deserialized_schema = {} 47 for param, config in schema.items(): 48 deserialized_schema[param] = {"type": deserialize_type(config["type"])} 49 50 if config.get("handler"): 51 deserialized_schema[param]["handler"] = deserialize_callable(config["handler"]) 52 53 return deserialized_schema 54 55 56 def _validate_schema(schema: dict[str, Any]) -> None: 57 """ 58 Validate that a schema dictionary meets all required constraints. 59 60 Checks that each parameter definition has a valid type field and that any handler 61 specified is a callable function. 62 63 :param schema: Dictionary mapping parameter names to their type and handler configs 64 :raises ValueError: If schema validation fails due to missing or invalid fields 65 """ 66 for param, definition in schema.items(): 67 if "type" not in definition: 68 raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.") 69 if not _is_valid_type(definition["type"]): 70 raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}") 71 if definition.get("handler") is not None and not callable(definition["handler"]): 72 raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None") 73 if param == "messages": # definition["type"] != list[ChatMessage] but split to cover also List[ChatMessage] 74 if not _is_list_type(definition["type"]): 75 raise ValueError(f"StateSchema: 'messages' must be of type list[ChatMessage], got {definition['type']}") 76 # Check if the list contains ChatMessage elements 77 args = get_args(definition["type"]) 78 if not args or not issubclass(args[0], ChatMessage): 79 raise ValueError(f"StateSchema: 'messages' must be of type list[ChatMessage], got {definition['type']}") 80 81 82 class State: 83 """ 84 State is a container for storing shared information during the execution of an Agent and its tools. 85 86 For instance, State can be used to store documents, context, and intermediate results. 87 88 Internally it wraps a `_data` dictionary defined by a `schema`. Each schema entry has: 89 ```json 90 "parameter_name": { 91 "type": SomeType, # expected type 92 "handler": Optional[Callable[[Any, Any], Any]] # merge/update function 93 } 94 ``` 95 96 Handlers control how values are merged when using the `set()` method: 97 - For list types: defaults to `merge_lists` (concatenates lists) 98 - For other types: defaults to `replace_values` (overwrites existing value) 99 100 A `messages` field with type `list[ChatMessage]` is automatically added to the schema. 101 102 This makes it possible for the Agent to read from and write to the same context. 103 104 ### Usage example 105 ```python 106 from haystack.components.agents.state import State 107 108 my_state = State( 109 schema={"gh_repo_name": {"type": str}, "user_name": {"type": str}}, 110 data={"gh_repo_name": "my_repo", "user_name": "my_user_name"} 111 ) 112 ``` 113 """ 114 115 def __init__(self, schema: dict[str, Any], data: dict[str, Any] | None = None) -> None: 116 """ 117 Initialize a State object with a schema and optional data. 118 119 :param schema: Dictionary mapping parameter names to their type and handler configs. 120 Type must be a valid Python type, and handler must be a callable function or None. 121 If handler is None, the default handler for the type will be used. The default handlers are: 122 - For list types: `haystack.agents.state.state_utils.merge_lists` 123 - For all other types: `haystack.agents.state.state_utils.replace_values` 124 :param data: Optional dictionary of initial data to populate the state 125 """ 126 _validate_schema(schema) 127 self.schema = deepcopy(schema) 128 if self.schema.get("messages") is None: 129 self.schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists} 130 self._data = data or {} 131 132 # Set default handlers if not provided in schema 133 for definition in self.schema.values(): 134 # Skip if handler is already defined and not None 135 if definition.get("handler") is not None: 136 continue 137 # Set default handler based on type 138 if _is_list_type(definition["type"]): 139 definition["handler"] = merge_lists 140 else: 141 definition["handler"] = replace_values 142 143 def get(self, key: str, default: Any = None) -> Any: 144 """ 145 Retrieve a value from the state by key. 146 147 :param key: Key to look up in the state 148 :param default: Value to return if key is not found 149 :returns: Value associated with key or default if not found 150 """ 151 return deepcopy(self._data.get(key, default)) 152 153 def set(self, key: str, value: Any, handler_override: Callable[[Any, Any], Any] | None = None) -> None: 154 """ 155 Set or merge a value in the state according to schema rules. 156 157 Value is merged or overwritten according to these rules: 158 - if handler_override is given, use that 159 - else use the handler defined in the schema for 'key' 160 161 :param key: Key to store the value under 162 :param value: Value to store or merge 163 :param handler_override: Optional function to override the default merge behavior 164 """ 165 # If key not in schema, we throw an error 166 definition = self.schema.get(key, None) 167 if definition is None: 168 raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}") 169 170 # Get current value from state and apply handler 171 current_value = self._data.get(key, None) 172 handler = handler_override or definition["handler"] 173 self._data[key] = handler(current_value, value) 174 175 @property 176 def data(self) -> dict[str, Any]: 177 """ 178 All current data of the state. 179 """ 180 return self._data 181 182 def has(self, key: str) -> bool: 183 """ 184 Check if a key exists in the state. 185 186 :param key: Key to check for existence 187 :returns: True if key exists in state, False otherwise 188 """ 189 return key in self._data 190 191 def to_dict(self) -> dict[str, Any]: 192 """ 193 Convert the State object to a dictionary. 194 """ 195 serialized = {} 196 serialized["schema"] = _schema_to_dict(self.schema) 197 serialized["data"] = _serialize_value_with_schema(self._data) 198 return serialized 199 200 @classmethod 201 def from_dict(cls, data: dict[str, Any]) -> "State": 202 """ 203 Convert a dictionary back to a State object. 204 """ 205 schema = _schema_from_dict(data.get("schema", {})) 206 deserialized_data = _deserialize_value_with_schema(data.get("data", {})) 207 return State(schema, deserialized_data)