callback.py
1 import logging 2 import threading 3 from collections import defaultdict 4 from functools import wraps 5 from typing import Any 6 7 import dspy 8 from dspy.utils.callback import BaseCallback 9 10 import mlflow 11 from mlflow.dspy.constant import FLAVOR_NAME 12 from mlflow.dspy.util import ( 13 log_dspy_lm_state, 14 log_dspy_module_params, 15 sanitize_params, 16 save_dspy_module_state, 17 ) 18 from mlflow.entities import SpanStatusCode, SpanType 19 from mlflow.entities.run_status import RunStatus 20 from mlflow.entities.span_event import SpanEvent 21 from mlflow.exceptions import MlflowException 22 from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey 23 from mlflow.tracing.fluent import start_span_no_context 24 from mlflow.tracing.provider import detach_span_from_context, set_span_in_context 25 from mlflow.tracing.utils import maybe_set_prediction_context 26 from mlflow.tracing.utils.token import SpanWithToken 27 from mlflow.utils import _get_fully_qualified_class_name 28 from mlflow.utils.autologging_utils import ( 29 get_autologging_config, 30 ) 31 from mlflow.version import IS_TRACING_SDK_ONLY 32 33 _logger = logging.getLogger(__name__) 34 _lock = threading.Lock() 35 36 37 def skip_if_trace_disabled(func): 38 @wraps(func) 39 def wrapper(*args, **kwargs): 40 if get_autologging_config(FLAVOR_NAME, "log_traces"): 41 func(*args, **kwargs) 42 43 return wrapper 44 45 46 def _convert_signature(val): 47 # serialization of dspy.Signature is quite slow, so we should convert it to string 48 if isinstance(val, type) and issubclass(val, dspy.Signature): 49 return repr(val) 50 return val 51 52 53 class MlflowCallback(BaseCallback): 54 """Callback for generating MLflow traces for DSPy components""" 55 56 def __init__(self, dependencies_schema: dict[str, Any] | None = None): 57 self._dependencies_schema = dependencies_schema 58 # call_id: (LiveSpan, OTel token) 59 self._call_id_to_span: dict[str, SpanWithToken] = {} 60 self._call_id_to_module: dict[str, Any] = {} 61 62 ###### state management for optimization process ###### 63 # The current callback logic assumes there is no optimization running in parallel. 64 # The state management may not work when multiple optimizations are running in parallel. 65 # optimizer_stack_level is used to determine if the callback is called within compile 66 # we cannot use boolean flag because the callback can be nested 67 self.optimizer_stack_level = 0 68 # call_id: (key, step) 69 self._call_id_to_metric_key: dict[str, tuple[str, int]] = {} 70 self._evaluation_counter = defaultdict(int) 71 self._disabled_eval_call_ids = set() 72 self._eval_runs_started: set[str] = set() 73 74 def set_dependencies_schema(self, dependencies_schema: dict[str, Any]): 75 if self._dependencies_schema: 76 raise MlflowException( 77 "Dependencies schema should be set only once to the callback.", 78 error_code=MlflowException.INVALID_PARAMETER_VALUE, 79 ) 80 self._dependencies_schema = dependencies_schema 81 82 @skip_if_trace_disabled 83 def on_module_start(self, call_id: str, instance: Any, inputs: dict[str, Any]): 84 span_type = self._get_span_type_for_module(instance) 85 attributes = self._get_span_attribute_for_module(instance) 86 87 # The __call__ method of dspy.Module has a signature of (self, *args, **kwargs), 88 # while all built-in modules only accepts keyword arguments. To avoid recording 89 # empty "args" key in the inputs, we remove it if it's empty. 90 if "args" in inputs and not inputs["args"]: 91 inputs.pop("args") 92 93 self._start_span( 94 call_id, 95 name=f"{instance.__class__.__name__}.forward", 96 span_type=span_type, 97 inputs=self._unpack_kwargs(inputs), 98 attributes=attributes, 99 ) 100 self._call_id_to_module[call_id] = instance 101 102 @skip_if_trace_disabled 103 def on_module_end(self, call_id: str, outputs: Any | None, exception: Exception | None = None): 104 instance = self._call_id_to_module.pop(call_id) 105 attributes = {} 106 107 if _get_fully_qualified_class_name(instance) == "dspy.retrieve.databricks_rm.DatabricksRM": 108 from mlflow.entities.document import Document 109 110 if isinstance(outputs, dspy.Prediction): 111 # Convert outputs to MLflow document format to make it compatible with 112 # agent evaluation. 113 num_docs = len(outputs.doc_ids) 114 doc_uris = outputs.doc_uris if outputs.doc_uris is not None else [None] * num_docs 115 outputs = [ 116 Document( 117 page_content=doc_content, 118 metadata={ 119 "doc_id": doc_id, 120 "doc_uri": doc_uri, 121 } 122 | extra_column_dict, 123 id=doc_id, 124 ).to_dict() 125 for doc_content, doc_id, doc_uri, extra_column_dict in zip( 126 outputs.docs, 127 outputs.doc_ids, 128 doc_uris, 129 outputs.extra_columns, 130 ) 131 ] 132 else: 133 # NB: DSPy's Prediction object is a customized dictionary-like object, but its repr 134 # is not easy to read on UI. Therefore, we unpack it to a dictionary. 135 # https://github.com/stanfordnlp/dspy/blob/6fe693528323c9c10c82d90cb26711a985e18b29/dspy/primitives/prediction.py#L21-L28 136 if isinstance(outputs, dspy.Prediction): 137 usage_by_model = ( 138 outputs.get_lm_usage() if hasattr(outputs, "get_lm_usage") else None 139 ) 140 outputs = outputs.toDict() 141 if usage_by_model: 142 usage_data = { 143 TokenUsageKey.INPUT_TOKENS: 0, 144 TokenUsageKey.OUTPUT_TOKENS: 0, 145 TokenUsageKey.TOTAL_TOKENS: 0, 146 } 147 for usage in usage_by_model.values(): 148 usage_data[TokenUsageKey.INPUT_TOKENS] += usage.get("prompt_tokens", 0) 149 usage_data[TokenUsageKey.OUTPUT_TOKENS] += usage.get("completion_tokens", 0) 150 usage_data[TokenUsageKey.TOTAL_TOKENS] += usage.get("total_tokens", 0) 151 attributes[SpanAttributeKey.CHAT_USAGE] = usage_data 152 # TODO: the span may not contain model name so we cannot calculate cost 153 self._end_span(call_id, outputs, exception, attributes) 154 155 @skip_if_trace_disabled 156 def on_lm_start(self, call_id: str, instance: Any, inputs: dict[str, Any]): 157 span_type = ( 158 SpanType.CHAT_MODEL if getattr(instance, "model_type", None) == "chat" else SpanType.LLM 159 ) 160 161 filtered_kwargs = sanitize_params(instance.kwargs) 162 attributes = { 163 **filtered_kwargs, 164 "model": instance.model, 165 "model_type": instance.model_type, 166 "cache": instance.cache, 167 SpanAttributeKey.MESSAGE_FORMAT: "dspy", 168 SpanAttributeKey.MODEL: instance.model, 169 } 170 match instance.model.split("/", 1): 171 case [provider, _]: 172 attributes[SpanAttributeKey.MODEL_PROVIDER] = provider 173 174 inputs = self._unpack_kwargs(inputs) 175 176 self._start_span( 177 call_id, 178 name=f"{instance.__class__.__name__}.__call__", 179 span_type=span_type, 180 inputs=inputs, 181 attributes=attributes, 182 ) 183 184 @skip_if_trace_disabled 185 def on_lm_end(self, call_id: str, outputs: Any | None, exception: Exception | None = None): 186 self._end_span(call_id, outputs, exception) 187 188 @skip_if_trace_disabled 189 def on_adapter_format_start(self, call_id: str, instance: Any, inputs: dict[str, Any]): 190 self._start_span( 191 call_id, 192 name=f"{instance.__class__.__name__}.format", 193 span_type=SpanType.PARSER, 194 inputs=self._unpack_kwargs(inputs), 195 attributes={}, 196 ) 197 198 @skip_if_trace_disabled 199 def on_adapter_format_end( 200 self, call_id: str, outputs: Any | None, exception: Exception | None = None 201 ): 202 self._end_span(call_id, outputs, exception) 203 204 @skip_if_trace_disabled 205 def on_adapter_parse_start(self, call_id: str, instance: Any, inputs: dict[str, Any]): 206 self._start_span( 207 call_id, 208 name=f"{instance.__class__.__name__}.parse", 209 span_type=SpanType.PARSER, 210 inputs=self._unpack_kwargs(inputs), 211 attributes={}, 212 ) 213 214 @skip_if_trace_disabled 215 def on_adapter_parse_end( 216 self, call_id: str, outputs: Any | None, exception: Exception | None = None 217 ): 218 self._end_span(call_id, outputs, exception) 219 220 @skip_if_trace_disabled 221 def on_tool_start(self, call_id: str, instance: Any, inputs: dict[str, Any]): 222 # DSPy uses the special "finish" tool to signal the end of the agent. 223 if instance.name == "finish": 224 return 225 226 inputs = self._unpack_kwargs(inputs) 227 # Tools are always called with keyword arguments only. 228 inputs.pop("args", None) 229 230 self._start_span( 231 call_id, 232 name=f"Tool.{instance.name}", 233 span_type=SpanType.TOOL, 234 inputs=inputs, 235 attributes={ 236 "name": instance.name, 237 "description": instance.desc, 238 "args": instance.args, 239 }, 240 ) 241 242 @skip_if_trace_disabled 243 def on_tool_end(self, call_id: str, outputs: Any | None, exception: Exception | None = None): 244 if call_id in self._call_id_to_span: 245 self._end_span(call_id, outputs, exception) 246 247 def on_evaluate_start(self, call_id: str, instance: Any, inputs: dict[str, Any]): 248 """ 249 Callback handler at the beginning of evaluation call. Available with DSPy>=2.6.9. 250 This callback starts a nested run for each evaluation call inside optimization. 251 If called outside optimization and no active run exists, it creates a new run. 252 """ 253 if not get_autologging_config(FLAVOR_NAME, "log_evals"): 254 return 255 256 key = "eval" 257 if callback_metadata := inputs.get("callback_metadata"): 258 if "metric_key" in callback_metadata: 259 key = callback_metadata["metric_key"] 260 if callback_metadata.get("disable_logging"): 261 self._disabled_eval_call_ids.add(call_id) 262 return 263 started_run = False 264 if self.optimizer_stack_level > 0: 265 with _lock: 266 # we may want to include optimizer_stack_level in the key 267 # to handle nested optimization 268 step = self._evaluation_counter[key] 269 self._evaluation_counter[key] += 1 270 self._call_id_to_metric_key[call_id] = (key, step) 271 mlflow.start_run(run_name=f"{key}_{step}", nested=True) 272 started_run = True 273 elif mlflow.active_run() is None: 274 mlflow.start_run(run_name=key, nested=True) 275 started_run = True 276 277 if started_run: 278 self._eval_runs_started.add(call_id) 279 if program := inputs.get("program"): 280 save_dspy_module_state(program, "model.json") 281 log_dspy_module_params(program) 282 283 # Log the current DSPy LM state 284 log_dspy_lm_state() 285 286 def on_evaluate_end( 287 self, 288 call_id: str, 289 outputs: Any, 290 exception: Exception | None = None, 291 ): 292 """ 293 Callback handler at the end of evaluation call. Available with DSPy>=2.6.9. 294 This callback logs the evaluation score to the individual run 295 and add eval metric to the parent run if called inside optimization. 296 """ 297 if not get_autologging_config(FLAVOR_NAME, "log_evals"): 298 return 299 if call_id in self._disabled_eval_call_ids: 300 self._disabled_eval_call_ids.discard(call_id) 301 return 302 run_started = call_id in self._eval_runs_started 303 if exception: 304 if run_started: 305 mlflow.end_run(status=RunStatus.to_string(RunStatus.FAILED)) 306 self._eval_runs_started.discard(call_id) 307 return 308 score = None 309 if isinstance(outputs, float): 310 score = outputs 311 elif isinstance(outputs, tuple): 312 score = outputs[0] 313 elif isinstance(outputs, dspy.Prediction): 314 score = float(outputs) 315 try: 316 mlflow.log_table(self._generate_result_table(outputs.results), "result_table.json") 317 except Exception: 318 _logger.debug("Failed to log result table.", exc_info=True) 319 if score is not None: 320 mlflow.log_metric("eval", score) 321 322 if run_started: 323 mlflow.end_run() 324 self._eval_runs_started.discard(call_id) 325 # Log the evaluation score to the parent run if called inside optimization 326 if self.optimizer_stack_level > 0 and mlflow.active_run() is not None: 327 if call_id not in self._call_id_to_metric_key: 328 return 329 key, step = self._call_id_to_metric_key.pop(call_id) 330 if score is not None: 331 mlflow.log_metric( 332 key, 333 score, 334 step=step, 335 ) 336 337 def reset(self): 338 self._call_id_to_metric_key: dict[str, tuple[str, int]] = {} 339 self._evaluation_counter = defaultdict(int) 340 self._eval_runs_started = set() 341 342 def _start_span( 343 self, 344 call_id: str, 345 name: str, 346 span_type: SpanType, 347 inputs: dict[str, Any], 348 attributes: dict[str, Any], 349 ): 350 if not IS_TRACING_SDK_ONLY: 351 from mlflow.pyfunc.context import get_prediction_context 352 353 prediction_context = get_prediction_context() 354 if prediction_context and self._dependencies_schema: 355 prediction_context.update(**self._dependencies_schema) 356 else: 357 prediction_context = None 358 359 with maybe_set_prediction_context(prediction_context): 360 span = start_span_no_context( 361 name=name, 362 span_type=span_type, 363 parent_span=mlflow.get_current_active_span(), 364 inputs=inputs, 365 attributes=attributes, 366 ) 367 368 token = set_span_in_context(span) 369 self._call_id_to_span[call_id] = SpanWithToken(span, token) 370 371 return span 372 373 def _end_span( 374 self, 375 call_id: str, 376 outputs: Any | None, 377 exception: Exception | None = None, 378 attributes: dict[str, Any] | None = None, 379 ): 380 st = self._call_id_to_span.pop(call_id, None) 381 382 if not st.span: 383 _logger.warning(f"Failed to end a span. Span not found for call_id: {call_id}") 384 return 385 386 status = SpanStatusCode.OK if exception is None else SpanStatusCode.ERROR 387 388 if exception: 389 st.span.add_event(SpanEvent.from_exception(exception)) 390 391 if attributes: 392 st.span.set_attributes(attributes) 393 394 try: 395 st.span.end(outputs=outputs, status=status) 396 finally: 397 detach_span_from_context(st.token) 398 399 def _get_span_type_for_module(self, instance): 400 if isinstance(instance, dspy.Retrieve): 401 return SpanType.RETRIEVER 402 elif isinstance(instance, dspy.ReAct): 403 return SpanType.AGENT 404 elif isinstance(instance, dspy.Predict): 405 return SpanType.LLM 406 elif isinstance(instance, dspy.Adapter): 407 return SpanType.PARSER 408 else: 409 return SpanType.CHAIN 410 411 def _get_span_attribute_for_module(self, instance): 412 if isinstance(instance, dspy.Predict): 413 return {"signature": instance.signature.signature} 414 elif isinstance(instance, dspy.ChainOfThought): 415 if hasattr(instance, "signature"): 416 signature = instance.signature.signature 417 else: 418 signature = instance.predict.signature.signature 419 420 attributes = {"signature": signature} 421 if hasattr(instance, "extended_signature"): 422 attributes["extended_signature"] = instance.extended_signature.signature 423 return attributes 424 return {} 425 426 def _unpack_kwargs(self, inputs: dict[str, Any]) -> dict[str, Any]: 427 """Unpacks the kwargs from the inputs dictionary""" 428 # NB: Not using pop() to avoid modifying the original inputs dictionary 429 kwargs = inputs.get("kwargs", {}) 430 inputs_wo_kwargs = {k: v for k, v in inputs.items() if k != "kwargs"} 431 merged = inputs_wo_kwargs | kwargs 432 return {k: _convert_signature(v) for k, v in merged.items()} 433 434 def _generate_result_table( 435 self, outputs: list[tuple[dspy.Example, dspy.Prediction, Any]] 436 ) -> dict[str, list[Any]]: 437 result = {"score": []} 438 for i, (example, prediction, score) in enumerate(outputs): 439 for k, v in example.items(): 440 if f"example_{k}" not in result: 441 result[f"example_{k}"] = [None] * i 442 result[f"example_{k}"].append(v) 443 444 for k, v in prediction.items(): 445 if f"pred_{k}" not in result: 446 result[f"pred_{k}"] = [None] * i 447 result[f"pred_{k}"].append(v) 448 449 result["score"].append(score) 450 451 for k, v in result.items(): 452 if len(v) != i + 1: 453 result[k].append(None) 454 455 return result