/ haystack / utils / jinja2_chat_extension.py
jinja2_chat_extension.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import json
  6  from collections.abc import Callable
  7  from typing import Any
  8  
  9  from jinja2 import TemplateSyntaxError, nodes
 10  from jinja2.ext import Extension
 11  from markupsafe import Markup
 12  
 13  from haystack import logging
 14  from haystack.dataclasses.chat_message import (
 15      ChatMessage,
 16      ChatMessageContentT,
 17      ChatRole,
 18      FileContent,
 19      ImageContent,
 20      ReasoningContent,
 21      TextContent,
 22      ToolCall,
 23      ToolCallResult,
 24      _deserialize_content_part,
 25      _serialize_content_part,
 26  )
 27  
 28  logger = logging.getLogger(__name__)
 29  
 30  START_TAG = "<haystack_content_part>"
 31  END_TAG = "</haystack_content_part>"
 32  
 33  ESCAPED_START_TAG = "&lt;haystack_content_part&gt;"
 34  ESCAPED_END_TAG = "&lt;/haystack_content_part&gt;"
 35  
 36  
 37  def _escape_sentinel_tags(value: object) -> str:
 38      """
 39      Jinja2 `finalize` callback that prevents sentinel tag injection.
 40  
 41      Called automatically on every `{{ }}` expression result during template rendering.
 42      Legitimate structured content from the `templatize_part` filter is wrapped in `Markup` and passes.
 43      Any other value containing sentinel tags has those tags replaced with harmless HTML entities so that
 44      `_parse_content_parts` will not treat them as structured content.
 45      """
 46      if isinstance(value, Markup):
 47          return value
 48  
 49      return str(value).replace(START_TAG, ESCAPED_START_TAG).replace(END_TAG, ESCAPED_END_TAG)
 50  
 51  
 52  class ChatMessageExtension(Extension):
 53      """
 54      A Jinja2 extension for creating structured chat messages with mixed content types.
 55  
 56      This extension provides a custom `{% message %}` tag that allows creating chat messages
 57      with different attributes (role, name, meta) and mixed content types (text, images, etc.).
 58  
 59      Inspired by [Banks](https://github.com/masci/banks).
 60  
 61      Example:
 62      ```
 63      {% message role="system" %}
 64      You are a helpful assistant. You like to talk with {{user_name}}.
 65      {% endmessage %}
 66  
 67      {% message role="user" %}
 68      Hello! I am {{user_name}}. Please describe the images.
 69      {% for image in images %}
 70      {{ image | templatize_part }}
 71      {% endfor %}
 72      {% endmessage %}
 73      ```
 74  
 75      ### How it works
 76      1. The `{% message %}` tag is used to define a chat message.
 77      2. The message can contain text and other structured content parts.
 78      3. To include a structured content part in the message, the `| templatize_part` filter is used.
 79         The filter serializes the content part into a JSON string and wraps it in a `<haystack_content_part>` tag.
 80      4. The `_build_chat_message_json` method of the extension parses the message content parts,
 81         converts them into a ChatMessage object and serializes it to a JSON string.
 82      5. The obtained JSON string is usable in the ChatPromptBuilder component, where templates are rendered to actual
 83         ChatMessage objects.
 84      """
 85  
 86      SUPPORTED_ROLES = [role.value for role in ChatRole]
 87  
 88      tags = {"message"}
 89  
 90      def __init__(self, environment: Any) -> None:
 91          super().__init__(environment)
 92          environment.finalize = _escape_sentinel_tags
 93          environment.filters["templatize_part"] = templatize_part
 94  
 95      def parse(self, parser: Any) -> nodes.Node | list[nodes.Node]:
 96          """
 97          Parse the message tag and its attributes in the Jinja2 template.
 98  
 99          This method handles the parsing of role (mandatory), name (optional), meta (optional) and message body content.
100  
101          :param parser: The Jinja2 parser instance
102          :return: A CallBlock node containing the parsed message configuration
103          :raises TemplateSyntaxError: If an invalid role is provided
104          """
105          lineno = next(parser.stream).lineno
106  
107          # Parse role attribute (mandatory)
108          parser.stream.expect("name:role")
109          parser.stream.expect("assign")
110          role_expr = parser.parse_expression()
111  
112          if isinstance(role_expr, nodes.Const):
113              role = role_expr.value
114              if role not in self.SUPPORTED_ROLES:
115                  raise TemplateSyntaxError(f"Role must be one of: {', '.join(self.SUPPORTED_ROLES)}", lineno)
116  
117          # Parse optional name attribute
118          name_expr = None
119          if parser.stream.current.test("name:name"):
120              parser.stream.skip()
121              parser.stream.expect("assign")
122              name_expr = parser.parse_expression()
123              if not isinstance(name_expr.value, str):
124                  raise TemplateSyntaxError("name must be a string", lineno)
125  
126          # Parse optional meta attribute
127          meta_expr = None
128          if parser.stream.current.test("name:meta"):
129              parser.stream.skip()
130              parser.stream.expect("assign")
131              meta_expr = parser.parse_expression()
132              if not isinstance(meta_expr, nodes.Dict):
133                  raise TemplateSyntaxError("meta must be a dictionary", lineno)
134  
135          # Parse message body
136          body = parser.parse_statements(("name:endmessage",), drop_needle=True)
137  
138          # Build message node with all parameters
139          return nodes.CallBlock(
140              self.call_method(
141                  name="_build_chat_message_json",
142                  args=[role_expr, name_expr or nodes.Const(None), meta_expr or nodes.Dict([])],
143              ),
144              [],
145              [],
146              body,
147          ).set_lineno(lineno)
148  
149      def _build_chat_message_json(self, role: str, name: str | None, meta: dict, caller: Callable[[], str]) -> str:
150          """
151          Build a ChatMessage object from template content and serialize it to a JSON string.
152  
153          This method is called by Jinja2 when processing a `{% message %}` tag.
154          It takes the rendered content from the template, converts XML blocks into ChatMessageContentT objects,
155          creates a ChatMessage object and serializes it to a JSON string.
156  
157          :param role: The role of the message
158          :param name: Optional name for the message sender
159          :param meta: Optional metadata dictionary
160          :param caller: Callable that returns the rendered content
161          :return: A JSON string representation of the ChatMessage object
162          """
163  
164          content = caller()
165          parts = self._parse_content_parts(content)
166          if not parts:
167              raise ValueError(
168                  f"Message template produced content that couldn't be parsed into any message parts. "
169                  f"Content: '{content!r}'"
170              )
171  
172          chat_message = self._validate_build_chat_message(parts=parts, role=role, meta=meta, name=name)
173  
174          return json.dumps(chat_message.to_dict()) + "\n"
175  
176      @staticmethod
177      def _parse_content_parts(content: str) -> list[ChatMessageContentT]:
178          """
179          Parse a string into a sequence of ChatMessageContentT objects.
180  
181          This method handles:
182          - Plain text content, converted to TextContent objects
183          - Structured content parts wrapped in `<haystack_content_part>` tags, converted to ChatMessageContentT objects
184  
185          :param content: Input string containing mixed text and content parts
186          :return: A list of ChatMessageContentT objects
187          :raises ValueError: If the content is empty or contains only whitespace characters or if a
188                              `<haystack_content_part>` tag is found without a matching closing tag.
189          """
190          if not content.strip():
191              raise ValueError(
192                  f"Message content in template is empty or contains only whitespace characters. Content: {content!r}"
193              )
194  
195          parts: list[ChatMessageContentT] = []
196          cursor = 0
197          total_length = len(content)
198  
199          while cursor < total_length:
200              tag_start = content.find(START_TAG, cursor)
201  
202              if tag_start == -1:
203                  # No more tags, add remaining text if any
204                  remaining_text = content[cursor:].strip()
205                  if remaining_text:
206                      parts.append(TextContent(text=remaining_text))
207                  break
208  
209              # Add text before tag if any
210              if tag_start > cursor:
211                  plain_text = content[cursor:tag_start].strip()
212                  if plain_text:
213                      parts.append(TextContent(text=plain_text))
214  
215              content_start = tag_start + len(START_TAG)
216              tag_end = content.find(END_TAG, content_start)
217  
218              if tag_end == -1:
219                  raise ValueError(
220                      f"Found unclosed <haystack_content_part> tag at position {tag_start}. "
221                      f"Content: '{content[tag_start : tag_start + 50]}...'"
222                  )
223  
224              json_content = content[content_start:tag_end]
225              data = json.loads(json_content)
226              parts.append(_deserialize_content_part(data))
227  
228              cursor = tag_end + len(END_TAG)
229  
230          return parts
231  
232      @staticmethod
233      def _validate_build_chat_message(
234          parts: list[ChatMessageContentT], role: str, meta: dict, name: str | None = None
235      ) -> ChatMessage:
236          """
237          Validate the parts of a chat message and build a ChatMessage object.
238  
239          :param parts: Content parts of the message
240          :param role: The role of the message
241          :param meta: The metadata of the message
242          :param name: The optional name of the message
243          :return: A ChatMessage object
244  
245          :raises ValueError: If content parts don't allow to build a valid ChatMessage object or the role is not
246                              supported
247          """
248  
249          if role == "user":
250              valid_parts = [part for part in parts if isinstance(part, (TextContent, str, ImageContent, FileContent))]
251              if len(parts) != len(valid_parts):
252                  raise ValueError(
253                      "User message must contain only TextContent, string, ImageContent or FileContent parts."
254                  )
255              return ChatMessage.from_user(meta=meta, name=name, content_parts=valid_parts)
256  
257          if role == "system":
258              if not isinstance(parts[0], TextContent):
259                  raise ValueError("System message must contain a text part.")
260              text = parts[0].text
261              if len(parts) > 1:
262                  raise ValueError("System message must contain only one text part.")
263              return ChatMessage.from_system(meta=meta, name=name, text=text)
264  
265          if role == "assistant":
266              texts = [part.text for part in parts if isinstance(part, TextContent)]
267              tool_calls = [part for part in parts if isinstance(part, ToolCall)]
268              reasoning = [part for part in parts if isinstance(part, ReasoningContent)]
269              if len(texts) > 1:
270                  raise ValueError("Assistant message must contain one text part at most.")
271              if len(texts) == 0 and len(tool_calls) == 0:
272                  raise ValueError("Assistant message must contain at least one text or tool call part.")
273              if len(parts) > len(texts) + len(tool_calls) + len(reasoning):
274                  raise ValueError("Assistant message must contain only text, tool call or reasoning parts.")
275              return ChatMessage.from_assistant(
276                  meta=meta,
277                  name=name,
278                  text=texts[0] if texts else None,
279                  tool_calls=tool_calls or None,
280                  reasoning=reasoning[0] if reasoning else None,
281              )
282  
283          if role == "tool":
284              tool_call_results = [part for part in parts if isinstance(part, ToolCallResult)]
285              if len(tool_call_results) == 0 or len(tool_call_results) > 1 or len(parts) > len(tool_call_results):
286                  raise ValueError("Tool message must contain only one tool call result.")
287  
288              tool_result = tool_call_results[0].result
289              origin = tool_call_results[0].origin
290              error = tool_call_results[0].error
291  
292              return ChatMessage.from_tool(meta=meta, tool_result=tool_result, origin=origin, error=error)
293  
294          raise ValueError(f"Unsupported role: {role}")
295  
296  
297  def templatize_part(value: ChatMessageContentT) -> Markup:
298      """
299      Jinja filter to convert an ChatMessageContentT object into JSON string wrapped in special XML content tags.
300  
301      :param value: The ChatMessageContentT object to convert
302      :return: A JSON string wrapped in special XML content tags marked as safe
303      :raises ValueError: If the value is not an instance of ChatMessageContentT
304      """
305      return Markup(f"{START_TAG}{json.dumps(_serialize_content_part(value))}{END_TAG}")