prompt_render.py
1 import ast 2 import re 3 from typing import Any 4 from typing import Callable 5 from typing import Dict 6 from typing import Iterable 7 from typing import List 8 from typing import Optional 9 from typing import Set 10 from typing import Tuple 11 from typing import Union 12 13 from evidently.legacy.core import new_id 14 from evidently.llm.utils.blocks import NoopOutputFormat 15 from evidently.llm.utils.blocks import OutputFormatBlock 16 from evidently.llm.utils.blocks import PromptBlock 17 18 placeholders_re = re.compile(r"\{([a-zA-Z0-9_. ]+)}") 19 20 21 def substitute_placeholders(template, mapping): 22 pattern = r"{([^{}]+)}" 23 24 def replacer(match): 25 key = match.group(1) 26 return str(mapping.get(key, match.group(0))) # leave unchanged if not found 27 28 return re.sub(pattern, replacer, template) 29 30 31 class PreparedTemplate: 32 def __init__( 33 self, template: str, placeholders: Optional[Set[str]] = None, output_format: Optional[OutputFormatBlock] = None 34 ): 35 self.template = template 36 if placeholders is None: 37 placeholders = set(placeholders_re.findall(template)) 38 self.placeholders = placeholders 39 self._output_format = output_format 40 41 @property 42 def output_format(self) -> OutputFormatBlock: 43 return self._output_format or NoopOutputFormat() 44 45 @output_format.setter 46 def output_format(self, output_format: OutputFormatBlock): 47 self._output_format = output_format 48 49 @property 50 def has_output_format(self) -> bool: 51 return self._output_format is not None 52 53 def render(self, values: Dict[str, Any]) -> str: 54 return self.template.format(**{p: values[p] for p in self.placeholders}) 55 56 def render_partial(self, values: Dict[str, Any]) -> "PreparedTemplate": 57 ph = {p: values.get(p, f"{{{p}}}") for p in self.placeholders} 58 return PreparedTemplate( 59 template=substitute_placeholders(self.template, ph), placeholders=None, output_format=self.output_format 60 ) 61 62 def __repr__(self) -> str: 63 ph = ", ".join(self.placeholders) 64 return f"PreparedTemplate[{ph}]\n```\n{self.template}\n```" 65 66 67 PromptCommandCallable = Callable[..., Union[PromptBlock, List[PromptBlock]]] 68 69 _prompt_command_registry: Dict[str, PromptCommandCallable] = { 70 "output_json": PromptBlock.json_output, 71 "output_string_list": PromptBlock.string_list_output, 72 "output_string": PromptBlock.string_output, 73 } 74 75 76 def prompt_command(f: Union[str, PromptCommandCallable]): 77 name = f if isinstance(f, str) else f.__name__ 78 79 def dec(func: PromptCommandCallable): 80 _prompt_command_registry[name] = func 81 return func 82 83 return dec(f) if callable(f) else dec 84 85 86 def get_placeholder_var_values(variables: Dict[str, Any], placeholders: Iterable[str]) -> Dict[str, str]: 87 res = {} 88 89 def _get_value(o, path: List[str]): 90 if len(path) == 0: 91 return o 92 key, *path = path 93 return _get_value(getattr(o, key), path) 94 95 def _render(o) -> str: 96 if isinstance(o, list): 97 return "\n".join(_render(o) for o in o) 98 if isinstance(o, PromptBlock): 99 return o.render() 100 return str(o) 101 102 for ph in placeholders: 103 var_name, *var_path = ph.split(".") 104 if var_name not in variables: 105 continue 106 var_value = _get_value(variables[var_name], var_path) 107 res[ph] = _render(var_value) 108 109 return res 110 111 112 class TemplateRenderer: 113 def __init__( 114 self, 115 template: str, 116 holder: Any, 117 variables: Optional[Dict[str, Any]] = None, 118 commands: Optional[Dict[str, PromptCommandCallable]] = None, 119 ): 120 self.template = template 121 self.holder = holder 122 self.vars = variables or {} 123 self.vars["self"] = holder 124 self.commands = commands or _prompt_command_registry 125 126 def add_var(self, name: str, value: Any): 127 self.vars[name] = value 128 129 def add_command(self, name: str, command: PromptCommandCallable): 130 self.commands[name] = command 131 132 @staticmethod 133 def extract_command_calls(template: str): 134 pattern = r"{%\s*(.*?)\s*%}" 135 mapping = {} 136 137 def replacer(match): 138 content = match.group(1) 139 random_key = f"_command_{new_id().hex}" 140 mapping[random_key] = content 141 return f"{{{random_key}}}" 142 143 replaced_string = re.sub(pattern, replacer, template) 144 return replaced_string, mapping 145 146 def prepare(self) -> PreparedTemplate: 147 template, command_calls = self.extract_command_calls(self.template) 148 prepared = PreparedTemplate(template) 149 command_values = {} 150 for name, command in command_calls.items(): 151 blocks = self._parse_command_to_blocks(command) 152 command_values[name] = "\n".join(str(b) for b in blocks) 153 for b in blocks: 154 if isinstance(b, OutputFormatBlock): 155 prepared.output_format = b 156 prepared = prepared.render_partial(command_values) 157 158 var_values = get_placeholder_var_values(self.vars, prepared.placeholders) 159 return prepared.render_partial(var_values) 160 161 def _parse_command_to_blocks(self, cmd: str) -> List[PromptBlock]: 162 func_name, args, kwargs, is_method = self._parse_function_call(cmd) 163 if is_method: 164 func = getattr(self.holder, func_name) 165 result = func(*args, **kwargs) 166 if not isinstance(result, list): 167 result = [result] 168 return result 169 if func_name not in self.commands: 170 raise ValueError( 171 f"Unknown function call `{func_name}`. Available functions: {list(_prompt_command_registry.keys())}" 172 ) 173 result = self.commands[func_name](*args, **kwargs) 174 if not isinstance(result, list): 175 result = [result] 176 return result 177 178 def _parse_function_call(self, call_string) -> Tuple[str, List[str], Dict, bool]: 179 try: 180 node = ast.parse(call_string, mode="eval").body 181 except SyntaxError: 182 raise ValueError("Invalid function call syntax") 183 184 if not isinstance(node, ast.Call): 185 raise ValueError("The string is not a valid function call") 186 187 if isinstance(node.func, ast.Name): 188 is_method = False 189 func_name = node.func.id 190 elif isinstance(node.func, ast.Attribute): 191 is_method = True 192 func_name = node.func.attr 193 else: 194 raise ValueError("Unsupported function call format") 195 196 args = [self._parse_function_call_arg(arg) for arg in node.args] 197 kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in node.keywords} 198 199 return func_name, args, kwargs, is_method 200 201 def _parse_function_call_arg(self, arg_node: ast.expr) -> str: 202 def rec(node: ast.AST) -> List[str]: 203 if isinstance(node, ast.Attribute): 204 return rec(node.value) + [node.attr] 205 if isinstance(node, ast.Name): 206 return [node.id] 207 if isinstance(node, ast.Constant): 208 return [node.value] 209 raise NotImplementedError(f"Cannot parse {node}") 210 211 first, *path = rec(arg_node) 212 if first not in self.vars: 213 if len(path) == 0: 214 return first 215 raise KeyError(f"Variable '{first}' is not defined") 216 obj = self.vars[first] 217 while path: 218 obj = getattr(obj, path[0]) 219 path.pop(0) 220 return obj