/ haystack / utils / type_serialization.py
type_serialization.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import builtins
  6  import importlib
  7  import inspect
  8  import sys
  9  import typing
 10  from threading import Lock
 11  from types import GenericAlias, ModuleType, NoneType, UnionType
 12  from typing import Any, Union, get_args
 13  
 14  from haystack.core.errors import DeserializationError
 15  
 16  _import_lock = Lock()
 17  
 18  
 19  def _is_union_type(target: Any) -> bool:
 20      """
 21      Check if target is a Union type.
 22  
 23      This handles both `typing.Union[X, Y]` and `X | Y` syntax from PEP 604,
 24      including parameterized types like `Optional[str]`.
 25      """
 26      if target is Union or target is UnionType:
 27          return True
 28      origin = typing.get_origin(target)
 29      return origin is Union or origin is UnionType
 30  
 31  
 32  def _build_pep604_union_type(types: list[type | UnionType]) -> type | UnionType:
 33      """Build a union type from a list of types using PEP 604 syntax (X | Y)."""
 34      result = types[0]
 35      for t in types[1:]:
 36          result = result | t
 37      return result
 38  
 39  
 40  def serialize_type(target: Any) -> str:
 41      """
 42      Serializes a type or an instance to its string representation, including the module name.
 43  
 44      This function handles types, instances of types, and special typing objects.
 45      It assumes that non-typing objects will have a '__name__' attribute.
 46  
 47      :param target:
 48          The object to serialize, can be an instance or a type.
 49      :return:
 50          The string representation of the type.
 51      """
 52      if target is NoneType:
 53          return "None"
 54  
 55      args = get_args(target)
 56  
 57      if isinstance(target, UnionType):
 58          return " | ".join([serialize_type(a) for a in args])
 59  
 60      name = getattr(target, "__name__", str(target))
 61      if name.startswith("typing."):
 62          name = name[7:]
 63      if "[" in name:
 64          name = name.split("[")[0]
 65  
 66      # Get module name
 67      module = inspect.getmodule(target)
 68      module_name = ""
 69      # We omit the module name for builtins to not clutter the output
 70      if module and hasattr(module, "__name__") and module.__name__ != "builtins":
 71          module_name = f"{module.__name__}"
 72  
 73      if args:
 74          # For typing generics, convert PEP 604 union types (X | Y) to typing.Union when serializing.
 75          # This avoids issues with Python's internal cache, where List[Union[str, int]] and List[str | int] are treated
 76          # as the same key. GenericAlias (builtins like list[...]) can keep the PEP 604 syntax.
 77          is_typing_generic = not isinstance(target, GenericAlias)
 78          args_str = ", ".join(
 79              serialize_type(Union[tuple(get_args(a))] if is_typing_generic and isinstance(a, UnionType) else a)  # noqa: UP007
 80              for a in args
 81              if a is not NoneType
 82          )
 83          return f"{module_name}.{name}[{args_str}]" if module_name else f"{name}[{args_str}]"
 84  
 85      return f"{module_name}.{name}" if module_name else f"{name}"
 86  
 87  
 88  def _parse_generic_args(args_str: str) -> list[str]:
 89      args = []
 90      bracket_count = 0
 91      current_arg = ""
 92  
 93      for char in args_str:
 94          if char == "[":
 95              bracket_count += 1
 96          elif char == "]":
 97              bracket_count -= 1
 98  
 99          if char == "," and bracket_count == 0:
100              args.append(current_arg.strip())
101              current_arg = ""
102          else:
103              current_arg += char
104  
105      if current_arg:
106          args.append(current_arg.strip())
107  
108      return args
109  
110  
111  def _parse_pep604_union_args(union_str: str) -> list[str]:
112      """
113      Parse a PEP 604 union string (e.g., "str | int | None") into individual type strings.
114  
115      Handles nested generics properly, e.g., "list[str] | dict[str, int] | None".
116  
117      :param union_str: The union string to parse
118      :returns: A list of individual type strings
119      """
120      args = []
121      bracket_count = 0
122      current_arg = ""
123  
124      for char in union_str:
125          if char == "[":
126              bracket_count += 1
127          elif char == "]":
128              bracket_count -= 1
129  
130          if char == "|" and bracket_count == 0:
131              args.append(current_arg.strip())
132              current_arg = ""
133          else:
134              current_arg += char
135  
136      if current_arg.strip():
137          args.append(current_arg.strip())
138  
139      return args
140  
141  
142  def deserialize_type(type_str: str) -> Any:
143      """
144      Deserializes a type given its full import path as a string, including nested generic types.
145  
146      This function will dynamically import the module if it's not already imported
147      and then retrieve the type object from it. It also handles nested generic types like
148      `list[dict[int, str]]`.
149  
150      :param type_str:
151          The string representation of the type's full import path.
152      :returns:
153          The deserialized type object.
154      :raises DeserializationError:
155          If the type cannot be deserialized due to missing module or type.
156      """
157      # Handle PEP 604 union syntax at the top level (e.g., "str | int", "str | None")
158      pep604_union_args = _parse_pep604_union_args(type_str)
159      if len(pep604_union_args) > 1:
160          deserialized_args = [deserialize_type(arg) for arg in pep604_union_args]
161          return _build_pep604_union_type(deserialized_args)
162  
163      # Handle generics (including Union[X, Y])
164      if "[" in type_str and type_str.endswith("]"):
165          main_type_str, generics_str = type_str.split("[", 1)
166          generics_str = generics_str[:-1]
167  
168          main_type = deserialize_type(main_type_str)
169          generic_args = [deserialize_type(arg) for arg in _parse_generic_args(generics_str)]
170  
171          # Reconstruct
172          try:
173              return main_type[tuple(generic_args) if len(generic_args) > 1 else generic_args[0]]
174          except (TypeError, AttributeError) as e:
175              raise DeserializationError(f"Could not apply arguments {generic_args} to type {main_type}") from e
176  
177      # Handle non-generic types
178      # First, check if there's a module prefix
179      if "." in type_str:
180          parts = type_str.split(".")
181          module_name = ".".join(parts[:-1])
182          type_name = parts[-1]
183  
184          module = sys.modules.get(module_name)
185          if module is None:
186              try:
187                  module = thread_safe_import(module_name)
188              except ImportError as e:
189                  raise DeserializationError(f"Could not import the module: {module_name}") from e
190  
191          # Get the class from the module
192          if hasattr(module, type_name):
193              return getattr(module, type_name)
194  
195          raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
196  
197      # No module prefix, check builtins and typing
198      # First check builtins
199      if hasattr(builtins, type_str):
200          return getattr(builtins, type_str)
201  
202      # Then check typing
203      if hasattr(typing, type_str):
204          return getattr(typing, type_str)
205  
206      # Special case for NoneType
207      if type_str == "NoneType":
208          return NoneType
209  
210      # Special case for None
211      if type_str == "None":
212          return None
213  
214      raise DeserializationError(f"Could not deserialize type: {type_str}")
215  
216  
217  def thread_safe_import(module_name: str) -> ModuleType:
218      """
219      Import a module in a thread-safe manner.
220  
221      Importing modules in a multi-threaded environment can lead to race conditions.
222      This function ensures that the module is imported in a thread-safe manner without having impact
223      on the performance of the import for single-threaded environments.
224  
225      :param module_name: the module to import
226      """
227      with _import_lock:
228          return importlib.import_module(module_name)