jinja2_chat_extension.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import json 6 from collections.abc import Callable 7 from typing import Any 8 9 from jinja2 import TemplateSyntaxError, nodes 10 from jinja2.ext import Extension 11 from markupsafe import Markup 12 13 from haystack import logging 14 from haystack.dataclasses.chat_message import ( 15 ChatMessage, 16 ChatMessageContentT, 17 ChatRole, 18 FileContent, 19 ImageContent, 20 ReasoningContent, 21 TextContent, 22 ToolCall, 23 ToolCallResult, 24 _deserialize_content_part, 25 _serialize_content_part, 26 ) 27 28 logger = logging.getLogger(__name__) 29 30 START_TAG = "<haystack_content_part>" 31 END_TAG = "</haystack_content_part>" 32 33 ESCAPED_START_TAG = "<haystack_content_part>" 34 ESCAPED_END_TAG = "</haystack_content_part>" 35 36 37 def _escape_sentinel_tags(value: object) -> str: 38 """ 39 Jinja2 `finalize` callback that prevents sentinel tag injection. 40 41 Called automatically on every `{{ }}` expression result during template rendering. 42 Legitimate structured content from the `templatize_part` filter is wrapped in `Markup` and passes. 43 Any other value containing sentinel tags has those tags replaced with harmless HTML entities so that 44 `_parse_content_parts` will not treat them as structured content. 45 """ 46 if isinstance(value, Markup): 47 return value 48 49 return str(value).replace(START_TAG, ESCAPED_START_TAG).replace(END_TAG, ESCAPED_END_TAG) 50 51 52 class ChatMessageExtension(Extension): 53 """ 54 A Jinja2 extension for creating structured chat messages with mixed content types. 55 56 This extension provides a custom `{% message %}` tag that allows creating chat messages 57 with different attributes (role, name, meta) and mixed content types (text, images, etc.). 58 59 Inspired by [Banks](https://github.com/masci/banks). 60 61 Example: 62 ``` 63 {% message role="system" %} 64 You are a helpful assistant. You like to talk with {{user_name}}. 65 {% endmessage %} 66 67 {% message role="user" %} 68 Hello! I am {{user_name}}. Please describe the images. 69 {% for image in images %} 70 {{ image | templatize_part }} 71 {% endfor %} 72 {% endmessage %} 73 ``` 74 75 ### How it works 76 1. The `{% message %}` tag is used to define a chat message. 77 2. The message can contain text and other structured content parts. 78 3. To include a structured content part in the message, the `| templatize_part` filter is used. 79 The filter serializes the content part into a JSON string and wraps it in a `<haystack_content_part>` tag. 80 4. The `_build_chat_message_json` method of the extension parses the message content parts, 81 converts them into a ChatMessage object and serializes it to a JSON string. 82 5. The obtained JSON string is usable in the ChatPromptBuilder component, where templates are rendered to actual 83 ChatMessage objects. 84 """ 85 86 SUPPORTED_ROLES = [role.value for role in ChatRole] 87 88 tags = {"message"} 89 90 def __init__(self, environment: Any) -> None: 91 super().__init__(environment) 92 environment.finalize = _escape_sentinel_tags 93 environment.filters["templatize_part"] = templatize_part 94 95 def parse(self, parser: Any) -> nodes.Node | list[nodes.Node]: 96 """ 97 Parse the message tag and its attributes in the Jinja2 template. 98 99 This method handles the parsing of role (mandatory), name (optional), meta (optional) and message body content. 100 101 :param parser: The Jinja2 parser instance 102 :return: A CallBlock node containing the parsed message configuration 103 :raises TemplateSyntaxError: If an invalid role is provided 104 """ 105 lineno = next(parser.stream).lineno 106 107 # Parse role attribute (mandatory) 108 parser.stream.expect("name:role") 109 parser.stream.expect("assign") 110 role_expr = parser.parse_expression() 111 112 if isinstance(role_expr, nodes.Const): 113 role = role_expr.value 114 if role not in self.SUPPORTED_ROLES: 115 raise TemplateSyntaxError(f"Role must be one of: {', '.join(self.SUPPORTED_ROLES)}", lineno) 116 117 # Parse optional name attribute 118 name_expr = None 119 if parser.stream.current.test("name:name"): 120 parser.stream.skip() 121 parser.stream.expect("assign") 122 name_expr = parser.parse_expression() 123 if not isinstance(name_expr.value, str): 124 raise TemplateSyntaxError("name must be a string", lineno) 125 126 # Parse optional meta attribute 127 meta_expr = None 128 if parser.stream.current.test("name:meta"): 129 parser.stream.skip() 130 parser.stream.expect("assign") 131 meta_expr = parser.parse_expression() 132 if not isinstance(meta_expr, nodes.Dict): 133 raise TemplateSyntaxError("meta must be a dictionary", lineno) 134 135 # Parse message body 136 body = parser.parse_statements(("name:endmessage",), drop_needle=True) 137 138 # Build message node with all parameters 139 return nodes.CallBlock( 140 self.call_method( 141 name="_build_chat_message_json", 142 args=[role_expr, name_expr or nodes.Const(None), meta_expr or nodes.Dict([])], 143 ), 144 [], 145 [], 146 body, 147 ).set_lineno(lineno) 148 149 def _build_chat_message_json(self, role: str, name: str | None, meta: dict, caller: Callable[[], str]) -> str: 150 """ 151 Build a ChatMessage object from template content and serialize it to a JSON string. 152 153 This method is called by Jinja2 when processing a `{% message %}` tag. 154 It takes the rendered content from the template, converts XML blocks into ChatMessageContentT objects, 155 creates a ChatMessage object and serializes it to a JSON string. 156 157 :param role: The role of the message 158 :param name: Optional name for the message sender 159 :param meta: Optional metadata dictionary 160 :param caller: Callable that returns the rendered content 161 :return: A JSON string representation of the ChatMessage object 162 """ 163 164 content = caller() 165 parts = self._parse_content_parts(content) 166 if not parts: 167 raise ValueError( 168 f"Message template produced content that couldn't be parsed into any message parts. " 169 f"Content: '{content!r}'" 170 ) 171 172 chat_message = self._validate_build_chat_message(parts=parts, role=role, meta=meta, name=name) 173 174 return json.dumps(chat_message.to_dict()) + "\n" 175 176 @staticmethod 177 def _parse_content_parts(content: str) -> list[ChatMessageContentT]: 178 """ 179 Parse a string into a sequence of ChatMessageContentT objects. 180 181 This method handles: 182 - Plain text content, converted to TextContent objects 183 - Structured content parts wrapped in `<haystack_content_part>` tags, converted to ChatMessageContentT objects 184 185 :param content: Input string containing mixed text and content parts 186 :return: A list of ChatMessageContentT objects 187 :raises ValueError: If the content is empty or contains only whitespace characters or if a 188 `<haystack_content_part>` tag is found without a matching closing tag. 189 """ 190 if not content.strip(): 191 raise ValueError( 192 f"Message content in template is empty or contains only whitespace characters. Content: {content!r}" 193 ) 194 195 parts: list[ChatMessageContentT] = [] 196 cursor = 0 197 total_length = len(content) 198 199 while cursor < total_length: 200 tag_start = content.find(START_TAG, cursor) 201 202 if tag_start == -1: 203 # No more tags, add remaining text if any 204 remaining_text = content[cursor:].strip() 205 if remaining_text: 206 parts.append(TextContent(text=remaining_text)) 207 break 208 209 # Add text before tag if any 210 if tag_start > cursor: 211 plain_text = content[cursor:tag_start].strip() 212 if plain_text: 213 parts.append(TextContent(text=plain_text)) 214 215 content_start = tag_start + len(START_TAG) 216 tag_end = content.find(END_TAG, content_start) 217 218 if tag_end == -1: 219 raise ValueError( 220 f"Found unclosed <haystack_content_part> tag at position {tag_start}. " 221 f"Content: '{content[tag_start : tag_start + 50]}...'" 222 ) 223 224 json_content = content[content_start:tag_end] 225 data = json.loads(json_content) 226 parts.append(_deserialize_content_part(data)) 227 228 cursor = tag_end + len(END_TAG) 229 230 return parts 231 232 @staticmethod 233 def _validate_build_chat_message( 234 parts: list[ChatMessageContentT], role: str, meta: dict, name: str | None = None 235 ) -> ChatMessage: 236 """ 237 Validate the parts of a chat message and build a ChatMessage object. 238 239 :param parts: Content parts of the message 240 :param role: The role of the message 241 :param meta: The metadata of the message 242 :param name: The optional name of the message 243 :return: A ChatMessage object 244 245 :raises ValueError: If content parts don't allow to build a valid ChatMessage object or the role is not 246 supported 247 """ 248 249 if role == "user": 250 valid_parts = [part for part in parts if isinstance(part, (TextContent, str, ImageContent, FileContent))] 251 if len(parts) != len(valid_parts): 252 raise ValueError( 253 "User message must contain only TextContent, string, ImageContent or FileContent parts." 254 ) 255 return ChatMessage.from_user(meta=meta, name=name, content_parts=valid_parts) 256 257 if role == "system": 258 if not isinstance(parts[0], TextContent): 259 raise ValueError("System message must contain a text part.") 260 text = parts[0].text 261 if len(parts) > 1: 262 raise ValueError("System message must contain only one text part.") 263 return ChatMessage.from_system(meta=meta, name=name, text=text) 264 265 if role == "assistant": 266 texts = [part.text for part in parts if isinstance(part, TextContent)] 267 tool_calls = [part for part in parts if isinstance(part, ToolCall)] 268 reasoning = [part for part in parts if isinstance(part, ReasoningContent)] 269 if len(texts) > 1: 270 raise ValueError("Assistant message must contain one text part at most.") 271 if len(texts) == 0 and len(tool_calls) == 0: 272 raise ValueError("Assistant message must contain at least one text or tool call part.") 273 if len(parts) > len(texts) + len(tool_calls) + len(reasoning): 274 raise ValueError("Assistant message must contain only text, tool call or reasoning parts.") 275 return ChatMessage.from_assistant( 276 meta=meta, 277 name=name, 278 text=texts[0] if texts else None, 279 tool_calls=tool_calls or None, 280 reasoning=reasoning[0] if reasoning else None, 281 ) 282 283 if role == "tool": 284 tool_call_results = [part for part in parts if isinstance(part, ToolCallResult)] 285 if len(tool_call_results) == 0 or len(tool_call_results) > 1 or len(parts) > len(tool_call_results): 286 raise ValueError("Tool message must contain only one tool call result.") 287 288 tool_result = tool_call_results[0].result 289 origin = tool_call_results[0].origin 290 error = tool_call_results[0].error 291 292 return ChatMessage.from_tool(meta=meta, tool_result=tool_result, origin=origin, error=error) 293 294 raise ValueError(f"Unsupported role: {role}") 295 296 297 def templatize_part(value: ChatMessageContentT) -> Markup: 298 """ 299 Jinja filter to convert an ChatMessageContentT object into JSON string wrapped in special XML content tags. 300 301 :param value: The ChatMessageContentT object to convert 302 :return: A JSON string wrapped in special XML content tags marked as safe 303 :raises ValueError: If the value is not an instance of ChatMessageContentT 304 """ 305 return Markup(f"{START_TAG}{json.dumps(_serialize_content_part(value))}{END_TAG}")