base_serialization.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from enum import Enum 6 from typing import Any 7 8 import pydantic 9 10 from haystack import logging 11 from haystack.core.errors import DeserializationError, SerializationError 12 from haystack.core.serialization import generate_qualified_class_name, import_class_by_name 13 from haystack.utils import deserialize_callable, serialize_callable 14 15 logger = logging.getLogger(__name__) 16 17 _PRIMITIVE_TO_SCHEMA_MAP = {type(None): "null", bool: "boolean", int: "integer", float: "number", str: "string"} 18 19 20 def serialize_class_instance(obj: Any) -> dict[str, Any]: 21 """ 22 Serializes an object that has a `to_dict` method into a dictionary. 23 24 :param obj: 25 The object to be serialized. 26 :returns: 27 A dictionary representation of the object. 28 :raises SerializationError: 29 If the object does not have a `to_dict` method. 30 """ 31 if not hasattr(obj, "to_dict"): 32 raise SerializationError(f"Object of class '{type(obj).__name__}' does not have a 'to_dict' method") 33 34 output = obj.to_dict() 35 return {"type": generate_qualified_class_name(type(obj)), "data": output} 36 37 38 def deserialize_class_instance(data: dict[str, Any]) -> Any: 39 """ 40 Deserializes an object from a dictionary representation generated by `auto_serialize_class_instance`. 41 42 :param data: 43 The dictionary to deserialize from. 44 :returns: 45 The deserialized object. 46 :raises DeserializationError: 47 If the serialization data is malformed, the class type cannot be imported, or the 48 class does not have a `from_dict` method. 49 """ 50 if "type" not in data: 51 raise DeserializationError("Missing 'type' in serialization data") 52 if "data" not in data: 53 raise DeserializationError("Missing 'data' in serialization data") 54 55 try: 56 obj_class = import_class_by_name(data["type"]) 57 except ImportError as e: 58 raise DeserializationError(f"Class '{data['type']}' not correctly imported") from e 59 60 if not hasattr(obj_class, "from_dict"): 61 raise DeserializationError(f"Class '{data['type']}' does not have a 'from_dict' method") 62 63 return obj_class.from_dict(data["data"]) 64 65 66 def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # noqa: PLR0911 67 """ 68 Serializes a value into a schema-aware format suitable for storage or transmission. 69 70 The output format separates the schema information from the actual data, making it easier 71 to deserialize complex nested structures correctly. 72 73 The function handles: 74 - Objects with to_dict() methods (e.g. dataclasses) 75 - Objects with __dict__ attributes 76 - Dictionaries 77 - Lists, tuples, and sets. Lists with mixed types are not supported. 78 - Primitive types (str, int, float, bool, None) 79 80 :param payload: The value to serialize (can be any type) 81 :returns: The serialized dict representation of the given value. Contains two keys: 82 - "serialization_schema": Contains type information for each field. 83 - "serialized_data": Contains the actual data in a simplified format. 84 85 """ 86 # Handle pydantic 87 if isinstance(payload, pydantic.BaseModel): 88 type_name = generate_qualified_class_name(type(payload)) 89 return {"serialization_schema": {"type": type_name}, "serialized_data": payload.model_dump()} 90 91 # Handle dictionary case - iterate through fields 92 if isinstance(payload, dict): 93 schema: dict[str, Any] = {} 94 data: dict[str, Any] = {} 95 96 for field, val in payload.items(): 97 # Recursively serialize each field 98 serialized_value = _serialize_value_with_schema(val) 99 schema[field] = serialized_value["serialization_schema"] 100 data[field] = serialized_value["serialized_data"] 101 102 return {"serialization_schema": {"type": "object", "properties": schema}, "serialized_data": data} 103 104 # Handle array case - iterate through elements 105 if isinstance(payload, (list, tuple, set)): 106 # Serialize each item in the array 107 serialized_list = [] 108 for item in payload: 109 serialized_value = _serialize_value_with_schema(item) 110 serialized_list.append(serialized_value["serialized_data"]) 111 112 # Determine item type from first element (if any) 113 # NOTE: We do not support mixed-type lists 114 if payload: 115 first = next(iter(payload)) 116 item_schema = _serialize_value_with_schema(first) 117 base_schema = {"type": "array", "items": item_schema["serialization_schema"]} 118 else: 119 base_schema = {"type": "array", "items": {}} 120 121 # Add JSON Schema properties to infer sets and tuples 122 if isinstance(payload, set): 123 base_schema["uniqueItems"] = True 124 elif isinstance(payload, tuple): 125 base_schema["minItems"] = len(payload) 126 base_schema["maxItems"] = len(payload) 127 128 return {"serialization_schema": base_schema, "serialized_data": serialized_list} 129 130 # Handle Haystack style objects (e.g. dataclasses and Components) 131 if hasattr(payload, "to_dict") and callable(payload.to_dict): 132 type_name = generate_qualified_class_name(type(payload)) 133 schema = {"type": type_name} 134 return {"serialization_schema": schema, "serialized_data": payload.to_dict()} 135 136 # Handle callable functions serialization 137 if callable(payload) and not isinstance(payload, type): 138 serialized = serialize_callable(payload) 139 return {"serialization_schema": {"type": "typing.Callable"}, "serialized_data": serialized} 140 141 # Handle Enums 142 if isinstance(payload, Enum): 143 type_name = generate_qualified_class_name(type(payload)) 144 return {"serialization_schema": {"type": type_name}, "serialized_data": payload.name} 145 146 # Handle arbitrary objects with __dict__ 147 if hasattr(payload, "__dict__"): 148 type_name = generate_qualified_class_name(type(payload)) 149 schema = {"type": type_name} 150 serialized_data = {} 151 for key, value in vars(payload).items(): 152 serialized_value = _serialize_value_with_schema(value) 153 serialized_data[key] = serialized_value["serialized_data"] 154 return {"serialization_schema": schema, "serialized_data": serialized_data} 155 156 # Handle primitives 157 schema = {"type": _primitive_schema_type(payload)} 158 return {"serialization_schema": schema, "serialized_data": payload} 159 160 161 def _primitive_schema_type(value: Any) -> str: 162 """ 163 Helper function to determine the schema type for primitive values. 164 """ 165 for py_type, schema_value in _PRIMITIVE_TO_SCHEMA_MAP.items(): 166 if isinstance(value, py_type): 167 return schema_value 168 logger.warning( 169 "Unsupported primitive type '{value_type}', falling back to 'string'", value_type=type(value).__name__ 170 ) 171 return "string" # fallback 172 173 174 def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: 175 """ 176 Deserializes a value with schema information back to its original form. 177 178 Takes a dict of the form: 179 { 180 "serialization_schema": {"type": "integer"} or {"type": "object", "properties": {...}}, 181 "serialized_data": <the actual data> 182 } 183 184 NOTE: For array types we only support homogeneous lists (all elements of the same type). 185 186 :param serialized: The serialized dict with schema and data. 187 :returns: The deserialized value in its original form. 188 """ 189 190 if not serialized or "serialization_schema" not in serialized or "serialized_data" not in serialized: 191 raise DeserializationError( 192 f"Invalid format of passed serialized payload. Expected a dictionary with keys " 193 f"'serialization_schema' and 'serialized_data'. Got: {serialized}" 194 ) 195 schema = serialized["serialization_schema"] 196 data = serialized["serialized_data"] 197 198 schema_type = schema.get("type") 199 200 if not schema_type: 201 # for backward compatibility till Haystack 2.16 we use legacy implementation 202 raise DeserializationError( 203 "Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized " 204 "State object created with a version of Haystack older than 2.15.0. " 205 "Support for the old serialization format is removed in Haystack 2.16.0. " 206 "Please upgrade to the new serialization format to ensure forward compatibility." 207 ) 208 209 # Handle object case (dictionary with properties) 210 if schema_type == "object": 211 properties = schema["properties"] 212 result: dict[str, Any] = {} 213 for field, raw_value in data.items(): 214 field_schema = properties[field] 215 # Recursively deserialize each field - avoid creating temporary dict 216 result[field] = _deserialize_value_with_schema( 217 {"serialization_schema": field_schema, "serialized_data": raw_value} 218 ) 219 return result 220 221 # Handle array case 222 if schema_type == "array": 223 # Deserialize each item 224 deserialized_items = [ 225 _deserialize_value_with_schema({"serialization_schema": schema["items"], "serialized_data": item}) 226 for item in data 227 ] 228 final_array: list | set | tuple 229 # Is a set if uniqueItems is True 230 if schema.get("uniqueItems") is True: 231 final_array = set(deserialized_items) 232 # Is a tuple if minItems and maxItems are set 233 elif schema.get("minItems") is not None and schema.get("maxItems") is not None: 234 final_array = tuple(deserialized_items) 235 else: 236 # Otherwise, it's a list 237 final_array = list(deserialized_items) 238 return final_array 239 240 # Handle primitive types 241 if schema_type in _PRIMITIVE_TO_SCHEMA_MAP.values(): 242 return data 243 244 # Handle callable functions 245 if schema_type == "typing.Callable": 246 return deserialize_callable(data) 247 248 # Handle custom class types 249 return _deserialize_value({"type": schema_type, "data": data}) 250 251 252 def _deserialize_value(value: dict[str, Any]) -> Any: 253 """ 254 Helper function to deserialize values from their envelope format {"type": T, "data": D}. 255 256 This handles: 257 - Custom classes (with a from_dict method) 258 - Enums 259 - Fallback for arbitrary classes (sets attributes on a blank instance) 260 261 :param value: The value to deserialize 262 :returns: 263 The deserialized value 264 :raises DeserializationError: 265 If the type cannot be imported or the value is not valid for the type. 266 """ 267 # 1) Envelope case 268 value_type = value["type"] 269 payload = value["data"] 270 271 # Custom class where value_type is a qualified class name 272 cls = import_class_by_name(value_type) 273 274 # try from_dict (e.g. Haystack dataclasses and Components) 275 if hasattr(cls, "from_dict") and callable(cls.from_dict): 276 return cls.from_dict(payload) 277 278 # handle pydantic models 279 if issubclass(cls, pydantic.BaseModel): 280 try: 281 return cls.model_validate(payload) 282 except Exception as e: 283 raise DeserializationError( 284 f"Failed to deserialize data '{payload}' into Pydantic model '{value_type}'" 285 ) from e 286 287 # handle enum types 288 if issubclass(cls, Enum): 289 try: 290 return cls[payload] 291 except Exception as e: 292 raise DeserializationError(f"Value '{payload}' is not a valid member of Enum '{value_type}'") from e 293 294 # fallback: set attributes on a blank instance 295 deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()} 296 instance = cls.__new__(cls) 297 for attr_name, attr_value in deserialized_payload.items(): 298 setattr(instance, attr_name, attr_value) 299 return instance