load.py
1 import json 2 import logging 3 import os 4 5 import cloudpickle 6 7 from mlflow.dspy.save import ( 8 _DSPY_SETTINGS_FILE_NAME, 9 _MODEL_CONFIG_FILE_NAME, 10 _MODEL_DATA_PATH, 11 ) 12 from mlflow.dspy.wrapper import DspyChatModelWrapper, DspyModelWrapper 13 from mlflow.environment_variables import MLFLOW_ALLOW_PICKLE_DESERIALIZATION 14 from mlflow.exceptions import MlflowException 15 from mlflow.models import Model 16 from mlflow.models.dependencies_schemas import _get_dependencies_schema_from_model 17 from mlflow.models.model import _update_active_model_id_based_on_mlflow_model 18 from mlflow.tracing.provider import trace_disabled 19 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 20 from mlflow.utils.databricks_utils import ( 21 is_in_databricks_model_serving_environment, 22 is_in_databricks_runtime, 23 ) 24 from mlflow.utils.model_utils import ( 25 _add_code_from_conf_to_system_path, 26 _get_flavor_configuration, 27 ) 28 29 _DEFAULT_MODEL_PATH = "data/model.pkl" 30 _logger = logging.getLogger(__name__) 31 32 33 def _set_dependency_schema_to_tracer(model_path, callbacks): 34 """ 35 Set dependency schemas from the saved model metadata to the tracer 36 to propagate it to inference traces. 37 """ 38 from mlflow.dspy.callback import MlflowCallback 39 40 tracer = next((cb for cb in callbacks if isinstance(cb, MlflowCallback)), None) 41 if tracer is None: 42 return 43 44 model = Model.load(model_path) 45 tracer.set_dependencies_schema(_get_dependencies_schema_from_model(model)) 46 47 48 def _load_model(model_uri, dst_path=None): 49 import dspy 50 51 local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path) 52 mlflow_model = Model.load(local_model_path) 53 flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name="dspy") 54 55 _add_code_from_conf_to_system_path(local_model_path, flavor_conf) 56 model_path = flavor_conf.get("model_path", _DEFAULT_MODEL_PATH) 57 task = flavor_conf.get("inference_task") 58 59 if model_path.endswith(".pkl"): 60 if ( 61 not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get() 62 and not is_in_databricks_runtime() 63 and not is_in_databricks_model_serving_environment() 64 ): 65 raise MlflowException( 66 "Deserializing model using pickle is disallowed, but this model is saved " 67 "in pickle format. To address this issue, you need to set environment variable " 68 "'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true', or save the model with " 69 "'use_dspy_model_save=True' like " 70 "`mlflow.dspy.save_model(model, path, use_dspy_model_save=True)`." 71 ) 72 73 with open(os.path.join(local_model_path, model_path), "rb") as f: 74 loaded_wrapper = cloudpickle.load(f) 75 else: 76 model = dspy.load(os.path.join(local_model_path, model_path), allow_pickle=True) 77 78 dspy_settings = dspy.load_settings( 79 os.path.join(local_model_path, _MODEL_DATA_PATH, _DSPY_SETTINGS_FILE_NAME) 80 ) 81 82 model_config_file = os.path.join( 83 local_model_path, _MODEL_DATA_PATH, _MODEL_CONFIG_FILE_NAME 84 ) 85 if os.path.exists(model_config_file): 86 with open(model_config_file) as f: 87 model_config = json.load(f) 88 else: 89 model_config = None 90 91 if task == "llm/v1/chat": 92 loaded_wrapper = DspyChatModelWrapper(model, dspy_settings, model_config) 93 else: 94 loaded_wrapper = DspyModelWrapper(model, dspy_settings, model_config) 95 96 _set_dependency_schema_to_tracer(local_model_path, loaded_wrapper.dspy_settings["callbacks"]) 97 _update_active_model_id_based_on_mlflow_model(mlflow_model) 98 return loaded_wrapper 99 100 101 @trace_disabled # Suppress traces for internal calls while loading model 102 def load_model(model_uri, dst_path=None): 103 """ 104 Load a Dspy model from a run. 105 106 This function will also set the global dspy settings `dspy.settings` by the saved settings. 107 108 Args: 109 model_uri: The location, in URI format, of the MLflow model. For example: 110 111 - ``/Users/me/path/to/local/model`` 112 - ``relative/path/to/local/model`` 113 - ``s3://my_bucket/path/to/model`` 114 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 115 - ``mlflow-artifacts:/path/to/model`` 116 117 For more information about supported URI schemes, see 118 `Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html# 119 artifact-locations>`_. 120 dst_path: The local filesystem path to utilize for downloading the model artifact. 121 This directory must already exist if provided. If unspecified, a local output 122 path will be created. 123 124 Returns: 125 An `dspy.module` instance, representing the dspy model. 126 """ 127 import dspy 128 129 wrapper = _load_model(model_uri, dst_path) 130 131 # Set the global dspy settings for reproducing the model's behavior when the model is 132 # loaded via `mlflow.dspy.load_model`. Note that for the model to be loaded as pyfunc, 133 # settings will be set in the wrapper's `predict` method via local context to avoid the 134 # "dspy.settings can only be changed by the thread that initially configured it" error 135 # in Databricks model serving. 136 dspy.settings.configure(**wrapper.dspy_settings) 137 138 return wrapper.model 139 140 141 def _load_pyfunc(path): 142 return _load_model(path)