test_logged_models.py
1 import json 2 import os 3 from concurrent.futures import ThreadPoolExecutor 4 5 import pytest 6 7 import mlflow 8 from mlflow.entities.logged_model_status import LoggedModelStatus 9 from mlflow.exceptions import MlflowException 10 from mlflow.models import Model 11 from mlflow.tracing.constant import TraceMetadataKey 12 from mlflow.utils.mlflow_tags import MLFLOW_MODEL_IS_EXTERNAL 13 14 15 class DummyModel(mlflow.pyfunc.PythonModel): 16 def predict(self, model_input): 17 return len(model_input) * [0] 18 19 20 class TraceModel(mlflow.pyfunc.PythonModel): 21 @mlflow.trace 22 def predict(self, model_input): 23 return len(model_input) * [0] 24 25 26 def test_model_id_tracking(): 27 model = TraceModel() 28 model.predict([1, 2, 3]) 29 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 30 assert TraceMetadataKey.MODEL_ID not in trace.info.request_metadata 31 32 with mlflow.start_run(): 33 info = mlflow.pyfunc.log_model(name="my_model", python_model=model) 34 # Log another model to ensure that the model ID is correctly associated with the first model 35 mlflow.pyfunc.log_model(name="another_model", python_model=model) 36 37 model = mlflow.pyfunc.load_model(info.model_uri) 38 model.predict([4, 5, 6]) 39 40 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 41 assert trace is not None 42 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == info.model_id 43 44 45 def test_model_id_tracking_evaluate(): 46 with mlflow.start_run(): 47 info = mlflow.pyfunc.log_model(name="my_model", python_model=TraceModel()) 48 49 mlflow.evaluate(model=info.model_uri, data=[[1, 2, 3]], model_type="regressor", targets=[1]) 50 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 51 assert trace is not None 52 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == info.model_id 53 54 55 def test_model_id_tracking_thread_safety(): 56 models = [] 57 for _ in range(5): 58 with mlflow.start_run(): 59 info = mlflow.pyfunc.log_model( 60 name="my_model", 61 python_model=TraceModel(), 62 pip_requirements=[], # to skip dependency inference 63 ) 64 model = mlflow.pyfunc.load_model(info.model_uri) 65 models.append(model) 66 67 def predict(idx, model) -> None: 68 model.predict([idx]) 69 70 with ThreadPoolExecutor( 71 max_workers=len(models), thread_name_prefix="test-logged-models" 72 ) as executor: 73 futures = [executor.submit(predict, idx, model) for idx, model in enumerate(models)] 74 for f in futures: 75 f.result() 76 77 traces = mlflow.search_traces(return_type="list") 78 assert len(traces) == len(models) 79 for trace in traces: 80 trace_inputs = trace.info.request_metadata["mlflow.traceInputs"] 81 index = json.loads(trace_inputs)["model_input"][0] 82 model_id = trace.info.request_metadata["mlflow.modelId"] 83 assert model_id == models[index].model_id 84 85 86 def test_run_params_are_logged_to_model(): 87 with mlflow.start_run(): 88 mlflow.log_params({"a": 1}) 89 mlflow.pyfunc.log_model(name="my_model", python_model=DummyModel()) 90 91 model = mlflow.last_logged_model() 92 assert model.params == {"a": "1"} 93 94 95 def test_run_metrics_are_logged_to_model(): 96 with mlflow.start_run(): 97 mlflow.log_metrics({"a": 1, "b": 2}) 98 mlflow.pyfunc.log_model(name="my_model", python_model=DummyModel()) 99 100 model = mlflow.last_logged_model() 101 assert [(m.key, m.value) for m in model.metrics] == [("a", 1), ("b", 2)] 102 103 104 def test_log_model_finalizes_existing_pending_model(): 105 model = mlflow.initialize_logged_model(name="testmodel") 106 assert model.status == LoggedModelStatus.PENDING 107 mlflow.pyfunc.log_model(python_model=DummyModel(), model_id=model.model_id) 108 updated_model = mlflow.get_logged_model(model.model_id) 109 assert updated_model.status == LoggedModelStatus.READY 110 111 112 def test_log_model_permits_logging_to_ready_model(tmp_path): 113 # Create a non-external model and finalize it to READY status 114 model = mlflow.initialize_logged_model(name="testmodel") 115 model = mlflow.finalize_logged_model(model.model_id, LoggedModelStatus.READY) 116 assert model.status == LoggedModelStatus.READY 117 assert model.tags.get(MLFLOW_MODEL_IS_EXTERNAL, "false").lower() == "false" 118 119 # Verify we can log to the READY model 120 mlflow.pyfunc.log_model(python_model=DummyModel(), model_id=model.model_id) 121 122 # Verify the model can be loaded 123 mlflow.pyfunc.load_model(f"models:/{model.model_id}") 124 125 # Verify the model artifacts were updated 126 dst_dir = os.path.join(tmp_path, "dst") 127 mlflow.artifacts.download_artifacts(f"models:/{model.model_id}", dst_path=dst_dir) 128 mlflow_model = Model.load(os.path.join(dst_dir, "MLmodel")) 129 assert mlflow_model.flavors.get("python_function") is not None 130 131 132 def test_log_model_permits_logging_model_artifacts_to_external_models(tmp_path): 133 model = mlflow.create_external_model(name="testmodel") 134 assert model.status == LoggedModelStatus.READY 135 assert model.tags.get(MLFLOW_MODEL_IS_EXTERNAL) == "true" 136 dst_dir_1 = os.path.join(tmp_path, "dst_1") 137 mlflow.artifacts.download_artifacts(f"models:/{model.model_id}", dst_path=dst_dir_1) 138 mlflow_model: Model = Model.load(os.path.join(dst_dir_1, "MLmodel")) 139 140 model_info = mlflow.pyfunc.log_model(python_model=DummyModel(), model_id=model.model_id) 141 142 # Verify that the model can now be loaded and is no longer tagged as external 143 mlflow.pyfunc.load_model(model_info.model_uri) 144 assert MLFLOW_MODEL_IS_EXTERNAL not in mlflow.get_logged_model(model.model_id).tags 145 dst_dir_2 = os.path.join(tmp_path, "dst_2") 146 mlflow.artifacts.download_artifacts(f"models:/{model.model_id}", dst_path=dst_dir_2) 147 mlflow_model = Model.load(os.path.join(dst_dir_2, "MLmodel")) 148 assert MLFLOW_MODEL_IS_EXTERNAL not in (mlflow_model.metadata or {}) 149 150 151 def test_external_logged_model_cannot_be_loaded_with_pyfunc(): 152 model = mlflow.create_external_model(name="testmodel") 153 with pytest.raises( 154 MlflowException, 155 match="This model's artifacts are external.*cannot be loaded", 156 ): 157 mlflow.pyfunc.load_model(f"models:/{model.model_id}")