util.py
1 import json 2 import logging 3 import tempfile 4 from collections import defaultdict 5 from pathlib import Path 6 from typing import Any 7 8 import dspy 9 from dspy import Example 10 11 import mlflow 12 from mlflow.entities import LoggedModelOutput 13 14 _logger = logging.getLogger(__name__) 15 16 EXCLUDE_LM_PARAMS = {"api_key", "api_base", "azure_ad_token", "client_secret", "azure_password"} 17 18 19 def save_dspy_module_state(program, file_name: str = "model.json"): 20 """ 21 Save states of dspy `Module` to a temporary directory and log it as an artifact. 22 23 Args: 24 program: The dspy `Module` to be saved. 25 file_name: The name of the file to save the dspy module state. Default is `model.json`. 26 """ 27 try: 28 with tempfile.TemporaryDirectory() as tmp_dir: 29 path = Path(tmp_dir, file_name) 30 program.save(path) 31 mlflow.log_artifact(path) 32 except Exception as e: 33 _logger.warning(f"Failed to save dspy module state: {e}") 34 35 36 def log_dspy_module_params(program): 37 """ 38 Log the parameters of the dspy `Module` as run parameters. 39 40 Args: 41 program: The dspy `Module` to be logged. 42 """ 43 try: 44 states = program.dump_state() 45 flat_state_dict = _flatten_dspy_module_state( 46 states, exclude_keys=("metadata", "lm", "traces", "train") 47 ) 48 mlflow.log_params({ 49 f"{program.__class__.__name__}.{k}": v for k, v in flat_state_dict.items() 50 }) 51 except Exception as e: 52 _logger.warning(f"Failed to log dspy module params: {e}") 53 54 55 def log_dspy_dataset(dataset: list["Example"], file_name: str): 56 """ 57 Log the DSPy dataset as a table. 58 59 Args: 60 dataset: The dataset to be logged. 61 file_name: The name of the file to save the dataset. 62 """ 63 result = defaultdict(list) 64 try: 65 for example in dataset: 66 for k, v in example.items(): 67 result[k].append(v) 68 mlflow.log_table(result, file_name) 69 except Exception as e: 70 _logger.warning(f"Failed to log dataset: {e}") 71 72 73 def log_dspy_lm_state(): 74 """ 75 Log the current DSPy LM state as run parameters. 76 This logs the language model configuration from dspy.settings.lm as a JSON string. 77 """ 78 try: 79 if dspy.settings.lm is None: 80 return 81 82 lm = dspy.settings.lm 83 84 lm_attributes = sanitize_params(getattr(lm, "kwargs", {})) 85 86 for attr in ["model", "model_type", "cache", "temperature", "max_tokens"]: 87 value = getattr(lm, attr, None) 88 if value is not None: 89 lm_attributes[attr] = value 90 91 if lm_attributes: 92 mlflow.log_param("lm_params", json.dumps(lm_attributes, sort_keys=True)) 93 94 except Exception as e: 95 _logger.warning(f"Failed to log DSPy LM state: {e}") 96 97 98 def _flatten_dspy_module_state( 99 d, parent_key="", sep=".", exclude_keys: set[str] | None = None 100 ) -> dict[str, Any]: 101 """ 102 Flattens a nested dictionary and accumulates the key names. 103 104 Args: 105 d: The dictionary or list to flatten. 106 parent_key: The base key used in recursion. Defaults to "". 107 sep: Separator for nested keys. Defaults to '.'. 108 exclude_keys: Keys to exclude from the flattened dictionary. Defaults to (). 109 110 Returns: 111 dict: A flattened dictionary with accumulated keys. 112 113 Example: 114 >>> _flatten_dspy_module_state({"a": {"b": [5, 6]}}) 115 {'a.b.0': 5, 'a.b.1': 6} 116 """ 117 items: dict[str, Any] = {} 118 119 if isinstance(d, dict): 120 for k, v in d.items(): 121 if exclude_keys and k in exclude_keys: 122 continue 123 new_key = f"{parent_key}{sep}{k}" if parent_key else k 124 if isinstance(v, Example): 125 # Don't flatten Example objects further even if it has dict or list values 126 v = {key: str(value) for key, value in v.items()} 127 items.update(_flatten_dspy_module_state(v, new_key, sep)) 128 elif isinstance(d, list): 129 for i, v in enumerate(d): 130 new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) 131 if isinstance(v, Example): 132 # Don't flatten Example objects further even if it has dict or list values 133 v = {key: str(value) for key, value in v.items()} 134 items.update(_flatten_dspy_module_state(v, new_key, sep)) 135 else: 136 if d is not None: 137 items[parent_key] = d 138 139 return items 140 141 142 def log_dummy_model_outputs(): 143 try: 144 from mlflow.dspy.autolog import FLAVOR_NAME 145 from mlflow.tracking.fluent import _create_logged_model 146 147 run_id = mlflow.active_run().info.run_id 148 logged_model = _create_logged_model(name="dspy", source_run_id=run_id, flavor=FLAVOR_NAME) 149 mlflow.log_outputs(models=[LoggedModelOutput(model_id=logged_model.model_id, step=0)]) 150 except Exception as e: 151 _logger.debug(f"Failed to log a dummy DSPy model outputs: {e}") 152 153 154 def sanitize_params(params: dict[str, Any]) -> dict[str, Any]: 155 """ 156 Sanitize the parameters by removing the sensitive parameters. 157 """ 158 return {k: v for k, v in params.items() if k not in EXCLUDE_LM_PARAMS}