save.py
1 """Functions for saving DSPY models to MLflow.""" 2 3 import json 4 import logging 5 import os 6 from pathlib import Path 7 from typing import Any 8 9 import cloudpickle 10 import yaml 11 from packaging.version import Version 12 13 import mlflow 14 from mlflow import pyfunc 15 from mlflow.dspy.constant import FLAVOR_NAME 16 from mlflow.dspy.wrapper import DspyChatModelWrapper, DspyModelWrapper 17 from mlflow.entities.model_registry.prompt import Prompt 18 from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException 19 from mlflow.models import ( 20 Model, 21 ModelInputExample, 22 ModelSignature, 23 infer_pip_requirements, 24 ) 25 from mlflow.models.dependencies_schemas import _get_dependencies_schemas 26 from mlflow.models.model import MLMODEL_FILE_NAME 27 from mlflow.models.rag_signatures import SIGNATURE_FOR_LLM_INFERENCE_TASK 28 from mlflow.models.resources import Resource, _ResourceBuilder 29 from mlflow.models.signature import _infer_signature_from_input_example 30 from mlflow.models.utils import _save_example 31 from mlflow.tracing.provider import trace_disabled 32 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 33 from mlflow.types.schema import DataType 34 from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring 35 from mlflow.utils.environment import ( 36 _CONDA_ENV_FILE_NAME, 37 _CONSTRAINTS_FILE_NAME, 38 _PYTHON_ENV_FILE_NAME, 39 _REQUIREMENTS_FILE_NAME, 40 _mlflow_conda_env, 41 _process_conda_env, 42 _process_pip_requirements, 43 _PythonEnv, 44 ) 45 from mlflow.utils.file_utils import get_total_file_size, write_to 46 from mlflow.utils.model_utils import ( 47 _validate_and_copy_code_paths, 48 _validate_and_prepare_target_save_path, 49 ) 50 from mlflow.utils.requirements_utils import _get_pinned_requirement 51 52 _MODEL_SAVE_PATH = "model" 53 _MODEL_DATA_PATH = "data" 54 _MODEL_CONFIG_FILE_NAME = "model_config.json" 55 _DSPY_SETTINGS_FILE_NAME = "dspy_config.pkl" 56 _DSPY_RM_FILE_NAME = "dspy_rm.pkl" 57 58 _logger = logging.getLogger(__name__) 59 60 61 def get_default_pip_requirements(): 62 """ 63 Returns: 64 A list of default pip requirements for MLflow Models produced by Dspy flavor. Calls to 65 `save_model()` and `log_model()` produce a pip environment that, at minimum, contains these 66 requirements. 67 """ 68 return [_get_pinned_requirement("dspy")] 69 70 71 def get_default_conda_env(): 72 """ 73 Returns: 74 The default Conda environment for MLflow Models produced by calls to `save_model()` and 75 `log_model()`. 76 """ 77 return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements()) 78 79 80 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) 81 @trace_disabled # Suppress traces for internal predict calls while logging model 82 def save_model( 83 model, 84 path: str, 85 task: str | None = None, 86 model_config: dict[str, Any] | None = None, 87 code_paths: list[str] | None = None, 88 mlflow_model: Model | None = None, 89 conda_env: list[str] | str | None = None, 90 signature: ModelSignature | None = None, 91 input_example: ModelInputExample | None = None, 92 pip_requirements: list[str] | str | None = None, 93 extra_pip_requirements: list[str] | str | None = None, 94 metadata: dict[str, Any] | None = None, 95 resources: str | Path | list[Resource] | None = None, 96 use_dspy_model_save: bool = False, 97 ): 98 """ 99 Save a Dspy model. 100 101 This method saves a Dspy model along with metadata such as model signature and conda 102 environments to local file system. This method is called inside `mlflow.dspy.log_model()`. 103 104 Args: 105 model: an instance of `dspy.Module`. The Dspy model/module to be saved. 106 path: local path where the MLflow model is to be saved. 107 task: defaults to None. The task type of the model. Can only be `llm/v1/chat` or None for 108 now. 109 model_config: keyword arguments to be passed to the Dspy Module at instantiation. 110 code_paths: {{ code_paths }} 111 mlflow_model: an instance of `mlflow.models.Model`, defaults to None. MLflow model 112 configuration to which to add the Dspy model metadata. If None, a blank instance will 113 be created. 114 conda_env: {{ conda_env }} 115 signature: {{ signature }} 116 input_example: {{ input_example }} 117 pip_requirements: {{ pip_requirements }} 118 extra_pip_requirements: {{ extra_pip_requirements }} 119 metadata: {{ metadata }} 120 resources: A list of model resources or a resources.yaml file containing a list of 121 resources required to serve the model. 122 use_dspy_model_save: Whether to save the Dspy model by dspy builtin `dspy.Module.save` 123 method. 124 """ 125 126 import dspy 127 128 from mlflow.transformers.llm_inference_utils import ( 129 _LLM_INFERENCE_TASK_KEY, 130 _METADATA_LLM_INFERENCE_TASK_KEY, 131 ) 132 from mlflow.utils.databricks_utils import is_in_databricks_runtime 133 134 if signature: 135 num_inputs = len(signature.inputs.inputs) 136 if num_inputs == 0: 137 raise MlflowException( 138 "The model signature's input schema must contain at least one field.", 139 error_code=INVALID_PARAMETER_VALUE, 140 ) 141 if task and task not in SIGNATURE_FOR_LLM_INFERENCE_TASK: 142 raise MlflowException( 143 "Invalid task: {task} at `mlflow.dspy.save_model()` call. The task must be None or one " 144 f"of: {list(SIGNATURE_FOR_LLM_INFERENCE_TASK.keys())}", 145 error_code=INVALID_PARAMETER_VALUE, 146 ) 147 if not use_dspy_model_save and not is_in_databricks_runtime(): 148 _logger.warning( 149 "Saving DSPy model by Pickle or CloudPickle format requires exercising " 150 "caution because these formats rely on Python's object serialization mechanism, " 151 "which can execute arbitrary code during deserialization." 152 "The recommended alternative is to set 'use_dspy_model_save' to True " 153 "(requiring dspy >= 3.1.0) to save the " 154 "DSPy model using the DSPy builtin saving method." 155 ) 156 157 if mlflow_model is None: 158 mlflow_model = Model() 159 if signature is not None: 160 mlflow_model.signature = signature 161 saved_example = None 162 if input_example is not None: 163 path = os.path.abspath(path) 164 _validate_and_prepare_target_save_path(path) 165 saved_example = _save_example(mlflow_model, input_example, path) 166 if metadata is not None: 167 mlflow_model.metadata = metadata 168 169 with _get_dependencies_schemas() as dependencies_schemas: 170 schema = dependencies_schemas.to_dict() 171 if schema is not None: 172 if mlflow_model.metadata is None: 173 mlflow_model.metadata = {} 174 mlflow_model.metadata.update(schema) 175 176 model_data_subpath = _MODEL_DATA_PATH 177 # Construct new data folder in existing path. 178 data_path = os.path.join(path, model_data_subpath) 179 os.makedirs(data_path, exist_ok=True) 180 model_subpath = os.path.join(model_data_subpath, _MODEL_SAVE_PATH) 181 if not use_dspy_model_save: 182 # Set the model path to end with ".pkl" as we use cloudpickle for serialization. 183 model_subpath += ".pkl" 184 185 model_path = os.path.join(path, model_subpath) 186 187 if use_dspy_model_save: 188 if Version(dspy.__version__) <= Version("3.1.0"): 189 raise MlflowException( 190 "'use_dspy_model_save' option is only supported for DSPy version > 3.1.0." 191 ) 192 os.makedirs(model_path, exist_ok=True) 193 194 # Dspy has a global context `dspy.settings`, and we need to save it along with the model. 195 dspy_settings = dict(dspy.settings.config) 196 197 # Don't save the trace in the model, which is only useful during the training phase. 198 dspy_settings.pop("trace", None) 199 200 # Store both dspy model and settings in `DspyChatModelWrapper` or `DspyModelWrapper` for 201 # serialization. 202 if task == "llm/v1/chat": 203 wrapped_dspy_model = DspyChatModelWrapper(model, dspy_settings, model_config) 204 else: 205 wrapped_dspy_model = DspyModelWrapper(model, dspy_settings, model_config) 206 207 flavor_options = { 208 "model_path": model_subpath, 209 } 210 211 if task: 212 if mlflow_model.signature is None: 213 mlflow_model.signature = SIGNATURE_FOR_LLM_INFERENCE_TASK[task] 214 flavor_options.update({_LLM_INFERENCE_TASK_KEY: task}) 215 if mlflow_model.metadata: 216 mlflow_model.metadata[_METADATA_LLM_INFERENCE_TASK_KEY] = task 217 else: 218 mlflow_model.metadata = {_METADATA_LLM_INFERENCE_TASK_KEY: task} 219 220 if saved_example and mlflow_model.signature is None: 221 signature = _infer_signature_from_input_example(saved_example, wrapped_dspy_model) 222 mlflow_model.signature = signature 223 224 streamable = False 225 # Set the output schema to the model wrapper to use it for streaming 226 if mlflow_model.signature and mlflow_model.signature.outputs: 227 wrapped_dspy_model.output_schema = mlflow_model.signature.outputs 228 # DSPy streaming only supports string outputs. 229 if all(spec.type == DataType.string for spec in mlflow_model.signature.outputs): 230 streamable = True 231 232 if use_dspy_model_save: 233 wrapped_dspy_model.model.save(model_path, save_program=True) 234 235 if model_config: 236 with open(os.path.join(data_path, _MODEL_CONFIG_FILE_NAME), "w") as f: 237 json.dump(model_config, f) 238 239 dspy.settings.save( 240 os.path.join(data_path, _DSPY_SETTINGS_FILE_NAME), exclude_keys=["trace"] 241 ) 242 else: 243 with open(model_path, "wb") as f: 244 cloudpickle.dump(wrapped_dspy_model, f) 245 246 code_dir_subpath = _validate_and_copy_code_paths(code_paths, path) 247 248 # Add flavor info to `mlflow_model`. 249 mlflow_model.add_flavor(FLAVOR_NAME, code=code_dir_subpath, **flavor_options) 250 # Add loader_module, data and env data to `mlflow_model`. 251 pyfunc.add_to_model( 252 mlflow_model, 253 loader_module="mlflow.dspy", 254 code=code_dir_subpath, 255 conda_env=_CONDA_ENV_FILE_NAME, 256 python_env=_PYTHON_ENV_FILE_NAME, 257 streamable=streamable, 258 ) 259 260 # Add model file size to `mlflow_model`. 261 if size := get_total_file_size(path): 262 mlflow_model.model_size_bytes = size 263 264 # Add resources if specified. 265 if resources is not None: 266 if isinstance(resources, (Path, str)): 267 serialized_resource = _ResourceBuilder.from_yaml_file(resources) 268 else: 269 serialized_resource = _ResourceBuilder.from_resources(resources) 270 271 mlflow_model.resources = serialized_resource 272 273 # Save mlflow_model to path/MLmodel. 274 mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) 275 276 if conda_env is None: 277 if pip_requirements is None: 278 default_reqs = get_default_pip_requirements() 279 # To ensure `_load_pyfunc` can successfully load the model during the dependency 280 # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file. 281 inferred_reqs = infer_pip_requirements(path, FLAVOR_NAME, fallback=default_reqs) 282 default_reqs = sorted(set(inferred_reqs).union(default_reqs)) 283 else: 284 default_reqs = None 285 conda_env, pip_requirements, pip_constraints = _process_pip_requirements( 286 default_reqs, 287 pip_requirements, 288 extra_pip_requirements, 289 ) 290 else: 291 conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) 292 293 with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f: 294 yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 295 296 # Save `constraints.txt` if necessary. 297 if pip_constraints: 298 write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) 299 300 # Save `requirements.txt`. 301 write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) 302 303 _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME)) 304 305 306 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) 307 @trace_disabled # Suppress traces for internal predict calls while logging model 308 def log_model( 309 dspy_model, 310 artifact_path: str | None = None, 311 task: str | None = None, 312 model_config: dict[str, Any] | None = None, 313 code_paths: list[str] | None = None, 314 conda_env: list[str] | str | None = None, 315 signature: ModelSignature | None = None, 316 input_example: ModelInputExample | None = None, 317 registered_model_name: str | None = None, 318 await_registration_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 319 pip_requirements: list[str] | str | None = None, 320 extra_pip_requirements: list[str] | str | None = None, 321 metadata: dict[str, Any] | None = None, 322 resources: str | Path | list[Resource] | None = None, 323 prompts: list[str | Prompt] | None = None, 324 name: str | None = None, 325 params: dict[str, Any] | None = None, 326 tags: dict[str, Any] | None = None, 327 model_type: str | None = None, 328 step: int = 0, 329 model_id: str | None = None, 330 use_dspy_model_save: bool = False, 331 ): 332 """ 333 Log a Dspy model along with metadata to MLflow. 334 335 This method saves a Dspy model along with metadata such as model signature and conda 336 environments to MLflow. 337 338 Args: 339 dspy_model: an instance of `dspy.Module`. The Dspy model to be saved. 340 artifact_path: Deprecated. Use `name` instead. 341 task: defaults to None. The task type of the model. Can only be `llm/v1/chat` or None for 342 now. 343 model_config: keyword arguments to be passed to the Dspy Module at instantiation. 344 code_paths: {{ code_paths }} 345 conda_env: {{ conda_env }} 346 signature: {{ signature }} 347 input_example: {{ input_example }} 348 registered_model_name: defaults to None. If set, create a model version under 349 `registered_model_name`, also create a registered model if one with the given name does 350 not exist. 351 await_registration_for: defaults to 352 `mlflow.tracking._model_registry.DEFAULT_AWAIT_MAX_SLEEP_SECONDS`. Number of 353 seconds to wait for the model version to finish being created and is in ``READY`` 354 status. By default, the function waits for five minutes. Specify 0 or None to skip 355 waiting. 356 pip_requirements: {{ pip_requirements }} 357 extra_pip_requirements: {{ extra_pip_requirements }} 358 metadata: Custom metadata dictionary passed to the model and stored in the MLmodel 359 file. 360 resources: A list of model resources or a resources.yaml file containing a list of 361 resources required to serve the model. 362 prompts: {{ prompts }} 363 name: {{ name }} 364 params: {{ params }} 365 tags: {{ tags }} 366 model_type: {{ model_type }} 367 step: {{ step }} 368 model_id: {{ model_id }} 369 use_dspy_model_save: Whether to save the Dspy model by dspy builtin `dspy.Module.save` 370 method. 371 372 .. code-block:: python 373 :caption: Example 374 375 import dspy 376 import mlflow 377 from mlflow.models import ModelSignature 378 from mlflow.types.schema import ColSpec, Schema 379 380 # Set up the LM. 381 lm = dspy.LM(model="openai/gpt-4o-mini", max_tokens=250) 382 dspy.settings.configure(lm=lm) 383 384 385 class CoT(dspy.Module): 386 def __init__(self): 387 super().__init__() 388 self.prog = dspy.ChainOfThought("question -> answer") 389 390 def forward(self, question): 391 return self.prog(question=question) 392 393 394 dspy_model = CoT() 395 396 mlflow.set_tracking_uri("http://127.0.0.1:5000") 397 mlflow.set_experiment("test-dspy-logging") 398 399 from mlflow.dspy import log_model 400 401 input_schema = Schema([ColSpec("string")]) 402 output_schema = Schema([ColSpec("string")]) 403 signature = ModelSignature(inputs=input_schema, outputs=output_schema) 404 405 with mlflow.start_run(): 406 log_model( 407 dspy_model, 408 "model", 409 input_example="what is 2 + 2?", 410 signature=signature, 411 ) 412 """ 413 return Model.log( 414 artifact_path=artifact_path, 415 name=name, 416 flavor=mlflow.dspy, 417 model=dspy_model, 418 task=task, 419 model_config=model_config, 420 code_paths=code_paths, 421 conda_env=conda_env, 422 registered_model_name=registered_model_name, 423 signature=signature, 424 input_example=input_example, 425 await_registration_for=await_registration_for, 426 pip_requirements=pip_requirements, 427 extra_pip_requirements=extra_pip_requirements, 428 metadata=metadata, 429 resources=resources, 430 prompts=prompts, 431 params=params, 432 tags=tags, 433 model_type=model_type, 434 step=step, 435 model_id=model_id, 436 use_dspy_model_save=use_dspy_model_save, 437 )