type_serialization.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import builtins 6 import importlib 7 import inspect 8 import sys 9 import typing 10 from threading import Lock 11 from types import GenericAlias, ModuleType, NoneType, UnionType 12 from typing import Any, Union, get_args 13 14 from haystack.core.errors import DeserializationError 15 16 _import_lock = Lock() 17 18 19 def _is_union_type(target: Any) -> bool: 20 """ 21 Check if target is a Union type. 22 23 This handles both `typing.Union[X, Y]` and `X | Y` syntax from PEP 604, 24 including parameterized types like `Optional[str]`. 25 """ 26 if target is Union or target is UnionType: 27 return True 28 origin = typing.get_origin(target) 29 return origin is Union or origin is UnionType 30 31 32 def _build_pep604_union_type(types: list[type | UnionType]) -> type | UnionType: 33 """Build a union type from a list of types using PEP 604 syntax (X | Y).""" 34 result = types[0] 35 for t in types[1:]: 36 result = result | t 37 return result 38 39 40 def serialize_type(target: Any) -> str: 41 """ 42 Serializes a type or an instance to its string representation, including the module name. 43 44 This function handles types, instances of types, and special typing objects. 45 It assumes that non-typing objects will have a '__name__' attribute. 46 47 :param target: 48 The object to serialize, can be an instance or a type. 49 :return: 50 The string representation of the type. 51 """ 52 if target is NoneType: 53 return "None" 54 55 args = get_args(target) 56 57 if isinstance(target, UnionType): 58 return " | ".join([serialize_type(a) for a in args]) 59 60 name = getattr(target, "__name__", str(target)) 61 if name.startswith("typing."): 62 name = name[7:] 63 if "[" in name: 64 name = name.split("[")[0] 65 66 # Get module name 67 module = inspect.getmodule(target) 68 module_name = "" 69 # We omit the module name for builtins to not clutter the output 70 if module and hasattr(module, "__name__") and module.__name__ != "builtins": 71 module_name = f"{module.__name__}" 72 73 if args: 74 # For typing generics, convert PEP 604 union types (X | Y) to typing.Union when serializing. 75 # This avoids issues with Python's internal cache, where List[Union[str, int]] and List[str | int] are treated 76 # as the same key. GenericAlias (builtins like list[...]) can keep the PEP 604 syntax. 77 is_typing_generic = not isinstance(target, GenericAlias) 78 args_str = ", ".join( 79 serialize_type(Union[tuple(get_args(a))] if is_typing_generic and isinstance(a, UnionType) else a) # noqa: UP007 80 for a in args 81 if a is not NoneType 82 ) 83 return f"{module_name}.{name}[{args_str}]" if module_name else f"{name}[{args_str}]" 84 85 return f"{module_name}.{name}" if module_name else f"{name}" 86 87 88 def _parse_generic_args(args_str: str) -> list[str]: 89 args = [] 90 bracket_count = 0 91 current_arg = "" 92 93 for char in args_str: 94 if char == "[": 95 bracket_count += 1 96 elif char == "]": 97 bracket_count -= 1 98 99 if char == "," and bracket_count == 0: 100 args.append(current_arg.strip()) 101 current_arg = "" 102 else: 103 current_arg += char 104 105 if current_arg: 106 args.append(current_arg.strip()) 107 108 return args 109 110 111 def _parse_pep604_union_args(union_str: str) -> list[str]: 112 """ 113 Parse a PEP 604 union string (e.g., "str | int | None") into individual type strings. 114 115 Handles nested generics properly, e.g., "list[str] | dict[str, int] | None". 116 117 :param union_str: The union string to parse 118 :returns: A list of individual type strings 119 """ 120 args = [] 121 bracket_count = 0 122 current_arg = "" 123 124 for char in union_str: 125 if char == "[": 126 bracket_count += 1 127 elif char == "]": 128 bracket_count -= 1 129 130 if char == "|" and bracket_count == 0: 131 args.append(current_arg.strip()) 132 current_arg = "" 133 else: 134 current_arg += char 135 136 if current_arg.strip(): 137 args.append(current_arg.strip()) 138 139 return args 140 141 142 def deserialize_type(type_str: str) -> Any: 143 """ 144 Deserializes a type given its full import path as a string, including nested generic types. 145 146 This function will dynamically import the module if it's not already imported 147 and then retrieve the type object from it. It also handles nested generic types like 148 `list[dict[int, str]]`. 149 150 :param type_str: 151 The string representation of the type's full import path. 152 :returns: 153 The deserialized type object. 154 :raises DeserializationError: 155 If the type cannot be deserialized due to missing module or type. 156 """ 157 # Handle PEP 604 union syntax at the top level (e.g., "str | int", "str | None") 158 pep604_union_args = _parse_pep604_union_args(type_str) 159 if len(pep604_union_args) > 1: 160 deserialized_args = [deserialize_type(arg) for arg in pep604_union_args] 161 return _build_pep604_union_type(deserialized_args) 162 163 # Handle generics (including Union[X, Y]) 164 if "[" in type_str and type_str.endswith("]"): 165 main_type_str, generics_str = type_str.split("[", 1) 166 generics_str = generics_str[:-1] 167 168 main_type = deserialize_type(main_type_str) 169 generic_args = [deserialize_type(arg) for arg in _parse_generic_args(generics_str)] 170 171 # Reconstruct 172 try: 173 return main_type[tuple(generic_args) if len(generic_args) > 1 else generic_args[0]] 174 except (TypeError, AttributeError) as e: 175 raise DeserializationError(f"Could not apply arguments {generic_args} to type {main_type}") from e 176 177 # Handle non-generic types 178 # First, check if there's a module prefix 179 if "." in type_str: 180 parts = type_str.split(".") 181 module_name = ".".join(parts[:-1]) 182 type_name = parts[-1] 183 184 module = sys.modules.get(module_name) 185 if module is None: 186 try: 187 module = thread_safe_import(module_name) 188 except ImportError as e: 189 raise DeserializationError(f"Could not import the module: {module_name}") from e 190 191 # Get the class from the module 192 if hasattr(module, type_name): 193 return getattr(module, type_name) 194 195 raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}") 196 197 # No module prefix, check builtins and typing 198 # First check builtins 199 if hasattr(builtins, type_str): 200 return getattr(builtins, type_str) 201 202 # Then check typing 203 if hasattr(typing, type_str): 204 return getattr(typing, type_str) 205 206 # Special case for NoneType 207 if type_str == "NoneType": 208 return NoneType 209 210 # Special case for None 211 if type_str == "None": 212 return None 213 214 raise DeserializationError(f"Could not deserialize type: {type_str}") 215 216 217 def thread_safe_import(module_name: str) -> ModuleType: 218 """ 219 Import a module in a thread-safe manner. 220 221 Importing modules in a multi-threaded environment can lead to race conditions. 222 This function ensures that the module is imported in a thread-safe manner without having impact 223 on the performance of the import for single-threaded environments. 224 225 :param module_name: the module to import 226 """ 227 with _import_lock: 228 return importlib.import_module(module_name)