/ tests / pyfunc / test_logged_models.py
test_logged_models.py
  1  import json
  2  import os
  3  from concurrent.futures import ThreadPoolExecutor
  4  
  5  import pytest
  6  
  7  import mlflow
  8  from mlflow.entities.logged_model_status import LoggedModelStatus
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.models import Model
 11  from mlflow.tracing.constant import TraceMetadataKey
 12  from mlflow.utils.mlflow_tags import MLFLOW_MODEL_IS_EXTERNAL
 13  
 14  
 15  class DummyModel(mlflow.pyfunc.PythonModel):
 16      def predict(self, model_input):
 17          return len(model_input) * [0]
 18  
 19  
 20  class TraceModel(mlflow.pyfunc.PythonModel):
 21      @mlflow.trace
 22      def predict(self, model_input):
 23          return len(model_input) * [0]
 24  
 25  
 26  def test_model_id_tracking():
 27      model = TraceModel()
 28      model.predict([1, 2, 3])
 29      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 30      assert TraceMetadataKey.MODEL_ID not in trace.info.request_metadata
 31  
 32      with mlflow.start_run():
 33          info = mlflow.pyfunc.log_model(name="my_model", python_model=model)
 34          # Log another model to ensure that the model ID is correctly associated with the first model
 35          mlflow.pyfunc.log_model(name="another_model", python_model=model)
 36  
 37      model = mlflow.pyfunc.load_model(info.model_uri)
 38      model.predict([4, 5, 6])
 39  
 40      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 41      assert trace is not None
 42      assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == info.model_id
 43  
 44  
 45  def test_model_id_tracking_evaluate():
 46      with mlflow.start_run():
 47          info = mlflow.pyfunc.log_model(name="my_model", python_model=TraceModel())
 48  
 49      mlflow.evaluate(model=info.model_uri, data=[[1, 2, 3]], model_type="regressor", targets=[1])
 50      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 51      assert trace is not None
 52      assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == info.model_id
 53  
 54  
 55  def test_model_id_tracking_thread_safety():
 56      models = []
 57      for _ in range(5):
 58          with mlflow.start_run():
 59              info = mlflow.pyfunc.log_model(
 60                  name="my_model",
 61                  python_model=TraceModel(),
 62                  pip_requirements=[],  # to skip dependency inference
 63              )
 64              model = mlflow.pyfunc.load_model(info.model_uri)
 65              models.append(model)
 66  
 67      def predict(idx, model) -> None:
 68          model.predict([idx])
 69  
 70      with ThreadPoolExecutor(
 71          max_workers=len(models), thread_name_prefix="test-logged-models"
 72      ) as executor:
 73          futures = [executor.submit(predict, idx, model) for idx, model in enumerate(models)]
 74          for f in futures:
 75              f.result()
 76  
 77      traces = mlflow.search_traces(return_type="list")
 78      assert len(traces) == len(models)
 79      for trace in traces:
 80          trace_inputs = trace.info.request_metadata["mlflow.traceInputs"]
 81          index = json.loads(trace_inputs)["model_input"][0]
 82          model_id = trace.info.request_metadata["mlflow.modelId"]
 83          assert model_id == models[index].model_id
 84  
 85  
 86  def test_run_params_are_logged_to_model():
 87      with mlflow.start_run():
 88          mlflow.log_params({"a": 1})
 89          mlflow.pyfunc.log_model(name="my_model", python_model=DummyModel())
 90  
 91      model = mlflow.last_logged_model()
 92      assert model.params == {"a": "1"}
 93  
 94  
 95  def test_run_metrics_are_logged_to_model():
 96      with mlflow.start_run():
 97          mlflow.log_metrics({"a": 1, "b": 2})
 98          mlflow.pyfunc.log_model(name="my_model", python_model=DummyModel())
 99  
100      model = mlflow.last_logged_model()
101      assert [(m.key, m.value) for m in model.metrics] == [("a", 1), ("b", 2)]
102  
103  
104  def test_log_model_finalizes_existing_pending_model():
105      model = mlflow.initialize_logged_model(name="testmodel")
106      assert model.status == LoggedModelStatus.PENDING
107      mlflow.pyfunc.log_model(python_model=DummyModel(), model_id=model.model_id)
108      updated_model = mlflow.get_logged_model(model.model_id)
109      assert updated_model.status == LoggedModelStatus.READY
110  
111  
112  def test_log_model_permits_logging_to_ready_model(tmp_path):
113      # Create a non-external model and finalize it to READY status
114      model = mlflow.initialize_logged_model(name="testmodel")
115      model = mlflow.finalize_logged_model(model.model_id, LoggedModelStatus.READY)
116      assert model.status == LoggedModelStatus.READY
117      assert model.tags.get(MLFLOW_MODEL_IS_EXTERNAL, "false").lower() == "false"
118  
119      # Verify we can log to the READY model
120      mlflow.pyfunc.log_model(python_model=DummyModel(), model_id=model.model_id)
121  
122      # Verify the model can be loaded
123      mlflow.pyfunc.load_model(f"models:/{model.model_id}")
124  
125      # Verify the model artifacts were updated
126      dst_dir = os.path.join(tmp_path, "dst")
127      mlflow.artifacts.download_artifacts(f"models:/{model.model_id}", dst_path=dst_dir)
128      mlflow_model = Model.load(os.path.join(dst_dir, "MLmodel"))
129      assert mlflow_model.flavors.get("python_function") is not None
130  
131  
132  def test_log_model_permits_logging_model_artifacts_to_external_models(tmp_path):
133      model = mlflow.create_external_model(name="testmodel")
134      assert model.status == LoggedModelStatus.READY
135      assert model.tags.get(MLFLOW_MODEL_IS_EXTERNAL) == "true"
136      dst_dir_1 = os.path.join(tmp_path, "dst_1")
137      mlflow.artifacts.download_artifacts(f"models:/{model.model_id}", dst_path=dst_dir_1)
138      mlflow_model: Model = Model.load(os.path.join(dst_dir_1, "MLmodel"))
139  
140      model_info = mlflow.pyfunc.log_model(python_model=DummyModel(), model_id=model.model_id)
141  
142      # Verify that the model can now be loaded and is no longer tagged as external
143      mlflow.pyfunc.load_model(model_info.model_uri)
144      assert MLFLOW_MODEL_IS_EXTERNAL not in mlflow.get_logged_model(model.model_id).tags
145      dst_dir_2 = os.path.join(tmp_path, "dst_2")
146      mlflow.artifacts.download_artifacts(f"models:/{model.model_id}", dst_path=dst_dir_2)
147      mlflow_model = Model.load(os.path.join(dst_dir_2, "MLmodel"))
148      assert MLFLOW_MODEL_IS_EXTERNAL not in (mlflow_model.metadata or {})
149  
150  
151  def test_external_logged_model_cannot_be_loaded_with_pyfunc():
152      model = mlflow.create_external_model(name="testmodel")
153      with pytest.raises(
154          MlflowException,
155          match="This model's artifacts are external.*cannot be loaded",
156      ):
157          mlflow.pyfunc.load_model(f"models:/{model.model_id}")