/ tests / test_mlflow_version_comp.py
test_mlflow_version_comp.py
  1  import os
  2  import subprocess
  3  import sys
  4  import uuid
  5  from pathlib import Path
  6  
  7  import numpy as np
  8  import sklearn
  9  from pyspark.sql import SparkSession
 10  from sklearn.linear_model import LinearRegression
 11  
 12  import mlflow
 13  from mlflow.models import Model
 14  
 15  
 16  def check_load(model_uri: str) -> None:
 17      Model.load(model_uri)
 18      model = mlflow.sklearn.load_model(model_uri)
 19      np.testing.assert_array_equal(model.predict([[1, 2]]), [3.0])
 20      model = mlflow.pyfunc.load_model(model_uri)
 21      np.testing.assert_array_equal(model.predict([[1, 2]]), [3.0])
 22  
 23  
 24  def check_register(model_uri: str) -> None:
 25      mv = mlflow.register_model(model_uri, "model")
 26      model = mlflow.pyfunc.load_model(f"models:/{mv.name}/{mv.version}")
 27      np.testing.assert_array_equal(model.predict([[1, 2]]), [3.0])
 28  
 29  
 30  def check_list_artifacts_with_run_id_and_path(run_id: str, path: str) -> None:
 31      # List artifacts
 32      client = mlflow.MlflowClient()
 33      artifacts = [a.path for a in client.list_artifacts(run_id=run_id, path=path)]
 34      # Ensure both run and model artifacts are listed
 35      assert "model/MLmodel" in artifacts
 36      assert "model/test.txt" in artifacts
 37      artifacts = [a.path for a in client.list_artifacts(run_id=run_id, path=path)]
 38      assert "model/MLmodel" in artifacts
 39      assert "model/test.txt" in artifacts
 40      # Non-existing artifact path should return an empty list
 41      assert len(client.list_artifacts(run_id=run_id, path="unknown")) == 0
 42      assert len(mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path="unknown")) == 0
 43  
 44  
 45  def check_list_artifacts_with_model_uri(model_uri: str) -> None:
 46      artifacts = [a.path for a in mlflow.artifacts.list_artifacts(artifact_uri=model_uri)]
 47      assert "model/MLmodel" in artifacts
 48      assert "model/test.txt" in artifacts
 49  
 50  
 51  def check_download_artifacts_with_run_id_and_path(run_id: str, path: str, tmp_path: Path) -> None:
 52      out_path = mlflow.artifacts.download_artifacts(
 53          run_id=run_id, artifact_path=path, dst_path=tmp_path / str(uuid.uuid4())
 54      )
 55      files = [f.name for f in Path(out_path).iterdir() if f.is_file()]
 56      assert "MLmodel" in files
 57      assert "test.txt" in files
 58      client = mlflow.MlflowClient()
 59      out_path = client.download_artifacts(
 60          run_id=run_id, path=path, dst_path=tmp_path / str(uuid.uuid4())
 61      )
 62      files = [f.name for f in Path(out_path).iterdir() if f.is_file()]
 63      assert "MLmodel" in files
 64      assert "test.txt" in files
 65  
 66  
 67  def check_download_artifacts_with_model_uri(model_uri: str, tmp_path: Path) -> None:
 68      out_path = mlflow.artifacts.download_artifacts(
 69          artifact_uri=model_uri, dst_path=tmp_path / str(uuid.uuid4())
 70      )
 71      files = [f.name for f in Path(out_path).iterdir() if f.is_file()]
 72      # Ensure both run and model artifacts are downloaded
 73      assert "MLmodel" in files
 74      assert "test.txt" in files
 75  
 76  
 77  def check_evaluate(model_uri: str) -> None:
 78      # Model evaluation
 79      eval_res = mlflow.models.evaluate(
 80          model=model_uri,
 81          data=np.array([[1, 2]]),
 82          targets=np.array([3]),
 83          model_type="regressor",
 84      )
 85      assert "mean_squared_error" in eval_res.metrics
 86  
 87  
 88  def check_spark_udf(model_uri: str) -> None:
 89      # Spark UDF
 90      if os.name != "nt":
 91          with SparkSession.builder.getOrCreate() as spark:
 92              udf = mlflow.pyfunc.spark_udf(
 93                  spark,
 94                  model_uri,
 95                  result_type="double",
 96                  env_manager="local",
 97              )
 98              df = spark.createDataFrame([[1, 2]], ["col1", "col2"])
 99              # This line fails with the following error on Windows:
100              #   File ".../pyspark\python\lib\pyspark.zip\pyspark\serializers.py", line 472, in loads
101              #     return cloudpickle.loads(obj, encoding=encoding)
102              # ModuleNotFoundError: No module named 'pandas'
103              pred = df.select(udf("col1", "col2").alias("pred")).collect()
104              assert [row.pred for row in pred] == [3.0]
105  
106  
107  def test_mlflow_2_x_comp(tmp_path: Path) -> None:
108      tracking_uri = f"sqlite:///{tmp_path / 'mlflow.db'}"
109      artifact_location = (tmp_path / "artifacts").as_uri()
110  
111      out_file = tmp_path / "out.txt"
112      # Log a model using MLflow 2.x (let 2.x create the DB and experiment)
113      py_ver = ".".join(map(str, sys.version_info[:2]))
114      subprocess.check_call(
115          [
116              "uv",
117              "run",
118              "--isolated",
119              "--no-project",
120              "--index-strategy=unsafe-first-match",
121              f"--python={py_ver}",
122              # Use mlflow 2.x
123              "--with=mlflow<3.0",
124              # Pin numpy and sklearn versions to ensure the model can be loaded
125              f"--with=numpy=={np.__version__}",
126              f"--with=scikit-learn=={sklearn.__version__}",
127              "python",
128              # Use the isolated mode to ignore mlflow in the repository
129              "-I",
130              "-c",
131              """
132  import sys
133  import mlflow
134  from sklearn.linear_model import LinearRegression
135  
136  assert mlflow.__version__.startswith("2."), mlflow.__version__
137  
138  tracking_uri, artifact_location, out = sys.argv[1:]
139  mlflow.set_tracking_uri(tracking_uri)
140  exp_id = mlflow.create_experiment("test", artifact_location=artifact_location)
141  mlflow.set_experiment(experiment_id=exp_id)
142  
143  fitted_model= LinearRegression().fit([[1, 2]], [3])
144  with mlflow.start_run() as run:
145      mlflow.log_text("test", "model/test.txt")
146      model_info = mlflow.sklearn.log_model(fitted_model, artifact_path="model")
147      assert model_info.model_uri.startswith("runs:/")
148      with open(out, "w") as f:
149          f.write(run.info.run_id)
150  """,
151              tracking_uri,
152              artifact_location,
153              out_file,
154          ],
155      )
156  
157      # 3.x opens the 2.x-created DB (migration happens automatically)
158      mlflow.set_tracking_uri(tracking_uri)
159      run_id = out_file.read_text().strip()
160      model_uri = f"runs:/{run_id}/model"
161      check_load(model_uri=model_uri)
162      check_register(model_uri=model_uri)
163      check_list_artifacts_with_run_id_and_path(run_id=run_id, path="model")
164      check_list_artifacts_with_model_uri(model_uri=model_uri)
165      check_download_artifacts_with_run_id_and_path(run_id=run_id, path="model", tmp_path=tmp_path)
166      check_download_artifacts_with_model_uri(model_uri=model_uri, tmp_path=tmp_path)
167      check_evaluate(model_uri=model_uri)
168      check_spark_udf(model_uri=model_uri)
169  
170  
171  def test_mlflow_3_x_comp(tmp_path: Path) -> None:
172      tracking_uri = f"sqlite:///{tmp_path / 'mlflow.db'}"
173      mlflow.set_tracking_uri(tracking_uri)
174      artifact_location = (tmp_path / "artifacts").as_uri()
175      exp_id = mlflow.create_experiment("test", artifact_location=artifact_location)
176      mlflow.set_experiment(experiment_id=exp_id)
177  
178      fitted_model = LinearRegression().fit([[1, 2]], [3])
179      with mlflow.start_run() as run:
180          mlflow.log_text("test", "model/test.txt")
181          model_info = mlflow.sklearn.log_model(fitted_model, name="model")
182  
183      # Runs URI
184      run_id = run.info.run_id
185      runs_model_uri = f"runs:/{run_id}/model"
186      check_load(model_uri=runs_model_uri)
187      check_register(model_uri=runs_model_uri)
188      check_list_artifacts_with_run_id_and_path(run_id=run_id, path="model")
189      check_list_artifacts_with_model_uri(model_uri=runs_model_uri)
190      check_download_artifacts_with_run_id_and_path(run_id=run_id, path="model", tmp_path=tmp_path)
191      check_download_artifacts_with_model_uri(model_uri=runs_model_uri, tmp_path=tmp_path)
192      check_evaluate(model_uri=runs_model_uri)
193      check_spark_udf(model_uri=runs_model_uri)
194  
195      # Models URI
196      logged_model_uri = f"models:/{model_info.model_id}"
197      check_load(model_uri=logged_model_uri)
198      check_register(model_uri=logged_model_uri)
199      artifacts = [a.path for a in mlflow.artifacts.list_artifacts(artifact_uri=logged_model_uri)]
200      assert "MLmodel" in artifacts
201      out_path = mlflow.artifacts.download_artifacts(
202          artifact_uri=logged_model_uri, dst_path=tmp_path / str(uuid.uuid4())
203      )
204      files = [f.name for f in Path(out_path).iterdir() if f.is_file()]
205      assert "MLmodel" in files
206      check_evaluate(model_uri=logged_model_uri)
207      check_spark_udf(model_uri=logged_model_uri)
208  
209  
210  def test_run_and_model_has_artifact_with_same_name(tmp_path: Path) -> None:
211      fitted_model = LinearRegression().fit([[1, 2]], [3])
212      with mlflow.start_run() as run:
213          mlflow.log_text("", artifact_file="model/MLmodel")
214          info = mlflow.sklearn.log_model(fitted_model, name="model")
215  
216      client = mlflow.MlflowClient()
217      artifacts = client.list_artifacts(run_id=run.info.run_id, path="model")
218      mlmodel_files = [a.path for a in artifacts if a.path.endswith("MLmodel")]
219      # Both run and model artifacts should be listed
220      assert len(mlmodel_files) == 2
221      out = mlflow.artifacts.download_artifacts(
222          run_id=run.info.run_id,
223          artifact_path="model",
224          dst_path=tmp_path / str(uuid.uuid4()),
225      )
226      mlmodel_files = list(Path(out).rglob("MLmodel"))
227      assert len(mlmodel_files) == 1
228      # The model MLmodel file should overwrite the run MLmodel file
229      assert info.model_id in mlmodel_files[0].read_text()