output_adapter.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import ast 6 import contextlib 7 from collections.abc import Callable 8 from typing import Any, TypeAlias 9 10 import jinja2.runtime 11 from jinja2 import TemplateSyntaxError 12 from jinja2.nativetypes import NativeEnvironment 13 from jinja2.sandbox import SandboxedEnvironment 14 15 from haystack import component, default_from_dict, default_to_dict, logging 16 from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type 17 from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments 18 19 logger = logging.getLogger(__name__) 20 21 22 class OutputAdaptationException(Exception): 23 """Exception raised when there is an error during output adaptation.""" 24 25 26 @component 27 class OutputAdapter: 28 """ 29 Adapts output of a Component using Jinja templates. 30 31 Usage example: 32 ```python 33 from haystack import Document 34 from haystack.components.converters import OutputAdapter 35 36 adapter = OutputAdapter(template="{{ documents[0].content }}", output_type=str) 37 documents = [Document(content="Test content")] 38 result = adapter.run(documents=documents) 39 40 assert result["output"] == "Test content" 41 ``` 42 """ 43 44 def __init__( 45 self, 46 template: str, 47 output_type: TypeAlias, 48 custom_filters: dict[str, Callable] | None = None, 49 unsafe: bool = False, 50 ) -> None: 51 """ 52 Create an OutputAdapter component. 53 54 :param template: 55 A Jinja template that defines how to adapt the input data. 56 The variables in the template define the input of this instance. 57 e.g. 58 With this template: 59 ``` 60 {{ documents[0].content }} 61 ``` 62 The Component input will be `documents`. 63 :param output_type: 64 The type of output this instance will return. 65 :param custom_filters: 66 A dictionary of custom Jinja filters used in the template. 67 :param unsafe: 68 Enable execution of arbitrary code in the Jinja template. 69 This should only be used if you trust the source of the template as it can be lead to remote code execution. 70 """ 71 self.custom_filters = {**(custom_filters or {})} 72 input_types: set[str] = set() 73 74 self._unsafe = unsafe 75 76 if self._unsafe: 77 msg = ( 78 "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. " 79 "Use this only if you trust the source of the template." 80 ) 81 logger.warning(msg) 82 self._env = ( 83 NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined) 84 ) 85 86 try: 87 self._env.parse(template) # Validate template syntax 88 self.template = template 89 except TemplateSyntaxError as e: 90 raise ValueError(f"Invalid Jinja template '{template}': {e}") from e 91 92 for name, filter_func in self.custom_filters.items(): 93 self._env.filters[name] = filter_func 94 95 # b) extract variables in the template 96 assigned_variables, template_variables = _extract_template_variables_and_assignments( 97 env=self._env, template=self.template 98 ) 99 route_input_names = template_variables - assigned_variables 100 input_types.update(route_input_names) 101 102 # the env is not needed, discarded automatically 103 component.set_input_types(self, **dict.fromkeys(input_types, Any)) 104 component.set_output_types(self, output=output_type) 105 self.output_type = output_type 106 107 def run(self, **kwargs: Any) -> dict[str, Any]: 108 """ 109 Renders the Jinja template with the provided inputs. 110 111 :param kwargs: 112 Must contain all variables used in the `template` string. 113 :returns: 114 A dictionary with the following keys: 115 - `output`: Rendered Jinja template. 116 117 :raises OutputAdaptationException: If template rendering fails. 118 """ 119 # check if kwargs are empty 120 if not kwargs: 121 raise ValueError("No input data provided for output adaptation") 122 for name, filter_func in self.custom_filters.items(): 123 self._env.filters[name] = filter_func 124 adapted_outputs = {} 125 try: 126 adapted_output_template = self._env.from_string(self.template) 127 output_result = adapted_output_template.render(**kwargs) 128 if isinstance(output_result, jinja2.runtime.Undefined): 129 raise OutputAdaptationException(f"Undefined variable in the template {self.template}; kwargs: {kwargs}") # noqa: TRY301 130 131 # We suppress the exception in case the output is already a string, otherwise 132 # we try to evaluate it and would fail. 133 # This must be done cause the output could be different literal structures. 134 # This doesn't support any user types. 135 with contextlib.suppress(Exception): 136 if not self._unsafe: 137 output_result = ast.literal_eval(output_result) 138 139 adapted_outputs["output"] = output_result 140 except Exception as e: 141 raise OutputAdaptationException(f"Error adapting {self.template} with {kwargs}: {e}") from e 142 return adapted_outputs 143 144 def to_dict(self) -> dict[str, Any]: 145 """ 146 Serializes the component to a dictionary. 147 148 :returns: 149 Dictionary with serialized data. 150 """ 151 se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()} 152 return default_to_dict( 153 self, 154 template=self.template, 155 output_type=serialize_type(self.output_type), 156 custom_filters=se_filters, 157 unsafe=self._unsafe, 158 ) 159 160 @classmethod 161 def from_dict(cls, data: dict[str, Any]) -> "OutputAdapter": 162 """ 163 Deserializes the component from a dictionary. 164 165 :param data: 166 The dictionary to deserialize from. 167 :returns: 168 The deserialized component. 169 """ 170 init_params = data.get("init_parameters", {}) 171 init_params["output_type"] = deserialize_type(init_params["output_type"]) 172 173 custom_filters = init_params.get("custom_filters", {}) 174 if custom_filters: 175 init_params["custom_filters"] = { 176 name: deserialize_callable(filter_func) if filter_func else None 177 for name, filter_func in custom_filters.items() 178 } 179 return default_from_dict(cls, data)