/ mlflow / dspy / save.py
save.py
  1  """Functions for saving DSPY models to MLflow."""
  2  
  3  import json
  4  import logging
  5  import os
  6  from pathlib import Path
  7  from typing import Any
  8  
  9  import cloudpickle
 10  import yaml
 11  from packaging.version import Version
 12  
 13  import mlflow
 14  from mlflow import pyfunc
 15  from mlflow.dspy.constant import FLAVOR_NAME
 16  from mlflow.dspy.wrapper import DspyChatModelWrapper, DspyModelWrapper
 17  from mlflow.entities.model_registry.prompt import Prompt
 18  from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException
 19  from mlflow.models import (
 20      Model,
 21      ModelInputExample,
 22      ModelSignature,
 23      infer_pip_requirements,
 24  )
 25  from mlflow.models.dependencies_schemas import _get_dependencies_schemas
 26  from mlflow.models.model import MLMODEL_FILE_NAME
 27  from mlflow.models.rag_signatures import SIGNATURE_FOR_LLM_INFERENCE_TASK
 28  from mlflow.models.resources import Resource, _ResourceBuilder
 29  from mlflow.models.signature import _infer_signature_from_input_example
 30  from mlflow.models.utils import _save_example
 31  from mlflow.tracing.provider import trace_disabled
 32  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
 33  from mlflow.types.schema import DataType
 34  from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
 35  from mlflow.utils.environment import (
 36      _CONDA_ENV_FILE_NAME,
 37      _CONSTRAINTS_FILE_NAME,
 38      _PYTHON_ENV_FILE_NAME,
 39      _REQUIREMENTS_FILE_NAME,
 40      _mlflow_conda_env,
 41      _process_conda_env,
 42      _process_pip_requirements,
 43      _PythonEnv,
 44  )
 45  from mlflow.utils.file_utils import get_total_file_size, write_to
 46  from mlflow.utils.model_utils import (
 47      _validate_and_copy_code_paths,
 48      _validate_and_prepare_target_save_path,
 49  )
 50  from mlflow.utils.requirements_utils import _get_pinned_requirement
 51  
 52  _MODEL_SAVE_PATH = "model"
 53  _MODEL_DATA_PATH = "data"
 54  _MODEL_CONFIG_FILE_NAME = "model_config.json"
 55  _DSPY_SETTINGS_FILE_NAME = "dspy_config.pkl"
 56  _DSPY_RM_FILE_NAME = "dspy_rm.pkl"
 57  
 58  _logger = logging.getLogger(__name__)
 59  
 60  
 61  def get_default_pip_requirements():
 62      """
 63      Returns:
 64          A list of default pip requirements for MLflow Models produced by Dspy flavor. Calls to
 65          `save_model()` and `log_model()` produce a pip environment that, at minimum, contains these
 66          requirements.
 67      """
 68      return [_get_pinned_requirement("dspy")]
 69  
 70  
 71  def get_default_conda_env():
 72      """
 73      Returns:
 74          The default Conda environment for MLflow Models produced by calls to `save_model()` and
 75          `log_model()`.
 76      """
 77      return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
 78  
 79  
 80  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
 81  @trace_disabled  # Suppress traces for internal predict calls while logging model
 82  def save_model(
 83      model,
 84      path: str,
 85      task: str | None = None,
 86      model_config: dict[str, Any] | None = None,
 87      code_paths: list[str] | None = None,
 88      mlflow_model: Model | None = None,
 89      conda_env: list[str] | str | None = None,
 90      signature: ModelSignature | None = None,
 91      input_example: ModelInputExample | None = None,
 92      pip_requirements: list[str] | str | None = None,
 93      extra_pip_requirements: list[str] | str | None = None,
 94      metadata: dict[str, Any] | None = None,
 95      resources: str | Path | list[Resource] | None = None,
 96      use_dspy_model_save: bool = False,
 97  ):
 98      """
 99      Save a Dspy model.
100  
101      This method saves a Dspy model along with metadata such as model signature and conda
102      environments to local file system. This method is called inside `mlflow.dspy.log_model()`.
103  
104      Args:
105          model: an instance of `dspy.Module`. The Dspy model/module to be saved.
106          path: local path where the MLflow model is to be saved.
107          task: defaults to None. The task type of the model. Can only be `llm/v1/chat` or None for
108              now.
109          model_config: keyword arguments to be passed to the Dspy Module at instantiation.
110          code_paths: {{ code_paths }}
111          mlflow_model: an instance of `mlflow.models.Model`, defaults to None. MLflow model
112              configuration to which to add the Dspy model metadata. If None, a blank instance will
113              be created.
114          conda_env: {{ conda_env }}
115          signature: {{ signature }}
116          input_example: {{ input_example }}
117          pip_requirements: {{ pip_requirements }}
118          extra_pip_requirements: {{ extra_pip_requirements }}
119          metadata: {{ metadata }}
120          resources: A list of model resources or a resources.yaml file containing a list of
121              resources required to serve the model.
122          use_dspy_model_save: Whether to save the Dspy model by dspy builtin `dspy.Module.save`
123              method.
124      """
125  
126      import dspy
127  
128      from mlflow.transformers.llm_inference_utils import (
129          _LLM_INFERENCE_TASK_KEY,
130          _METADATA_LLM_INFERENCE_TASK_KEY,
131      )
132      from mlflow.utils.databricks_utils import is_in_databricks_runtime
133  
134      if signature:
135          num_inputs = len(signature.inputs.inputs)
136          if num_inputs == 0:
137              raise MlflowException(
138                  "The model signature's input schema must contain at least one field.",
139                  error_code=INVALID_PARAMETER_VALUE,
140              )
141      if task and task not in SIGNATURE_FOR_LLM_INFERENCE_TASK:
142          raise MlflowException(
143              "Invalid task: {task} at `mlflow.dspy.save_model()` call. The task must be None or one "
144              f"of: {list(SIGNATURE_FOR_LLM_INFERENCE_TASK.keys())}",
145              error_code=INVALID_PARAMETER_VALUE,
146          )
147      if not use_dspy_model_save and not is_in_databricks_runtime():
148          _logger.warning(
149              "Saving DSPy model by Pickle or CloudPickle format requires exercising "
150              "caution because these formats rely on Python's object serialization mechanism, "
151              "which can execute arbitrary code during deserialization."
152              "The recommended alternative is to set 'use_dspy_model_save' to True "
153              "(requiring dspy >= 3.1.0) to save the "
154              "DSPy model using the DSPy builtin saving method."
155          )
156  
157      if mlflow_model is None:
158          mlflow_model = Model()
159      if signature is not None:
160          mlflow_model.signature = signature
161      saved_example = None
162      if input_example is not None:
163          path = os.path.abspath(path)
164          _validate_and_prepare_target_save_path(path)
165          saved_example = _save_example(mlflow_model, input_example, path)
166      if metadata is not None:
167          mlflow_model.metadata = metadata
168  
169      with _get_dependencies_schemas() as dependencies_schemas:
170          schema = dependencies_schemas.to_dict()
171          if schema is not None:
172              if mlflow_model.metadata is None:
173                  mlflow_model.metadata = {}
174              mlflow_model.metadata.update(schema)
175  
176      model_data_subpath = _MODEL_DATA_PATH
177      # Construct new data folder in existing path.
178      data_path = os.path.join(path, model_data_subpath)
179      os.makedirs(data_path, exist_ok=True)
180      model_subpath = os.path.join(model_data_subpath, _MODEL_SAVE_PATH)
181      if not use_dspy_model_save:
182          # Set the model path to end with ".pkl" as we use cloudpickle for serialization.
183          model_subpath += ".pkl"
184  
185      model_path = os.path.join(path, model_subpath)
186  
187      if use_dspy_model_save:
188          if Version(dspy.__version__) <= Version("3.1.0"):
189              raise MlflowException(
190                  "'use_dspy_model_save' option is only supported for DSPy version > 3.1.0."
191              )
192          os.makedirs(model_path, exist_ok=True)
193  
194      # Dspy has a global context `dspy.settings`, and we need to save it along with the model.
195      dspy_settings = dict(dspy.settings.config)
196  
197      # Don't save the trace in the model, which is only useful during the training phase.
198      dspy_settings.pop("trace", None)
199  
200      # Store both dspy model and settings in `DspyChatModelWrapper` or `DspyModelWrapper` for
201      # serialization.
202      if task == "llm/v1/chat":
203          wrapped_dspy_model = DspyChatModelWrapper(model, dspy_settings, model_config)
204      else:
205          wrapped_dspy_model = DspyModelWrapper(model, dspy_settings, model_config)
206  
207      flavor_options = {
208          "model_path": model_subpath,
209      }
210  
211      if task:
212          if mlflow_model.signature is None:
213              mlflow_model.signature = SIGNATURE_FOR_LLM_INFERENCE_TASK[task]
214          flavor_options.update({_LLM_INFERENCE_TASK_KEY: task})
215          if mlflow_model.metadata:
216              mlflow_model.metadata[_METADATA_LLM_INFERENCE_TASK_KEY] = task
217          else:
218              mlflow_model.metadata = {_METADATA_LLM_INFERENCE_TASK_KEY: task}
219  
220      if saved_example and mlflow_model.signature is None:
221          signature = _infer_signature_from_input_example(saved_example, wrapped_dspy_model)
222          mlflow_model.signature = signature
223  
224      streamable = False
225      # Set the output schema to the model wrapper to use it for streaming
226      if mlflow_model.signature and mlflow_model.signature.outputs:
227          wrapped_dspy_model.output_schema = mlflow_model.signature.outputs
228          # DSPy streaming only supports string outputs.
229          if all(spec.type == DataType.string for spec in mlflow_model.signature.outputs):
230              streamable = True
231  
232      if use_dspy_model_save:
233          wrapped_dspy_model.model.save(model_path, save_program=True)
234  
235          if model_config:
236              with open(os.path.join(data_path, _MODEL_CONFIG_FILE_NAME), "w") as f:
237                  json.dump(model_config, f)
238  
239          dspy.settings.save(
240              os.path.join(data_path, _DSPY_SETTINGS_FILE_NAME), exclude_keys=["trace"]
241          )
242      else:
243          with open(model_path, "wb") as f:
244              cloudpickle.dump(wrapped_dspy_model, f)
245  
246      code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
247  
248      # Add flavor info to `mlflow_model`.
249      mlflow_model.add_flavor(FLAVOR_NAME, code=code_dir_subpath, **flavor_options)
250      # Add loader_module, data and env data to `mlflow_model`.
251      pyfunc.add_to_model(
252          mlflow_model,
253          loader_module="mlflow.dspy",
254          code=code_dir_subpath,
255          conda_env=_CONDA_ENV_FILE_NAME,
256          python_env=_PYTHON_ENV_FILE_NAME,
257          streamable=streamable,
258      )
259  
260      # Add model file size to `mlflow_model`.
261      if size := get_total_file_size(path):
262          mlflow_model.model_size_bytes = size
263  
264      # Add resources if specified.
265      if resources is not None:
266          if isinstance(resources, (Path, str)):
267              serialized_resource = _ResourceBuilder.from_yaml_file(resources)
268          else:
269              serialized_resource = _ResourceBuilder.from_resources(resources)
270  
271          mlflow_model.resources = serialized_resource
272  
273      # Save mlflow_model to path/MLmodel.
274      mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
275  
276      if conda_env is None:
277          if pip_requirements is None:
278              default_reqs = get_default_pip_requirements()
279              # To ensure `_load_pyfunc` can successfully load the model during the dependency
280              # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
281              inferred_reqs = infer_pip_requirements(path, FLAVOR_NAME, fallback=default_reqs)
282              default_reqs = sorted(set(inferred_reqs).union(default_reqs))
283          else:
284              default_reqs = None
285          conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
286              default_reqs,
287              pip_requirements,
288              extra_pip_requirements,
289          )
290      else:
291          conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
292  
293      with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
294          yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
295  
296      # Save `constraints.txt` if necessary.
297      if pip_constraints:
298          write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
299  
300      # Save `requirements.txt`.
301      write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
302  
303      _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
304  
305  
306  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
307  @trace_disabled  # Suppress traces for internal predict calls while logging model
308  def log_model(
309      dspy_model,
310      artifact_path: str | None = None,
311      task: str | None = None,
312      model_config: dict[str, Any] | None = None,
313      code_paths: list[str] | None = None,
314      conda_env: list[str] | str | None = None,
315      signature: ModelSignature | None = None,
316      input_example: ModelInputExample | None = None,
317      registered_model_name: str | None = None,
318      await_registration_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
319      pip_requirements: list[str] | str | None = None,
320      extra_pip_requirements: list[str] | str | None = None,
321      metadata: dict[str, Any] | None = None,
322      resources: str | Path | list[Resource] | None = None,
323      prompts: list[str | Prompt] | None = None,
324      name: str | None = None,
325      params: dict[str, Any] | None = None,
326      tags: dict[str, Any] | None = None,
327      model_type: str | None = None,
328      step: int = 0,
329      model_id: str | None = None,
330      use_dspy_model_save: bool = False,
331  ):
332      """
333      Log a Dspy model along with metadata to MLflow.
334  
335      This method saves a Dspy model along with metadata such as model signature and conda
336      environments to MLflow.
337  
338      Args:
339          dspy_model: an instance of `dspy.Module`. The Dspy model to be saved.
340          artifact_path: Deprecated. Use `name` instead.
341          task: defaults to None. The task type of the model. Can only be `llm/v1/chat` or None for
342              now.
343          model_config: keyword arguments to be passed to the Dspy Module at instantiation.
344          code_paths: {{ code_paths }}
345          conda_env: {{ conda_env }}
346          signature: {{ signature }}
347          input_example: {{ input_example }}
348          registered_model_name: defaults to None. If set, create a model version under
349              `registered_model_name`, also create a registered model if one with the given name does
350              not exist.
351          await_registration_for: defaults to
352              `mlflow.tracking._model_registry.DEFAULT_AWAIT_MAX_SLEEP_SECONDS`. Number of
353              seconds to wait for the model version to finish being created and is in ``READY``
354              status. By default, the function waits for five minutes. Specify 0 or None to skip
355              waiting.
356          pip_requirements: {{ pip_requirements }}
357          extra_pip_requirements: {{ extra_pip_requirements }}
358          metadata: Custom metadata dictionary passed to the model and stored in the MLmodel
359              file.
360          resources: A list of model resources or a resources.yaml file containing a list of
361              resources required to serve the model.
362          prompts: {{ prompts }}
363          name: {{ name }}
364          params: {{ params }}
365          tags: {{ tags }}
366          model_type: {{ model_type }}
367          step: {{ step }}
368          model_id: {{ model_id }}
369          use_dspy_model_save: Whether to save the Dspy model by dspy builtin `dspy.Module.save`
370              method.
371  
372      .. code-block:: python
373          :caption: Example
374  
375          import dspy
376          import mlflow
377          from mlflow.models import ModelSignature
378          from mlflow.types.schema import ColSpec, Schema
379  
380          # Set up the LM.
381          lm = dspy.LM(model="openai/gpt-4o-mini", max_tokens=250)
382          dspy.settings.configure(lm=lm)
383  
384  
385          class CoT(dspy.Module):
386              def __init__(self):
387                  super().__init__()
388                  self.prog = dspy.ChainOfThought("question -> answer")
389  
390              def forward(self, question):
391                  return self.prog(question=question)
392  
393  
394          dspy_model = CoT()
395  
396          mlflow.set_tracking_uri("http://127.0.0.1:5000")
397          mlflow.set_experiment("test-dspy-logging")
398  
399          from mlflow.dspy import log_model
400  
401          input_schema = Schema([ColSpec("string")])
402          output_schema = Schema([ColSpec("string")])
403          signature = ModelSignature(inputs=input_schema, outputs=output_schema)
404  
405          with mlflow.start_run():
406              log_model(
407                  dspy_model,
408                  "model",
409                  input_example="what is 2 + 2?",
410                  signature=signature,
411              )
412      """
413      return Model.log(
414          artifact_path=artifact_path,
415          name=name,
416          flavor=mlflow.dspy,
417          model=dspy_model,
418          task=task,
419          model_config=model_config,
420          code_paths=code_paths,
421          conda_env=conda_env,
422          registered_model_name=registered_model_name,
423          signature=signature,
424          input_example=input_example,
425          await_registration_for=await_registration_for,
426          pip_requirements=pip_requirements,
427          extra_pip_requirements=extra_pip_requirements,
428          metadata=metadata,
429          resources=resources,
430          prompts=prompts,
431          params=params,
432          tags=tags,
433          model_type=model_type,
434          step=step,
435          model_id=model_id,
436          use_dspy_model_save=use_dspy_model_save,
437      )