/ haystack / components / converters / output_adapter.py
output_adapter.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import ast
  6  import contextlib
  7  from collections.abc import Callable
  8  from typing import Any, TypeAlias
  9  
 10  import jinja2.runtime
 11  from jinja2 import TemplateSyntaxError
 12  from jinja2.nativetypes import NativeEnvironment
 13  from jinja2.sandbox import SandboxedEnvironment
 14  
 15  from haystack import component, default_from_dict, default_to_dict, logging
 16  from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
 17  from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments
 18  
 19  logger = logging.getLogger(__name__)
 20  
 21  
 22  class OutputAdaptationException(Exception):
 23      """Exception raised when there is an error during output adaptation."""
 24  
 25  
 26  @component
 27  class OutputAdapter:
 28      """
 29      Adapts output of a Component using Jinja templates.
 30  
 31      Usage example:
 32      ```python
 33      from haystack import Document
 34      from haystack.components.converters import OutputAdapter
 35  
 36      adapter = OutputAdapter(template="{{ documents[0].content }}", output_type=str)
 37      documents = [Document(content="Test content")]
 38      result = adapter.run(documents=documents)
 39  
 40      assert result["output"] == "Test content"
 41      ```
 42      """
 43  
 44      def __init__(
 45          self,
 46          template: str,
 47          output_type: TypeAlias,
 48          custom_filters: dict[str, Callable] | None = None,
 49          unsafe: bool = False,
 50      ) -> None:
 51          """
 52          Create an OutputAdapter component.
 53  
 54          :param template:
 55              A Jinja template that defines how to adapt the input data.
 56              The variables in the template define the input of this instance.
 57              e.g.
 58              With this template:
 59              ```
 60              {{ documents[0].content }}
 61              ```
 62              The Component input will be `documents`.
 63          :param output_type:
 64              The type of output this instance will return.
 65          :param custom_filters:
 66              A dictionary of custom Jinja filters used in the template.
 67          :param unsafe:
 68              Enable execution of arbitrary code in the Jinja template.
 69              This should only be used if you trust the source of the template as it can be lead to remote code execution.
 70          """
 71          self.custom_filters = {**(custom_filters or {})}
 72          input_types: set[str] = set()
 73  
 74          self._unsafe = unsafe
 75  
 76          if self._unsafe:
 77              msg = (
 78                  "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
 79                  "Use this only if you trust the source of the template."
 80              )
 81              logger.warning(msg)
 82          self._env = (
 83              NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
 84          )
 85  
 86          try:
 87              self._env.parse(template)  # Validate template syntax
 88              self.template = template
 89          except TemplateSyntaxError as e:
 90              raise ValueError(f"Invalid Jinja template '{template}': {e}") from e
 91  
 92          for name, filter_func in self.custom_filters.items():
 93              self._env.filters[name] = filter_func
 94  
 95          # b) extract variables in the template
 96          assigned_variables, template_variables = _extract_template_variables_and_assignments(
 97              env=self._env, template=self.template
 98          )
 99          route_input_names = template_variables - assigned_variables
100          input_types.update(route_input_names)
101  
102          # the env is not needed, discarded automatically
103          component.set_input_types(self, **dict.fromkeys(input_types, Any))
104          component.set_output_types(self, output=output_type)
105          self.output_type = output_type
106  
107      def run(self, **kwargs: Any) -> dict[str, Any]:
108          """
109          Renders the Jinja template with the provided inputs.
110  
111          :param kwargs:
112              Must contain all variables used in the `template` string.
113          :returns:
114              A dictionary with the following keys:
115              - `output`: Rendered Jinja template.
116  
117          :raises OutputAdaptationException: If template rendering fails.
118          """
119          # check if kwargs are empty
120          if not kwargs:
121              raise ValueError("No input data provided for output adaptation")
122          for name, filter_func in self.custom_filters.items():
123              self._env.filters[name] = filter_func
124          adapted_outputs = {}
125          try:
126              adapted_output_template = self._env.from_string(self.template)
127              output_result = adapted_output_template.render(**kwargs)
128              if isinstance(output_result, jinja2.runtime.Undefined):
129                  raise OutputAdaptationException(f"Undefined variable in the template {self.template}; kwargs: {kwargs}")  # noqa: TRY301
130  
131              # We suppress the exception in case the output is already a string, otherwise
132              # we try to evaluate it and would fail.
133              # This must be done cause the output could be different literal structures.
134              # This doesn't support any user types.
135              with contextlib.suppress(Exception):
136                  if not self._unsafe:
137                      output_result = ast.literal_eval(output_result)
138  
139              adapted_outputs["output"] = output_result
140          except Exception as e:
141              raise OutputAdaptationException(f"Error adapting {self.template} with {kwargs}: {e}") from e
142          return adapted_outputs
143  
144      def to_dict(self) -> dict[str, Any]:
145          """
146          Serializes the component to a dictionary.
147  
148          :returns:
149              Dictionary with serialized data.
150          """
151          se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
152          return default_to_dict(
153              self,
154              template=self.template,
155              output_type=serialize_type(self.output_type),
156              custom_filters=se_filters,
157              unsafe=self._unsafe,
158          )
159  
160      @classmethod
161      def from_dict(cls, data: dict[str, Any]) -> "OutputAdapter":
162          """
163          Deserializes the component from a dictionary.
164  
165          :param data:
166              The dictionary to deserialize from.
167          :returns:
168              The deserialized component.
169          """
170          init_params = data.get("init_parameters", {})
171          init_params["output_type"] = deserialize_type(init_params["output_type"])
172  
173          custom_filters = init_params.get("custom_filters", {})
174          if custom_filters:
175              init_params["custom_filters"] = {
176                  name: deserialize_callable(filter_func) if filter_func else None
177                  for name, filter_func in custom_filters.items()
178              }
179          return default_from_dict(cls, data)