/ mlflow / dspy / callback.py
callback.py
  1  import logging
  2  import threading
  3  from collections import defaultdict
  4  from functools import wraps
  5  from typing import Any
  6  
  7  import dspy
  8  from dspy.utils.callback import BaseCallback
  9  
 10  import mlflow
 11  from mlflow.dspy.constant import FLAVOR_NAME
 12  from mlflow.dspy.util import (
 13      log_dspy_lm_state,
 14      log_dspy_module_params,
 15      sanitize_params,
 16      save_dspy_module_state,
 17  )
 18  from mlflow.entities import SpanStatusCode, SpanType
 19  from mlflow.entities.run_status import RunStatus
 20  from mlflow.entities.span_event import SpanEvent
 21  from mlflow.exceptions import MlflowException
 22  from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey
 23  from mlflow.tracing.fluent import start_span_no_context
 24  from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
 25  from mlflow.tracing.utils import maybe_set_prediction_context
 26  from mlflow.tracing.utils.token import SpanWithToken
 27  from mlflow.utils import _get_fully_qualified_class_name
 28  from mlflow.utils.autologging_utils import (
 29      get_autologging_config,
 30  )
 31  from mlflow.version import IS_TRACING_SDK_ONLY
 32  
 33  _logger = logging.getLogger(__name__)
 34  _lock = threading.Lock()
 35  
 36  
 37  def skip_if_trace_disabled(func):
 38      @wraps(func)
 39      def wrapper(*args, **kwargs):
 40          if get_autologging_config(FLAVOR_NAME, "log_traces"):
 41              func(*args, **kwargs)
 42  
 43      return wrapper
 44  
 45  
 46  def _convert_signature(val):
 47      # serialization of dspy.Signature is quite slow, so we should convert it to string
 48      if isinstance(val, type) and issubclass(val, dspy.Signature):
 49          return repr(val)
 50      return val
 51  
 52  
 53  class MlflowCallback(BaseCallback):
 54      """Callback for generating MLflow traces for DSPy components"""
 55  
 56      def __init__(self, dependencies_schema: dict[str, Any] | None = None):
 57          self._dependencies_schema = dependencies_schema
 58          # call_id: (LiveSpan, OTel token)
 59          self._call_id_to_span: dict[str, SpanWithToken] = {}
 60          self._call_id_to_module: dict[str, Any] = {}
 61  
 62          ###### state management for optimization process ######
 63          # The current callback logic assumes there is no optimization running in parallel.
 64          # The state management may not work when multiple optimizations are running in parallel.
 65          # optimizer_stack_level is used to determine if the callback is called within compile
 66          # we cannot use boolean flag because the callback can be nested
 67          self.optimizer_stack_level = 0
 68          # call_id: (key, step)
 69          self._call_id_to_metric_key: dict[str, tuple[str, int]] = {}
 70          self._evaluation_counter = defaultdict(int)
 71          self._disabled_eval_call_ids = set()
 72          self._eval_runs_started: set[str] = set()
 73  
 74      def set_dependencies_schema(self, dependencies_schema: dict[str, Any]):
 75          if self._dependencies_schema:
 76              raise MlflowException(
 77                  "Dependencies schema should be set only once to the callback.",
 78                  error_code=MlflowException.INVALID_PARAMETER_VALUE,
 79              )
 80          self._dependencies_schema = dependencies_schema
 81  
 82      @skip_if_trace_disabled
 83      def on_module_start(self, call_id: str, instance: Any, inputs: dict[str, Any]):
 84          span_type = self._get_span_type_for_module(instance)
 85          attributes = self._get_span_attribute_for_module(instance)
 86  
 87          # The __call__ method of dspy.Module has a signature of (self, *args, **kwargs),
 88          # while all built-in modules only accepts keyword arguments. To avoid recording
 89          # empty "args" key in the inputs, we remove it if it's empty.
 90          if "args" in inputs and not inputs["args"]:
 91              inputs.pop("args")
 92  
 93          self._start_span(
 94              call_id,
 95              name=f"{instance.__class__.__name__}.forward",
 96              span_type=span_type,
 97              inputs=self._unpack_kwargs(inputs),
 98              attributes=attributes,
 99          )
100          self._call_id_to_module[call_id] = instance
101  
102      @skip_if_trace_disabled
103      def on_module_end(self, call_id: str, outputs: Any | None, exception: Exception | None = None):
104          instance = self._call_id_to_module.pop(call_id)
105          attributes = {}
106  
107          if _get_fully_qualified_class_name(instance) == "dspy.retrieve.databricks_rm.DatabricksRM":
108              from mlflow.entities.document import Document
109  
110              if isinstance(outputs, dspy.Prediction):
111                  # Convert outputs to MLflow document format to make it compatible with
112                  # agent evaluation.
113                  num_docs = len(outputs.doc_ids)
114                  doc_uris = outputs.doc_uris if outputs.doc_uris is not None else [None] * num_docs
115                  outputs = [
116                      Document(
117                          page_content=doc_content,
118                          metadata={
119                              "doc_id": doc_id,
120                              "doc_uri": doc_uri,
121                          }
122                          | extra_column_dict,
123                          id=doc_id,
124                      ).to_dict()
125                      for doc_content, doc_id, doc_uri, extra_column_dict in zip(
126                          outputs.docs,
127                          outputs.doc_ids,
128                          doc_uris,
129                          outputs.extra_columns,
130                      )
131                  ]
132          else:
133              # NB: DSPy's Prediction object is a customized dictionary-like object, but its repr
134              # is not easy to read on UI. Therefore, we unpack it to a dictionary.
135              # https://github.com/stanfordnlp/dspy/blob/6fe693528323c9c10c82d90cb26711a985e18b29/dspy/primitives/prediction.py#L21-L28
136              if isinstance(outputs, dspy.Prediction):
137                  usage_by_model = (
138                      outputs.get_lm_usage() if hasattr(outputs, "get_lm_usage") else None
139                  )
140                  outputs = outputs.toDict()
141                  if usage_by_model:
142                      usage_data = {
143                          TokenUsageKey.INPUT_TOKENS: 0,
144                          TokenUsageKey.OUTPUT_TOKENS: 0,
145                          TokenUsageKey.TOTAL_TOKENS: 0,
146                      }
147                      for usage in usage_by_model.values():
148                          usage_data[TokenUsageKey.INPUT_TOKENS] += usage.get("prompt_tokens", 0)
149                          usage_data[TokenUsageKey.OUTPUT_TOKENS] += usage.get("completion_tokens", 0)
150                          usage_data[TokenUsageKey.TOTAL_TOKENS] += usage.get("total_tokens", 0)
151                      attributes[SpanAttributeKey.CHAT_USAGE] = usage_data
152                      # TODO: the span may not contain model name so we cannot calculate cost
153          self._end_span(call_id, outputs, exception, attributes)
154  
155      @skip_if_trace_disabled
156      def on_lm_start(self, call_id: str, instance: Any, inputs: dict[str, Any]):
157          span_type = (
158              SpanType.CHAT_MODEL if getattr(instance, "model_type", None) == "chat" else SpanType.LLM
159          )
160  
161          filtered_kwargs = sanitize_params(instance.kwargs)
162          attributes = {
163              **filtered_kwargs,
164              "model": instance.model,
165              "model_type": instance.model_type,
166              "cache": instance.cache,
167              SpanAttributeKey.MESSAGE_FORMAT: "dspy",
168              SpanAttributeKey.MODEL: instance.model,
169          }
170          match instance.model.split("/", 1):
171              case [provider, _]:
172                  attributes[SpanAttributeKey.MODEL_PROVIDER] = provider
173  
174          inputs = self._unpack_kwargs(inputs)
175  
176          self._start_span(
177              call_id,
178              name=f"{instance.__class__.__name__}.__call__",
179              span_type=span_type,
180              inputs=inputs,
181              attributes=attributes,
182          )
183  
184      @skip_if_trace_disabled
185      def on_lm_end(self, call_id: str, outputs: Any | None, exception: Exception | None = None):
186          self._end_span(call_id, outputs, exception)
187  
188      @skip_if_trace_disabled
189      def on_adapter_format_start(self, call_id: str, instance: Any, inputs: dict[str, Any]):
190          self._start_span(
191              call_id,
192              name=f"{instance.__class__.__name__}.format",
193              span_type=SpanType.PARSER,
194              inputs=self._unpack_kwargs(inputs),
195              attributes={},
196          )
197  
198      @skip_if_trace_disabled
199      def on_adapter_format_end(
200          self, call_id: str, outputs: Any | None, exception: Exception | None = None
201      ):
202          self._end_span(call_id, outputs, exception)
203  
204      @skip_if_trace_disabled
205      def on_adapter_parse_start(self, call_id: str, instance: Any, inputs: dict[str, Any]):
206          self._start_span(
207              call_id,
208              name=f"{instance.__class__.__name__}.parse",
209              span_type=SpanType.PARSER,
210              inputs=self._unpack_kwargs(inputs),
211              attributes={},
212          )
213  
214      @skip_if_trace_disabled
215      def on_adapter_parse_end(
216          self, call_id: str, outputs: Any | None, exception: Exception | None = None
217      ):
218          self._end_span(call_id, outputs, exception)
219  
220      @skip_if_trace_disabled
221      def on_tool_start(self, call_id: str, instance: Any, inputs: dict[str, Any]):
222          # DSPy uses the special "finish" tool to signal the end of the agent.
223          if instance.name == "finish":
224              return
225  
226          inputs = self._unpack_kwargs(inputs)
227          # Tools are always called with keyword arguments only.
228          inputs.pop("args", None)
229  
230          self._start_span(
231              call_id,
232              name=f"Tool.{instance.name}",
233              span_type=SpanType.TOOL,
234              inputs=inputs,
235              attributes={
236                  "name": instance.name,
237                  "description": instance.desc,
238                  "args": instance.args,
239              },
240          )
241  
242      @skip_if_trace_disabled
243      def on_tool_end(self, call_id: str, outputs: Any | None, exception: Exception | None = None):
244          if call_id in self._call_id_to_span:
245              self._end_span(call_id, outputs, exception)
246  
247      def on_evaluate_start(self, call_id: str, instance: Any, inputs: dict[str, Any]):
248          """
249          Callback handler at the beginning of evaluation call. Available with DSPy>=2.6.9.
250          This callback starts a nested run for each evaluation call inside optimization.
251          If called outside optimization and no active run exists, it creates a new run.
252          """
253          if not get_autologging_config(FLAVOR_NAME, "log_evals"):
254              return
255  
256          key = "eval"
257          if callback_metadata := inputs.get("callback_metadata"):
258              if "metric_key" in callback_metadata:
259                  key = callback_metadata["metric_key"]
260              if callback_metadata.get("disable_logging"):
261                  self._disabled_eval_call_ids.add(call_id)
262                  return
263          started_run = False
264          if self.optimizer_stack_level > 0:
265              with _lock:
266                  # we may want to include optimizer_stack_level in the key
267                  # to handle nested optimization
268                  step = self._evaluation_counter[key]
269                  self._evaluation_counter[key] += 1
270              self._call_id_to_metric_key[call_id] = (key, step)
271              mlflow.start_run(run_name=f"{key}_{step}", nested=True)
272              started_run = True
273          elif mlflow.active_run() is None:
274              mlflow.start_run(run_name=key, nested=True)
275              started_run = True
276  
277          if started_run:
278              self._eval_runs_started.add(call_id)
279          if program := inputs.get("program"):
280              save_dspy_module_state(program, "model.json")
281              log_dspy_module_params(program)
282  
283          # Log the current DSPy LM state
284          log_dspy_lm_state()
285  
286      def on_evaluate_end(
287          self,
288          call_id: str,
289          outputs: Any,
290          exception: Exception | None = None,
291      ):
292          """
293          Callback handler at the end of evaluation call. Available with DSPy>=2.6.9.
294          This callback logs the evaluation score to the individual run
295          and add eval metric to the parent run if called inside optimization.
296          """
297          if not get_autologging_config(FLAVOR_NAME, "log_evals"):
298              return
299          if call_id in self._disabled_eval_call_ids:
300              self._disabled_eval_call_ids.discard(call_id)
301              return
302          run_started = call_id in self._eval_runs_started
303          if exception:
304              if run_started:
305                  mlflow.end_run(status=RunStatus.to_string(RunStatus.FAILED))
306                  self._eval_runs_started.discard(call_id)
307              return
308          score = None
309          if isinstance(outputs, float):
310              score = outputs
311          elif isinstance(outputs, tuple):
312              score = outputs[0]
313          elif isinstance(outputs, dspy.Prediction):
314              score = float(outputs)
315              try:
316                  mlflow.log_table(self._generate_result_table(outputs.results), "result_table.json")
317              except Exception:
318                  _logger.debug("Failed to log result table.", exc_info=True)
319          if score is not None:
320              mlflow.log_metric("eval", score)
321  
322          if run_started:
323              mlflow.end_run()
324              self._eval_runs_started.discard(call_id)
325          # Log the evaluation score to the parent run if called inside optimization
326          if self.optimizer_stack_level > 0 and mlflow.active_run() is not None:
327              if call_id not in self._call_id_to_metric_key:
328                  return
329              key, step = self._call_id_to_metric_key.pop(call_id)
330              if score is not None:
331                  mlflow.log_metric(
332                      key,
333                      score,
334                      step=step,
335                  )
336  
337      def reset(self):
338          self._call_id_to_metric_key: dict[str, tuple[str, int]] = {}
339          self._evaluation_counter = defaultdict(int)
340          self._eval_runs_started = set()
341  
342      def _start_span(
343          self,
344          call_id: str,
345          name: str,
346          span_type: SpanType,
347          inputs: dict[str, Any],
348          attributes: dict[str, Any],
349      ):
350          if not IS_TRACING_SDK_ONLY:
351              from mlflow.pyfunc.context import get_prediction_context
352  
353              prediction_context = get_prediction_context()
354              if prediction_context and self._dependencies_schema:
355                  prediction_context.update(**self._dependencies_schema)
356          else:
357              prediction_context = None
358  
359          with maybe_set_prediction_context(prediction_context):
360              span = start_span_no_context(
361                  name=name,
362                  span_type=span_type,
363                  parent_span=mlflow.get_current_active_span(),
364                  inputs=inputs,
365                  attributes=attributes,
366              )
367  
368          token = set_span_in_context(span)
369          self._call_id_to_span[call_id] = SpanWithToken(span, token)
370  
371          return span
372  
373      def _end_span(
374          self,
375          call_id: str,
376          outputs: Any | None,
377          exception: Exception | None = None,
378          attributes: dict[str, Any] | None = None,
379      ):
380          st = self._call_id_to_span.pop(call_id, None)
381  
382          if not st.span:
383              _logger.warning(f"Failed to end a span. Span not found for call_id: {call_id}")
384              return
385  
386          status = SpanStatusCode.OK if exception is None else SpanStatusCode.ERROR
387  
388          if exception:
389              st.span.add_event(SpanEvent.from_exception(exception))
390  
391          if attributes:
392              st.span.set_attributes(attributes)
393  
394          try:
395              st.span.end(outputs=outputs, status=status)
396          finally:
397              detach_span_from_context(st.token)
398  
399      def _get_span_type_for_module(self, instance):
400          if isinstance(instance, dspy.Retrieve):
401              return SpanType.RETRIEVER
402          elif isinstance(instance, dspy.ReAct):
403              return SpanType.AGENT
404          elif isinstance(instance, dspy.Predict):
405              return SpanType.LLM
406          elif isinstance(instance, dspy.Adapter):
407              return SpanType.PARSER
408          else:
409              return SpanType.CHAIN
410  
411      def _get_span_attribute_for_module(self, instance):
412          if isinstance(instance, dspy.Predict):
413              return {"signature": instance.signature.signature}
414          elif isinstance(instance, dspy.ChainOfThought):
415              if hasattr(instance, "signature"):
416                  signature = instance.signature.signature
417              else:
418                  signature = instance.predict.signature.signature
419  
420              attributes = {"signature": signature}
421              if hasattr(instance, "extended_signature"):
422                  attributes["extended_signature"] = instance.extended_signature.signature
423              return attributes
424          return {}
425  
426      def _unpack_kwargs(self, inputs: dict[str, Any]) -> dict[str, Any]:
427          """Unpacks the kwargs from the inputs dictionary"""
428          # NB: Not using pop() to avoid modifying the original inputs dictionary
429          kwargs = inputs.get("kwargs", {})
430          inputs_wo_kwargs = {k: v for k, v in inputs.items() if k != "kwargs"}
431          merged = inputs_wo_kwargs | kwargs
432          return {k: _convert_signature(v) for k, v in merged.items()}
433  
434      def _generate_result_table(
435          self, outputs: list[tuple[dspy.Example, dspy.Prediction, Any]]
436      ) -> dict[str, list[Any]]:
437          result = {"score": []}
438          for i, (example, prediction, score) in enumerate(outputs):
439              for k, v in example.items():
440                  if f"example_{k}" not in result:
441                      result[f"example_{k}"] = [None] * i
442                  result[f"example_{k}"].append(v)
443  
444              for k, v in prediction.items():
445                  if f"pred_{k}" not in result:
446                      result[f"pred_{k}"] = [None] * i
447                  result[f"pred_{k}"].append(v)
448  
449              result["score"].append(score)
450  
451              for k, v in result.items():
452                  if len(v) != i + 1:
453                      result[k].append(None)
454  
455          return result