/ src / evidently / llm / utils / prompt_render.py
prompt_render.py
  1  import ast
  2  import re
  3  from typing import Any
  4  from typing import Callable
  5  from typing import Dict
  6  from typing import Iterable
  7  from typing import List
  8  from typing import Optional
  9  from typing import Set
 10  from typing import Tuple
 11  from typing import Union
 12  
 13  from evidently.legacy.core import new_id
 14  from evidently.llm.utils.blocks import NoopOutputFormat
 15  from evidently.llm.utils.blocks import OutputFormatBlock
 16  from evidently.llm.utils.blocks import PromptBlock
 17  
 18  placeholders_re = re.compile(r"\{([a-zA-Z0-9_. ]+)}")
 19  
 20  
 21  def substitute_placeholders(template, mapping):
 22      pattern = r"{([^{}]+)}"
 23  
 24      def replacer(match):
 25          key = match.group(1)
 26          return str(mapping.get(key, match.group(0)))  # leave unchanged if not found
 27  
 28      return re.sub(pattern, replacer, template)
 29  
 30  
 31  class PreparedTemplate:
 32      def __init__(
 33          self, template: str, placeholders: Optional[Set[str]] = None, output_format: Optional[OutputFormatBlock] = None
 34      ):
 35          self.template = template
 36          if placeholders is None:
 37              placeholders = set(placeholders_re.findall(template))
 38          self.placeholders = placeholders
 39          self._output_format = output_format
 40  
 41      @property
 42      def output_format(self) -> OutputFormatBlock:
 43          return self._output_format or NoopOutputFormat()
 44  
 45      @output_format.setter
 46      def output_format(self, output_format: OutputFormatBlock):
 47          self._output_format = output_format
 48  
 49      @property
 50      def has_output_format(self) -> bool:
 51          return self._output_format is not None
 52  
 53      def render(self, values: Dict[str, Any]) -> str:
 54          return self.template.format(**{p: values[p] for p in self.placeholders})
 55  
 56      def render_partial(self, values: Dict[str, Any]) -> "PreparedTemplate":
 57          ph = {p: values.get(p, f"{{{p}}}") for p in self.placeholders}
 58          return PreparedTemplate(
 59              template=substitute_placeholders(self.template, ph), placeholders=None, output_format=self.output_format
 60          )
 61  
 62      def __repr__(self) -> str:
 63          ph = ", ".join(self.placeholders)
 64          return f"PreparedTemplate[{ph}]\n```\n{self.template}\n```"
 65  
 66  
 67  PromptCommandCallable = Callable[..., Union[PromptBlock, List[PromptBlock]]]
 68  
 69  _prompt_command_registry: Dict[str, PromptCommandCallable] = {
 70      "output_json": PromptBlock.json_output,
 71      "output_string_list": PromptBlock.string_list_output,
 72      "output_string": PromptBlock.string_output,
 73  }
 74  
 75  
 76  def prompt_command(f: Union[str, PromptCommandCallable]):
 77      name = f if isinstance(f, str) else f.__name__
 78  
 79      def dec(func: PromptCommandCallable):
 80          _prompt_command_registry[name] = func
 81          return func
 82  
 83      return dec(f) if callable(f) else dec
 84  
 85  
 86  def get_placeholder_var_values(variables: Dict[str, Any], placeholders: Iterable[str]) -> Dict[str, str]:
 87      res = {}
 88  
 89      def _get_value(o, path: List[str]):
 90          if len(path) == 0:
 91              return o
 92          key, *path = path
 93          return _get_value(getattr(o, key), path)
 94  
 95      def _render(o) -> str:
 96          if isinstance(o, list):
 97              return "\n".join(_render(o) for o in o)
 98          if isinstance(o, PromptBlock):
 99              return o.render()
100          return str(o)
101  
102      for ph in placeholders:
103          var_name, *var_path = ph.split(".")
104          if var_name not in variables:
105              continue
106          var_value = _get_value(variables[var_name], var_path)
107          res[ph] = _render(var_value)
108  
109      return res
110  
111  
112  class TemplateRenderer:
113      def __init__(
114          self,
115          template: str,
116          holder: Any,
117          variables: Optional[Dict[str, Any]] = None,
118          commands: Optional[Dict[str, PromptCommandCallable]] = None,
119      ):
120          self.template = template
121          self.holder = holder
122          self.vars = variables or {}
123          self.vars["self"] = holder
124          self.commands = commands or _prompt_command_registry
125  
126      def add_var(self, name: str, value: Any):
127          self.vars[name] = value
128  
129      def add_command(self, name: str, command: PromptCommandCallable):
130          self.commands[name] = command
131  
132      @staticmethod
133      def extract_command_calls(template: str):
134          pattern = r"{%\s*(.*?)\s*%}"
135          mapping = {}
136  
137          def replacer(match):
138              content = match.group(1)
139              random_key = f"_command_{new_id().hex}"
140              mapping[random_key] = content
141              return f"{{{random_key}}}"
142  
143          replaced_string = re.sub(pattern, replacer, template)
144          return replaced_string, mapping
145  
146      def prepare(self) -> PreparedTemplate:
147          template, command_calls = self.extract_command_calls(self.template)
148          prepared = PreparedTemplate(template)
149          command_values = {}
150          for name, command in command_calls.items():
151              blocks = self._parse_command_to_blocks(command)
152              command_values[name] = "\n".join(str(b) for b in blocks)
153              for b in blocks:
154                  if isinstance(b, OutputFormatBlock):
155                      prepared.output_format = b
156          prepared = prepared.render_partial(command_values)
157  
158          var_values = get_placeholder_var_values(self.vars, prepared.placeholders)
159          return prepared.render_partial(var_values)
160  
161      def _parse_command_to_blocks(self, cmd: str) -> List[PromptBlock]:
162          func_name, args, kwargs, is_method = self._parse_function_call(cmd)
163          if is_method:
164              func = getattr(self.holder, func_name)
165              result = func(*args, **kwargs)
166              if not isinstance(result, list):
167                  result = [result]
168              return result
169          if func_name not in self.commands:
170              raise ValueError(
171                  f"Unknown function call `{func_name}`. Available functions: {list(_prompt_command_registry.keys())}"
172              )
173          result = self.commands[func_name](*args, **kwargs)
174          if not isinstance(result, list):
175              result = [result]
176          return result
177  
178      def _parse_function_call(self, call_string) -> Tuple[str, List[str], Dict, bool]:
179          try:
180              node = ast.parse(call_string, mode="eval").body
181          except SyntaxError:
182              raise ValueError("Invalid function call syntax")
183  
184          if not isinstance(node, ast.Call):
185              raise ValueError("The string is not a valid function call")
186  
187          if isinstance(node.func, ast.Name):
188              is_method = False
189              func_name = node.func.id
190          elif isinstance(node.func, ast.Attribute):
191              is_method = True
192              func_name = node.func.attr
193          else:
194              raise ValueError("Unsupported function call format")
195  
196          args = [self._parse_function_call_arg(arg) for arg in node.args]
197          kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in node.keywords}
198  
199          return func_name, args, kwargs, is_method
200  
201      def _parse_function_call_arg(self, arg_node: ast.expr) -> str:
202          def rec(node: ast.AST) -> List[str]:
203              if isinstance(node, ast.Attribute):
204                  return rec(node.value) + [node.attr]
205              if isinstance(node, ast.Name):
206                  return [node.id]
207              if isinstance(node, ast.Constant):
208                  return [node.value]
209              raise NotImplementedError(f"Cannot parse {node}")
210  
211          first, *path = rec(arg_node)
212          if first not in self.vars:
213              if len(path) == 0:
214                  return first
215              raise KeyError(f"Variable '{first}' is not defined")
216          obj = self.vars[first]
217          while path:
218              obj = getattr(obj, path[0])
219              path.pop(0)
220          return obj