/ mlflow / dspy / util.py
util.py
  1  import json
  2  import logging
  3  import tempfile
  4  from collections import defaultdict
  5  from pathlib import Path
  6  from typing import Any
  7  
  8  import dspy
  9  from dspy import Example
 10  
 11  import mlflow
 12  from mlflow.entities import LoggedModelOutput
 13  
 14  _logger = logging.getLogger(__name__)
 15  
 16  EXCLUDE_LM_PARAMS = {"api_key", "api_base", "azure_ad_token", "client_secret", "azure_password"}
 17  
 18  
 19  def save_dspy_module_state(program, file_name: str = "model.json"):
 20      """
 21      Save states of dspy `Module` to a temporary directory and log it as an artifact.
 22  
 23      Args:
 24          program: The dspy `Module` to be saved.
 25          file_name: The name of the file to save the dspy module state. Default is `model.json`.
 26      """
 27      try:
 28          with tempfile.TemporaryDirectory() as tmp_dir:
 29              path = Path(tmp_dir, file_name)
 30              program.save(path)
 31              mlflow.log_artifact(path)
 32      except Exception as e:
 33          _logger.warning(f"Failed to save dspy module state: {e}")
 34  
 35  
 36  def log_dspy_module_params(program):
 37      """
 38      Log the parameters of the dspy `Module` as run parameters.
 39  
 40      Args:
 41          program: The dspy `Module` to be logged.
 42      """
 43      try:
 44          states = program.dump_state()
 45          flat_state_dict = _flatten_dspy_module_state(
 46              states, exclude_keys=("metadata", "lm", "traces", "train")
 47          )
 48          mlflow.log_params({
 49              f"{program.__class__.__name__}.{k}": v for k, v in flat_state_dict.items()
 50          })
 51      except Exception as e:
 52          _logger.warning(f"Failed to log dspy module params: {e}")
 53  
 54  
 55  def log_dspy_dataset(dataset: list["Example"], file_name: str):
 56      """
 57      Log the DSPy dataset as a table.
 58  
 59      Args:
 60          dataset: The dataset to be logged.
 61          file_name: The name of the file to save the dataset.
 62      """
 63      result = defaultdict(list)
 64      try:
 65          for example in dataset:
 66              for k, v in example.items():
 67                  result[k].append(v)
 68          mlflow.log_table(result, file_name)
 69      except Exception as e:
 70          _logger.warning(f"Failed to log dataset: {e}")
 71  
 72  
 73  def log_dspy_lm_state():
 74      """
 75      Log the current DSPy LM state as run parameters.
 76      This logs the language model configuration from dspy.settings.lm as a JSON string.
 77      """
 78      try:
 79          if dspy.settings.lm is None:
 80              return
 81  
 82          lm = dspy.settings.lm
 83  
 84          lm_attributes = sanitize_params(getattr(lm, "kwargs", {}))
 85  
 86          for attr in ["model", "model_type", "cache", "temperature", "max_tokens"]:
 87              value = getattr(lm, attr, None)
 88              if value is not None:
 89                  lm_attributes[attr] = value
 90  
 91          if lm_attributes:
 92              mlflow.log_param("lm_params", json.dumps(lm_attributes, sort_keys=True))
 93  
 94      except Exception as e:
 95          _logger.warning(f"Failed to log DSPy LM state: {e}")
 96  
 97  
 98  def _flatten_dspy_module_state(
 99      d, parent_key="", sep=".", exclude_keys: set[str] | None = None
100  ) -> dict[str, Any]:
101      """
102      Flattens a nested dictionary and accumulates the key names.
103  
104      Args:
105          d: The dictionary or list to flatten.
106          parent_key: The base key used in recursion. Defaults to "".
107          sep: Separator for nested keys. Defaults to '.'.
108          exclude_keys: Keys to exclude from the flattened dictionary. Defaults to ().
109  
110      Returns:
111          dict: A flattened dictionary with accumulated keys.
112  
113      Example:
114          >>> _flatten_dspy_module_state({"a": {"b": [5, 6]}})
115          {'a.b.0': 5, 'a.b.1': 6}
116      """
117      items: dict[str, Any] = {}
118  
119      if isinstance(d, dict):
120          for k, v in d.items():
121              if exclude_keys and k in exclude_keys:
122                  continue
123              new_key = f"{parent_key}{sep}{k}" if parent_key else k
124              if isinstance(v, Example):
125                  # Don't flatten Example objects further even if it has dict or list values
126                  v = {key: str(value) for key, value in v.items()}
127              items.update(_flatten_dspy_module_state(v, new_key, sep))
128      elif isinstance(d, list):
129          for i, v in enumerate(d):
130              new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
131              if isinstance(v, Example):
132                  # Don't flatten Example objects further even if it has dict or list values
133                  v = {key: str(value) for key, value in v.items()}
134              items.update(_flatten_dspy_module_state(v, new_key, sep))
135      else:
136          if d is not None:
137              items[parent_key] = d
138  
139      return items
140  
141  
142  def log_dummy_model_outputs():
143      try:
144          from mlflow.dspy.autolog import FLAVOR_NAME
145          from mlflow.tracking.fluent import _create_logged_model
146  
147          run_id = mlflow.active_run().info.run_id
148          logged_model = _create_logged_model(name="dspy", source_run_id=run_id, flavor=FLAVOR_NAME)
149          mlflow.log_outputs(models=[LoggedModelOutput(model_id=logged_model.model_id, step=0)])
150      except Exception as e:
151          _logger.debug(f"Failed to log a dummy DSPy model outputs: {e}")
152  
153  
154  def sanitize_params(params: dict[str, Any]) -> dict[str, Any]:
155      """
156      Sanitize the parameters by removing the sensitive parameters.
157      """
158      return {k: v for k, v in params.items() if k not in EXCLUDE_LM_PARAMS}