test_mlflow_version_comp.py
1 import os 2 import subprocess 3 import sys 4 import uuid 5 from pathlib import Path 6 7 import numpy as np 8 import sklearn 9 from pyspark.sql import SparkSession 10 from sklearn.linear_model import LinearRegression 11 12 import mlflow 13 from mlflow.models import Model 14 15 16 def check_load(model_uri: str) -> None: 17 Model.load(model_uri) 18 model = mlflow.sklearn.load_model(model_uri) 19 np.testing.assert_array_equal(model.predict([[1, 2]]), [3.0]) 20 model = mlflow.pyfunc.load_model(model_uri) 21 np.testing.assert_array_equal(model.predict([[1, 2]]), [3.0]) 22 23 24 def check_register(model_uri: str) -> None: 25 mv = mlflow.register_model(model_uri, "model") 26 model = mlflow.pyfunc.load_model(f"models:/{mv.name}/{mv.version}") 27 np.testing.assert_array_equal(model.predict([[1, 2]]), [3.0]) 28 29 30 def check_list_artifacts_with_run_id_and_path(run_id: str, path: str) -> None: 31 # List artifacts 32 client = mlflow.MlflowClient() 33 artifacts = [a.path for a in client.list_artifacts(run_id=run_id, path=path)] 34 # Ensure both run and model artifacts are listed 35 assert "model/MLmodel" in artifacts 36 assert "model/test.txt" in artifacts 37 artifacts = [a.path for a in client.list_artifacts(run_id=run_id, path=path)] 38 assert "model/MLmodel" in artifacts 39 assert "model/test.txt" in artifacts 40 # Non-existing artifact path should return an empty list 41 assert len(client.list_artifacts(run_id=run_id, path="unknown")) == 0 42 assert len(mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path="unknown")) == 0 43 44 45 def check_list_artifacts_with_model_uri(model_uri: str) -> None: 46 artifacts = [a.path for a in mlflow.artifacts.list_artifacts(artifact_uri=model_uri)] 47 assert "model/MLmodel" in artifacts 48 assert "model/test.txt" in artifacts 49 50 51 def check_download_artifacts_with_run_id_and_path(run_id: str, path: str, tmp_path: Path) -> None: 52 out_path = mlflow.artifacts.download_artifacts( 53 run_id=run_id, artifact_path=path, dst_path=tmp_path / str(uuid.uuid4()) 54 ) 55 files = [f.name for f in Path(out_path).iterdir() if f.is_file()] 56 assert "MLmodel" in files 57 assert "test.txt" in files 58 client = mlflow.MlflowClient() 59 out_path = client.download_artifacts( 60 run_id=run_id, path=path, dst_path=tmp_path / str(uuid.uuid4()) 61 ) 62 files = [f.name for f in Path(out_path).iterdir() if f.is_file()] 63 assert "MLmodel" in files 64 assert "test.txt" in files 65 66 67 def check_download_artifacts_with_model_uri(model_uri: str, tmp_path: Path) -> None: 68 out_path = mlflow.artifacts.download_artifacts( 69 artifact_uri=model_uri, dst_path=tmp_path / str(uuid.uuid4()) 70 ) 71 files = [f.name for f in Path(out_path).iterdir() if f.is_file()] 72 # Ensure both run and model artifacts are downloaded 73 assert "MLmodel" in files 74 assert "test.txt" in files 75 76 77 def check_evaluate(model_uri: str) -> None: 78 # Model evaluation 79 eval_res = mlflow.models.evaluate( 80 model=model_uri, 81 data=np.array([[1, 2]]), 82 targets=np.array([3]), 83 model_type="regressor", 84 ) 85 assert "mean_squared_error" in eval_res.metrics 86 87 88 def check_spark_udf(model_uri: str) -> None: 89 # Spark UDF 90 if os.name != "nt": 91 with SparkSession.builder.getOrCreate() as spark: 92 udf = mlflow.pyfunc.spark_udf( 93 spark, 94 model_uri, 95 result_type="double", 96 env_manager="local", 97 ) 98 df = spark.createDataFrame([[1, 2]], ["col1", "col2"]) 99 # This line fails with the following error on Windows: 100 # File ".../pyspark\python\lib\pyspark.zip\pyspark\serializers.py", line 472, in loads 101 # return cloudpickle.loads(obj, encoding=encoding) 102 # ModuleNotFoundError: No module named 'pandas' 103 pred = df.select(udf("col1", "col2").alias("pred")).collect() 104 assert [row.pred for row in pred] == [3.0] 105 106 107 def test_mlflow_2_x_comp(tmp_path: Path) -> None: 108 tracking_uri = f"sqlite:///{tmp_path / 'mlflow.db'}" 109 artifact_location = (tmp_path / "artifacts").as_uri() 110 111 out_file = tmp_path / "out.txt" 112 # Log a model using MLflow 2.x (let 2.x create the DB and experiment) 113 py_ver = ".".join(map(str, sys.version_info[:2])) 114 subprocess.check_call( 115 [ 116 "uv", 117 "run", 118 "--isolated", 119 "--no-project", 120 "--index-strategy=unsafe-first-match", 121 f"--python={py_ver}", 122 # Use mlflow 2.x 123 "--with=mlflow<3.0", 124 # Pin numpy and sklearn versions to ensure the model can be loaded 125 f"--with=numpy=={np.__version__}", 126 f"--with=scikit-learn=={sklearn.__version__}", 127 "python", 128 # Use the isolated mode to ignore mlflow in the repository 129 "-I", 130 "-c", 131 """ 132 import sys 133 import mlflow 134 from sklearn.linear_model import LinearRegression 135 136 assert mlflow.__version__.startswith("2."), mlflow.__version__ 137 138 tracking_uri, artifact_location, out = sys.argv[1:] 139 mlflow.set_tracking_uri(tracking_uri) 140 exp_id = mlflow.create_experiment("test", artifact_location=artifact_location) 141 mlflow.set_experiment(experiment_id=exp_id) 142 143 fitted_model= LinearRegression().fit([[1, 2]], [3]) 144 with mlflow.start_run() as run: 145 mlflow.log_text("test", "model/test.txt") 146 model_info = mlflow.sklearn.log_model(fitted_model, artifact_path="model") 147 assert model_info.model_uri.startswith("runs:/") 148 with open(out, "w") as f: 149 f.write(run.info.run_id) 150 """, 151 tracking_uri, 152 artifact_location, 153 out_file, 154 ], 155 ) 156 157 # 3.x opens the 2.x-created DB (migration happens automatically) 158 mlflow.set_tracking_uri(tracking_uri) 159 run_id = out_file.read_text().strip() 160 model_uri = f"runs:/{run_id}/model" 161 check_load(model_uri=model_uri) 162 check_register(model_uri=model_uri) 163 check_list_artifacts_with_run_id_and_path(run_id=run_id, path="model") 164 check_list_artifacts_with_model_uri(model_uri=model_uri) 165 check_download_artifacts_with_run_id_and_path(run_id=run_id, path="model", tmp_path=tmp_path) 166 check_download_artifacts_with_model_uri(model_uri=model_uri, tmp_path=tmp_path) 167 check_evaluate(model_uri=model_uri) 168 check_spark_udf(model_uri=model_uri) 169 170 171 def test_mlflow_3_x_comp(tmp_path: Path) -> None: 172 tracking_uri = f"sqlite:///{tmp_path / 'mlflow.db'}" 173 mlflow.set_tracking_uri(tracking_uri) 174 artifact_location = (tmp_path / "artifacts").as_uri() 175 exp_id = mlflow.create_experiment("test", artifact_location=artifact_location) 176 mlflow.set_experiment(experiment_id=exp_id) 177 178 fitted_model = LinearRegression().fit([[1, 2]], [3]) 179 with mlflow.start_run() as run: 180 mlflow.log_text("test", "model/test.txt") 181 model_info = mlflow.sklearn.log_model(fitted_model, name="model") 182 183 # Runs URI 184 run_id = run.info.run_id 185 runs_model_uri = f"runs:/{run_id}/model" 186 check_load(model_uri=runs_model_uri) 187 check_register(model_uri=runs_model_uri) 188 check_list_artifacts_with_run_id_and_path(run_id=run_id, path="model") 189 check_list_artifacts_with_model_uri(model_uri=runs_model_uri) 190 check_download_artifacts_with_run_id_and_path(run_id=run_id, path="model", tmp_path=tmp_path) 191 check_download_artifacts_with_model_uri(model_uri=runs_model_uri, tmp_path=tmp_path) 192 check_evaluate(model_uri=runs_model_uri) 193 check_spark_udf(model_uri=runs_model_uri) 194 195 # Models URI 196 logged_model_uri = f"models:/{model_info.model_id}" 197 check_load(model_uri=logged_model_uri) 198 check_register(model_uri=logged_model_uri) 199 artifacts = [a.path for a in mlflow.artifacts.list_artifacts(artifact_uri=logged_model_uri)] 200 assert "MLmodel" in artifacts 201 out_path = mlflow.artifacts.download_artifacts( 202 artifact_uri=logged_model_uri, dst_path=tmp_path / str(uuid.uuid4()) 203 ) 204 files = [f.name for f in Path(out_path).iterdir() if f.is_file()] 205 assert "MLmodel" in files 206 check_evaluate(model_uri=logged_model_uri) 207 check_spark_udf(model_uri=logged_model_uri) 208 209 210 def test_run_and_model_has_artifact_with_same_name(tmp_path: Path) -> None: 211 fitted_model = LinearRegression().fit([[1, 2]], [3]) 212 with mlflow.start_run() as run: 213 mlflow.log_text("", artifact_file="model/MLmodel") 214 info = mlflow.sklearn.log_model(fitted_model, name="model") 215 216 client = mlflow.MlflowClient() 217 artifacts = client.list_artifacts(run_id=run.info.run_id, path="model") 218 mlmodel_files = [a.path for a in artifacts if a.path.endswith("MLmodel")] 219 # Both run and model artifacts should be listed 220 assert len(mlmodel_files) == 2 221 out = mlflow.artifacts.download_artifacts( 222 run_id=run.info.run_id, 223 artifact_path="model", 224 dst_path=tmp_path / str(uuid.uuid4()), 225 ) 226 mlmodel_files = list(Path(out).rglob("MLmodel")) 227 assert len(mlmodel_files) == 1 228 # The model MLmodel file should overwrite the run MLmodel file 229 assert info.model_id in mlmodel_files[0].read_text()