/ haystack / utils / base_serialization.py
base_serialization.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from enum import Enum
  6  from typing import Any
  7  
  8  import pydantic
  9  
 10  from haystack import logging
 11  from haystack.core.errors import DeserializationError, SerializationError
 12  from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
 13  from haystack.utils import deserialize_callable, serialize_callable
 14  
 15  logger = logging.getLogger(__name__)
 16  
 17  _PRIMITIVE_TO_SCHEMA_MAP = {type(None): "null", bool: "boolean", int: "integer", float: "number", str: "string"}
 18  
 19  
 20  def serialize_class_instance(obj: Any) -> dict[str, Any]:
 21      """
 22      Serializes an object that has a `to_dict` method into a dictionary.
 23  
 24      :param obj:
 25          The object to be serialized.
 26      :returns:
 27          A dictionary representation of the object.
 28      :raises SerializationError:
 29          If the object does not have a `to_dict` method.
 30      """
 31      if not hasattr(obj, "to_dict"):
 32          raise SerializationError(f"Object of class '{type(obj).__name__}' does not have a 'to_dict' method")
 33  
 34      output = obj.to_dict()
 35      return {"type": generate_qualified_class_name(type(obj)), "data": output}
 36  
 37  
 38  def deserialize_class_instance(data: dict[str, Any]) -> Any:
 39      """
 40      Deserializes an object from a dictionary representation generated by `auto_serialize_class_instance`.
 41  
 42      :param data:
 43          The dictionary to deserialize from.
 44      :returns:
 45          The deserialized object.
 46      :raises DeserializationError:
 47          If the serialization data is malformed, the class type cannot be imported, or the
 48          class does not have a `from_dict` method.
 49      """
 50      if "type" not in data:
 51          raise DeserializationError("Missing 'type' in serialization data")
 52      if "data" not in data:
 53          raise DeserializationError("Missing 'data' in serialization data")
 54  
 55      try:
 56          obj_class = import_class_by_name(data["type"])
 57      except ImportError as e:
 58          raise DeserializationError(f"Class '{data['type']}' not correctly imported") from e
 59  
 60      if not hasattr(obj_class, "from_dict"):
 61          raise DeserializationError(f"Class '{data['type']}' does not have a 'from_dict' method")
 62  
 63      return obj_class.from_dict(data["data"])
 64  
 65  
 66  def _serialize_value_with_schema(payload: Any) -> dict[str, Any]:  # noqa: PLR0911
 67      """
 68      Serializes a value into a schema-aware format suitable for storage or transmission.
 69  
 70      The output format separates the schema information from the actual data, making it easier
 71      to deserialize complex nested structures correctly.
 72  
 73      The function handles:
 74      - Objects with to_dict() methods (e.g. dataclasses)
 75      - Objects with __dict__ attributes
 76      - Dictionaries
 77      - Lists, tuples, and sets. Lists with mixed types are not supported.
 78      - Primitive types (str, int, float, bool, None)
 79  
 80      :param payload: The value to serialize (can be any type)
 81      :returns: The serialized dict representation of the given value. Contains two keys:
 82          - "serialization_schema": Contains type information for each field.
 83          - "serialized_data": Contains the actual data in a simplified format.
 84  
 85      """
 86      # Handle pydantic
 87      if isinstance(payload, pydantic.BaseModel):
 88          type_name = generate_qualified_class_name(type(payload))
 89          return {"serialization_schema": {"type": type_name}, "serialized_data": payload.model_dump()}
 90  
 91      # Handle dictionary case - iterate through fields
 92      if isinstance(payload, dict):
 93          schema: dict[str, Any] = {}
 94          data: dict[str, Any] = {}
 95  
 96          for field, val in payload.items():
 97              # Recursively serialize each field
 98              serialized_value = _serialize_value_with_schema(val)
 99              schema[field] = serialized_value["serialization_schema"]
100              data[field] = serialized_value["serialized_data"]
101  
102          return {"serialization_schema": {"type": "object", "properties": schema}, "serialized_data": data}
103  
104      # Handle array case - iterate through elements
105      if isinstance(payload, (list, tuple, set)):
106          # Serialize each item in the array
107          serialized_list = []
108          for item in payload:
109              serialized_value = _serialize_value_with_schema(item)
110              serialized_list.append(serialized_value["serialized_data"])
111  
112          # Determine item type from first element (if any)
113          # NOTE: We do not support mixed-type lists
114          if payload:
115              first = next(iter(payload))
116              item_schema = _serialize_value_with_schema(first)
117              base_schema = {"type": "array", "items": item_schema["serialization_schema"]}
118          else:
119              base_schema = {"type": "array", "items": {}}
120  
121          # Add JSON Schema properties to infer sets and tuples
122          if isinstance(payload, set):
123              base_schema["uniqueItems"] = True
124          elif isinstance(payload, tuple):
125              base_schema["minItems"] = len(payload)
126              base_schema["maxItems"] = len(payload)
127  
128          return {"serialization_schema": base_schema, "serialized_data": serialized_list}
129  
130      # Handle Haystack style objects (e.g. dataclasses and Components)
131      if hasattr(payload, "to_dict") and callable(payload.to_dict):
132          type_name = generate_qualified_class_name(type(payload))
133          schema = {"type": type_name}
134          return {"serialization_schema": schema, "serialized_data": payload.to_dict()}
135  
136      # Handle callable functions serialization
137      if callable(payload) and not isinstance(payload, type):
138          serialized = serialize_callable(payload)
139          return {"serialization_schema": {"type": "typing.Callable"}, "serialized_data": serialized}
140  
141      # Handle Enums
142      if isinstance(payload, Enum):
143          type_name = generate_qualified_class_name(type(payload))
144          return {"serialization_schema": {"type": type_name}, "serialized_data": payload.name}
145  
146      # Handle arbitrary objects with __dict__
147      if hasattr(payload, "__dict__"):
148          type_name = generate_qualified_class_name(type(payload))
149          schema = {"type": type_name}
150          serialized_data = {}
151          for key, value in vars(payload).items():
152              serialized_value = _serialize_value_with_schema(value)
153              serialized_data[key] = serialized_value["serialized_data"]
154          return {"serialization_schema": schema, "serialized_data": serialized_data}
155  
156      # Handle primitives
157      schema = {"type": _primitive_schema_type(payload)}
158      return {"serialization_schema": schema, "serialized_data": payload}
159  
160  
161  def _primitive_schema_type(value: Any) -> str:
162      """
163      Helper function to determine the schema type for primitive values.
164      """
165      for py_type, schema_value in _PRIMITIVE_TO_SCHEMA_MAP.items():
166          if isinstance(value, py_type):
167              return schema_value
168      logger.warning(
169          "Unsupported primitive type '{value_type}', falling back to 'string'", value_type=type(value).__name__
170      )
171      return "string"  # fallback
172  
173  
174  def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any:
175      """
176      Deserializes a value with schema information back to its original form.
177  
178      Takes a dict of the form:
179        {
180           "serialization_schema": {"type": "integer"} or {"type": "object", "properties": {...}},
181           "serialized_data": <the actual data>
182        }
183  
184      NOTE: For array types we only support homogeneous lists (all elements of the same type).
185  
186      :param serialized: The serialized dict with schema and data.
187      :returns: The deserialized value in its original form.
188      """
189  
190      if not serialized or "serialization_schema" not in serialized or "serialized_data" not in serialized:
191          raise DeserializationError(
192              f"Invalid format of passed serialized payload. Expected a dictionary with keys "
193              f"'serialization_schema' and 'serialized_data'. Got: {serialized}"
194          )
195      schema = serialized["serialization_schema"]
196      data = serialized["serialized_data"]
197  
198      schema_type = schema.get("type")
199  
200      if not schema_type:
201          # for backward compatibility till Haystack 2.16 we use legacy implementation
202          raise DeserializationError(
203              "Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
204              "State object created with a version of Haystack older than 2.15.0. "
205              "Support for the old serialization format is removed in Haystack 2.16.0. "
206              "Please upgrade to the new serialization format to ensure forward compatibility."
207          )
208  
209      # Handle object case (dictionary with properties)
210      if schema_type == "object":
211          properties = schema["properties"]
212          result: dict[str, Any] = {}
213          for field, raw_value in data.items():
214              field_schema = properties[field]
215              # Recursively deserialize each field - avoid creating temporary dict
216              result[field] = _deserialize_value_with_schema(
217                  {"serialization_schema": field_schema, "serialized_data": raw_value}
218              )
219          return result
220  
221      # Handle array case
222      if schema_type == "array":
223          # Deserialize each item
224          deserialized_items = [
225              _deserialize_value_with_schema({"serialization_schema": schema["items"], "serialized_data": item})
226              for item in data
227          ]
228          final_array: list | set | tuple
229          # Is a set if uniqueItems is True
230          if schema.get("uniqueItems") is True:
231              final_array = set(deserialized_items)
232          # Is a tuple if minItems and maxItems are set
233          elif schema.get("minItems") is not None and schema.get("maxItems") is not None:
234              final_array = tuple(deserialized_items)
235          else:
236              # Otherwise, it's a list
237              final_array = list(deserialized_items)
238          return final_array
239  
240      # Handle primitive types
241      if schema_type in _PRIMITIVE_TO_SCHEMA_MAP.values():
242          return data
243  
244      # Handle callable functions
245      if schema_type == "typing.Callable":
246          return deserialize_callable(data)
247  
248      # Handle custom class types
249      return _deserialize_value({"type": schema_type, "data": data})
250  
251  
252  def _deserialize_value(value: dict[str, Any]) -> Any:
253      """
254      Helper function to deserialize values from their envelope format {"type": T, "data": D}.
255  
256      This handles:
257      - Custom classes (with a from_dict method)
258      - Enums
259      - Fallback for arbitrary classes (sets attributes on a blank instance)
260  
261      :param value: The value to deserialize
262      :returns:
263          The deserialized value
264      :raises DeserializationError:
265          If the type cannot be imported or the value is not valid for the type.
266      """
267      # 1) Envelope case
268      value_type = value["type"]
269      payload = value["data"]
270  
271      # Custom class where value_type is a qualified class name
272      cls = import_class_by_name(value_type)
273  
274      # try from_dict (e.g. Haystack dataclasses and Components)
275      if hasattr(cls, "from_dict") and callable(cls.from_dict):
276          return cls.from_dict(payload)
277  
278      # handle pydantic models
279      if issubclass(cls, pydantic.BaseModel):
280          try:
281              return cls.model_validate(payload)
282          except Exception as e:
283              raise DeserializationError(
284                  f"Failed to deserialize data '{payload}' into Pydantic model '{value_type}'"
285              ) from e
286  
287      # handle enum types
288      if issubclass(cls, Enum):
289          try:
290              return cls[payload]
291          except Exception as e:
292              raise DeserializationError(f"Value '{payload}' is not a valid member of Enum '{value_type}'") from e
293  
294      # fallback: set attributes on a blank instance
295      deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()}
296      instance = cls.__new__(cls)
297      for attr_name, attr_value in deserialized_payload.items():
298          setattr(instance, attr_name, attr_value)
299      return instance