/ mlflow / dspy / load.py
load.py
  1  import json
  2  import logging
  3  import os
  4  
  5  import cloudpickle
  6  
  7  from mlflow.dspy.save import (
  8      _DSPY_SETTINGS_FILE_NAME,
  9      _MODEL_CONFIG_FILE_NAME,
 10      _MODEL_DATA_PATH,
 11  )
 12  from mlflow.dspy.wrapper import DspyChatModelWrapper, DspyModelWrapper
 13  from mlflow.environment_variables import MLFLOW_ALLOW_PICKLE_DESERIALIZATION
 14  from mlflow.exceptions import MlflowException
 15  from mlflow.models import Model
 16  from mlflow.models.dependencies_schemas import _get_dependencies_schema_from_model
 17  from mlflow.models.model import _update_active_model_id_based_on_mlflow_model
 18  from mlflow.tracing.provider import trace_disabled
 19  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 20  from mlflow.utils.databricks_utils import (
 21      is_in_databricks_model_serving_environment,
 22      is_in_databricks_runtime,
 23  )
 24  from mlflow.utils.model_utils import (
 25      _add_code_from_conf_to_system_path,
 26      _get_flavor_configuration,
 27  )
 28  
 29  _DEFAULT_MODEL_PATH = "data/model.pkl"
 30  _logger = logging.getLogger(__name__)
 31  
 32  
 33  def _set_dependency_schema_to_tracer(model_path, callbacks):
 34      """
 35      Set dependency schemas from the saved model metadata to the tracer
 36      to propagate it to inference traces.
 37      """
 38      from mlflow.dspy.callback import MlflowCallback
 39  
 40      tracer = next((cb for cb in callbacks if isinstance(cb, MlflowCallback)), None)
 41      if tracer is None:
 42          return
 43  
 44      model = Model.load(model_path)
 45      tracer.set_dependencies_schema(_get_dependencies_schema_from_model(model))
 46  
 47  
 48  def _load_model(model_uri, dst_path=None):
 49      import dspy
 50  
 51      local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
 52      mlflow_model = Model.load(local_model_path)
 53      flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name="dspy")
 54  
 55      _add_code_from_conf_to_system_path(local_model_path, flavor_conf)
 56      model_path = flavor_conf.get("model_path", _DEFAULT_MODEL_PATH)
 57      task = flavor_conf.get("inference_task")
 58  
 59      if model_path.endswith(".pkl"):
 60          if (
 61              not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get()
 62              and not is_in_databricks_runtime()
 63              and not is_in_databricks_model_serving_environment()
 64          ):
 65              raise MlflowException(
 66                  "Deserializing model using pickle is disallowed, but this model is saved "
 67                  "in pickle format. To address this issue, you need to set environment variable "
 68                  "'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true', or save the model with "
 69                  "'use_dspy_model_save=True' like "
 70                  "`mlflow.dspy.save_model(model, path, use_dspy_model_save=True)`."
 71              )
 72  
 73          with open(os.path.join(local_model_path, model_path), "rb") as f:
 74              loaded_wrapper = cloudpickle.load(f)
 75      else:
 76          model = dspy.load(os.path.join(local_model_path, model_path), allow_pickle=True)
 77  
 78          dspy_settings = dspy.load_settings(
 79              os.path.join(local_model_path, _MODEL_DATA_PATH, _DSPY_SETTINGS_FILE_NAME)
 80          )
 81  
 82          model_config_file = os.path.join(
 83              local_model_path, _MODEL_DATA_PATH, _MODEL_CONFIG_FILE_NAME
 84          )
 85          if os.path.exists(model_config_file):
 86              with open(model_config_file) as f:
 87                  model_config = json.load(f)
 88          else:
 89              model_config = None
 90  
 91          if task == "llm/v1/chat":
 92              loaded_wrapper = DspyChatModelWrapper(model, dspy_settings, model_config)
 93          else:
 94              loaded_wrapper = DspyModelWrapper(model, dspy_settings, model_config)
 95  
 96      _set_dependency_schema_to_tracer(local_model_path, loaded_wrapper.dspy_settings["callbacks"])
 97      _update_active_model_id_based_on_mlflow_model(mlflow_model)
 98      return loaded_wrapper
 99  
100  
101  @trace_disabled  # Suppress traces for internal calls while loading model
102  def load_model(model_uri, dst_path=None):
103      """
104      Load a Dspy model from a run.
105  
106      This function will also set the global dspy settings `dspy.settings` by the saved settings.
107  
108      Args:
109          model_uri: The location, in URI format, of the MLflow model. For example:
110  
111              - ``/Users/me/path/to/local/model``
112              - ``relative/path/to/local/model``
113              - ``s3://my_bucket/path/to/model``
114              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
115              - ``mlflow-artifacts:/path/to/model``
116  
117              For more information about supported URI schemes, see
118              `Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
119              artifact-locations>`_.
120          dst_path: The local filesystem path to utilize for downloading the model artifact.
121              This directory must already exist if provided. If unspecified, a local output
122              path will be created.
123  
124      Returns:
125          An `dspy.module` instance, representing the dspy model.
126      """
127      import dspy
128  
129      wrapper = _load_model(model_uri, dst_path)
130  
131      # Set the global dspy settings for reproducing the model's behavior when the model is
132      # loaded via `mlflow.dspy.load_model`. Note that for the model to be loaded as pyfunc,
133      # settings will be set in the wrapper's `predict` method via local context to avoid the
134      # "dspy.settings can only be changed by the thread that initially configured it" error
135      # in Databricks model serving.
136      dspy.settings.configure(**wrapper.dspy_settings)
137  
138      return wrapper.model
139  
140  
141  def _load_pyfunc(path):
142      return _load_model(path)