/ tests / models / test_cli.py
test_cli.py
   1  import json
   2  import os
   3  import re
   4  import shutil
   5  import subprocess
   6  import sys
   7  import warnings
   8  from dataclasses import dataclass
   9  from io import BytesIO, StringIO
  10  from pathlib import Path
  11  from unittest import mock
  12  
  13  import numpy as np
  14  import pandas as pd
  15  import pytest
  16  import sklearn
  17  import sklearn.datasets
  18  import sklearn.neighbors
  19  from click.testing import CliRunner
  20  from packaging.requirements import Requirement
  21  
  22  import mlflow
  23  import mlflow.models.cli as models_cli
  24  import mlflow.sklearn
  25  from mlflow.environment_variables import MLFLOW_DISABLE_ENV_MANAGER_CONDA_WARNING
  26  from mlflow.exceptions import MlflowException
  27  from mlflow.models.flavor_backend_registry import get_flavor_backend
  28  from mlflow.models.model import get_model_requirements_files, update_model_requirements
  29  from mlflow.models.utils import load_serving_example
  30  from mlflow.protos.databricks_pb2 import BAD_REQUEST, ErrorCode
  31  from mlflow.pyfunc.backend import PyFuncBackend
  32  from mlflow.pyfunc.scoring_server import (
  33      CONTENT_TYPE_CSV,
  34      CONTENT_TYPE_JSON,
  35  )
  36  from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
  37  from mlflow.utils import PYTHON_VERSION
  38  from mlflow.utils import env_manager as _EnvManager
  39  from mlflow.utils.conda import _get_conda_env_name
  40  from mlflow.utils.environment import (
  41      _get_requirements_from_file,
  42      _mlflow_conda_env,
  43  )
  44  from mlflow.utils.file_utils import TempDir
  45  from mlflow.utils.process import ShellCommandException
  46  
  47  from tests.helper_functions import (
  48      PROTOBUF_REQUIREMENT,
  49      RestEndpoint,
  50      get_safe_port,
  51      pyfunc_build_image,
  52      pyfunc_generate_dockerfile,
  53      pyfunc_serve_and_score_model,
  54      pyfunc_serve_from_docker_image,
  55      pyfunc_serve_from_docker_image_with_env_override,
  56  )
  57  
  58  # NB: for now, windows tests do not have conda available.
  59  no_conda = ["--env-manager", "local"] if sys.platform == "win32" else []
  60  
  61  # NB: need to install mlflow since the pip version does not have mlflow models cli.
  62  install_mlflow = ["--install-mlflow"] if not no_conda else []
  63  
  64  extra_options = no_conda + install_mlflow
  65  
  66  
  67  def env_with_tracking_uri() -> dict[str, str]:
  68      return {**os.environ, "MLFLOW_TRACKING_URI": mlflow.get_tracking_uri()}
  69  
  70  
  71  @pytest.fixture(scope="module")
  72  def iris_data() -> tuple[np.ndarray, np.ndarray]:
  73      iris = sklearn.datasets.load_iris()
  74      x = iris.data[:, :2]
  75      y = iris.target
  76      return x, y
  77  
  78  
  79  @pytest.fixture(scope="module")
  80  def sk_model(iris_data: tuple[np.ndarray, np.ndarray]) -> sklearn.neighbors.KNeighborsClassifier:
  81      x, y = iris_data
  82      knn_model = sklearn.neighbors.KNeighborsClassifier()
  83      knn_model.fit(x, y)
  84      return knn_model
  85  
  86  
  87  @pytest.mark.allow_infer_pip_requirements_fallback
  88  def test_mlflow_is_not_installed_unless_specified():
  89      if no_conda:
  90          pytest.skip("This test requires conda.")
  91      with TempDir(chdr=True) as tmp:
  92          fake_model_path = tmp.path("fake_model")
  93          mlflow.pyfunc.save_model(fake_model_path, loader_module=__name__)
  94          # Overwrite the logged `conda.yaml` to remove mlflow.
  95          _mlflow_conda_env(path=os.path.join(fake_model_path, "conda.yaml"), install_mlflow=False)
  96          # The following should fail because there should be no mlflow in the env:
  97          prc = subprocess.run(
  98              [
  99                  sys.executable,
 100                  "-m",
 101                  "mlflow",
 102                  "models",
 103                  "predict",
 104                  "-m",
 105                  fake_model_path,
 106                  "--env-manager",
 107                  "conda",
 108              ],
 109              stderr=subprocess.PIPE,
 110              cwd=tmp.path(""),
 111              check=False,
 112              text=True,
 113              env=env_with_tracking_uri(),
 114          )
 115          assert prc.returncode != 0
 116          if PYTHON_VERSION.startswith("3"):
 117              assert "ModuleNotFoundError: No module named 'mlflow'" in prc.stderr
 118          else:
 119              assert "ImportError: No module named mlflow.pyfunc.scoring_server" in prc.stderr
 120  
 121  
 122  def test_model_with_no_deployable_flavors_fails_pollitely():
 123      from mlflow.models import Model
 124  
 125      with TempDir(chdr=True) as tmp:
 126          m = Model(
 127              artifact_path=None,
 128              run_id=None,
 129              utc_time_created="now",
 130              flavors={"some": {}, "useless": {}, "flavors": {}},
 131          )
 132          os.mkdir(tmp.path("model"))
 133          m.save(tmp.path("model", "MLmodel"))
 134          # The following should fail because there should be no suitable flavor
 135          prc = subprocess.run(
 136              [sys.executable, "-m", "mlflow", "models", "predict", "-m", tmp.path("model")],
 137              stderr=subprocess.PIPE,
 138              cwd=tmp.path(""),
 139              check=False,
 140              text=True,
 141              env=env_with_tracking_uri(),
 142          )
 143          assert "No suitable flavor backend was found for the model." in prc.stderr
 144  
 145  
 146  def test_serve_uvicorn_opts(iris_data, sk_model):
 147      if sys.platform == "win32":
 148          pytest.skip("This test requires gunicorn which is not available on windows.")
 149      with mlflow.start_run():
 150          x, _ = iris_data
 151          model_info = mlflow.sklearn.log_model(
 152              sk_model, name="model", registered_model_name="test", input_example=pd.DataFrame(x)
 153          )
 154  
 155      model_uris = ["models:/test/None", model_info.model_uri]
 156      for model_uri in model_uris:
 157          with TempDir() as tpm:
 158              output_file_path = tpm.path("stdout")
 159              inference_payload = load_serving_example(model_uri)
 160              with open(output_file_path, "w") as output_file:
 161                  scoring_response = pyfunc_serve_and_score_model(
 162                      model_uri,
 163                      inference_payload,
 164                      content_type=CONTENT_TYPE_JSON,
 165                      stdout=output_file,
 166                      extra_args=["-w", "3", "--env-manager", "local"],
 167                  )
 168              with open(output_file_path) as output_file:
 169                  stdout = output_file.read()
 170          actual = pd.read_json(BytesIO(scoring_response.content), orient="records")
 171          actual = actual[actual.columns[0]].values
 172          expected = sk_model.predict(x)
 173          assert all(expected == actual)
 174          expected_command_pattern = re.compile(
 175              r"uvicorn.*--workers 3.*mlflow\.pyfunc\.scoring_server\.app:app"
 176          )
 177          assert expected_command_pattern.search(stdout) is not None
 178  
 179  
 180  @dataclass
 181  class PredictTestData:
 182      model_uri: str
 183      model_registry_uri: str
 184      input_json_path: Path
 185      input_csv_path: Path
 186      output_json_path: Path
 187      x: np.ndarray
 188      sk_model: sklearn.base.BaseEstimator
 189  
 190  
 191  @pytest.fixture
 192  def predict_test_setup(
 193      iris_data: tuple[np.ndarray, np.ndarray],
 194      sk_model: sklearn.neighbors.KNeighborsClassifier,
 195      tmp_path: Path,
 196  ) -> PredictTestData:
 197      with mlflow.start_run() as active_run:
 198          mlflow.sklearn.log_model(sk_model, name="model", registered_model_name="impredicting")
 199          model_uri = f"runs:/{active_run.info.run_id}/model"
 200  
 201      model_registry_uri = "models:/impredicting/None"
 202      input_json_path = tmp_path / "input.json"
 203      input_csv_path = tmp_path / "input.csv"
 204      output_json_path = tmp_path / "output.json"
 205  
 206      x, _ = iris_data
 207      with open(input_json_path, "w") as f:
 208          json.dump({"dataframe_split": pd.DataFrame(x).to_dict(orient="split")}, f)
 209      pd.DataFrame(x).to_csv(input_csv_path, index=False)
 210  
 211      return PredictTestData(
 212          model_uri=model_uri,
 213          model_registry_uri=model_registry_uri,
 214          input_json_path=input_json_path,
 215          input_csv_path=input_csv_path,
 216          output_json_path=output_json_path,
 217          x=x,
 218          sk_model=sk_model,
 219      )
 220  
 221  
 222  def test_predict_with_model_registry_uri(predict_test_setup: PredictTestData) -> None:
 223      setup = predict_test_setup
 224      subprocess.check_call(
 225          [
 226              sys.executable,
 227              "-m",
 228              "mlflow",
 229              "models",
 230              "predict",
 231              "-m",
 232              setup.model_registry_uri,
 233              "-i",
 234              setup.input_json_path,
 235              "-o",
 236              setup.output_json_path,
 237              "--env-manager",
 238              "local",
 239          ],
 240          env=env_with_tracking_uri(),
 241      )
 242      actual = pd.read_json(setup.output_json_path, orient="records")
 243      actual = actual[actual.columns[0]].values
 244      expected = setup.sk_model.predict(setup.x)
 245      assert all(expected == actual)
 246  
 247  
 248  def test_predict_with_conda_and_install_mlflow(predict_test_setup: PredictTestData) -> None:
 249      setup = predict_test_setup
 250      subprocess.check_call(
 251          [
 252              sys.executable,
 253              "-m",
 254              "mlflow",
 255              "models",
 256              "predict",
 257              "-m",
 258              setup.model_uri,
 259              "-i",
 260              setup.input_json_path,
 261              "-o",
 262              setup.output_json_path,
 263              *extra_options,
 264          ],
 265          env=env_with_tracking_uri(),
 266      )
 267      actual = pd.read_json(setup.output_json_path, orient="records")
 268      actual = actual[actual.columns[0]].values
 269      expected = setup.sk_model.predict(setup.x)
 270      assert all(expected == actual)
 271  
 272  
 273  def test_predict_explicit_json_format_default_orient(predict_test_setup: PredictTestData) -> None:
 274      setup = predict_test_setup
 275      subprocess.check_call(
 276          [
 277              sys.executable,
 278              "-m",
 279              "mlflow",
 280              "models",
 281              "predict",
 282              "-m",
 283              setup.model_uri,
 284              "-i",
 285              setup.input_json_path,
 286              "-o",
 287              setup.output_json_path,
 288              "-t",
 289              "json",
 290              *extra_options,
 291          ],
 292          env=env_with_tracking_uri(),
 293      )
 294      actual = pd.read_json(setup.output_json_path, orient="records")
 295      actual = actual[actual.columns[0]].values
 296      expected = setup.sk_model.predict(setup.x)
 297      assert all(expected == actual)
 298  
 299  
 300  def test_predict_explicit_json_format_split_orient(predict_test_setup: PredictTestData) -> None:
 301      # Note: This test has the same command as the previous one but tests orient==split
 302      # The comment in original code mentions this should be split orient
 303      setup = predict_test_setup
 304      subprocess.check_call(
 305          [
 306              sys.executable,
 307              "-m",
 308              "mlflow",
 309              "models",
 310              "predict",
 311              "-m",
 312              setup.model_uri,
 313              "-i",
 314              setup.input_json_path,
 315              "-o",
 316              setup.output_json_path,
 317              "-t",
 318              "json",
 319              *extra_options,
 320          ],
 321          env=env_with_tracking_uri(),
 322      )
 323      actual = pd.read_json(setup.output_json_path, orient="records")
 324      actual = actual[actual.columns[0]].values
 325      expected = setup.sk_model.predict(setup.x)
 326      assert all(expected == actual)
 327  
 328  
 329  def test_predict_stdin_stdout(predict_test_setup: PredictTestData) -> None:
 330      setup = predict_test_setup
 331      stdout = subprocess.check_output(
 332          [
 333              sys.executable,
 334              "-m",
 335              "mlflow",
 336              "models",
 337              "predict",
 338              "-m",
 339              setup.model_uri,
 340              "-t",
 341              "json",
 342              *extra_options,
 343          ],
 344          input=setup.input_json_path.read_text(),
 345          env=env_with_tracking_uri(),
 346          text=True,
 347      )
 348      predictions = re.search(r"{\"predictions\": .*}", stdout).group(0)
 349      actual = pd.read_json(StringIO(predictions), orient="records")
 350      actual = actual[actual.columns[0]].values
 351      expected = setup.sk_model.predict(setup.x)
 352      assert all(expected == actual)
 353      # NB: We do not test orient=records here because records may loose column ordering.
 354      # orient == records is tested in other test with simpler model.
 355  
 356  
 357  def test_predict_csv_format(predict_test_setup: PredictTestData) -> None:
 358      setup = predict_test_setup
 359      subprocess.check_call(
 360          [
 361              sys.executable,
 362              "-m",
 363              "mlflow",
 364              "models",
 365              "predict",
 366              "-m",
 367              setup.model_uri,
 368              "-i",
 369              setup.input_csv_path,
 370              "-o",
 371              setup.output_json_path,
 372              "-t",
 373              "csv",
 374              *extra_options,
 375          ],
 376          env=env_with_tracking_uri(),
 377      )
 378      actual = pd.read_json(setup.output_json_path, orient="records")
 379      actual = actual[actual.columns[0]].values
 380      expected = setup.sk_model.predict(setup.x)
 381      assert all(expected == actual)
 382  
 383  
 384  def test_predict_check_content_type(iris_data, sk_model, tmp_path):
 385      with mlflow.start_run():
 386          mlflow.sklearn.log_model(sk_model, name="model", registered_model_name="impredicting")
 387      model_registry_uri = "models:/impredicting/None"
 388      input_json_path = tmp_path / "input.json"
 389      input_csv_path = tmp_path / "input.csv"
 390      output_json_path = tmp_path / "output.json"
 391  
 392      x, _ = iris_data
 393      with input_json_path.open("w") as f:
 394          json.dump({"dataframe_split": pd.DataFrame(x).to_dict(orient="split")}, f)
 395  
 396      pd.DataFrame(x).to_csv(input_csv_path, index=False)
 397  
 398      # Throw errors for invalid content_type
 399      prc = subprocess.run(
 400          [
 401              sys.executable,
 402              "-m",
 403              "mlflow",
 404              "models",
 405              "predict",
 406              "-m",
 407              model_registry_uri,
 408              "-i",
 409              input_json_path,
 410              "-o",
 411              output_json_path,
 412              "-t",
 413              "invalid",
 414              "--env-manager",
 415              "local",
 416          ],
 417          stdout=subprocess.PIPE,
 418          stderr=subprocess.PIPE,
 419          env=env_with_tracking_uri(),
 420          check=False,
 421      )
 422      assert prc.returncode != 0
 423      assert "Content type must be one of json or csv." in prc.stderr.decode("utf-8")
 424  
 425  
 426  def test_predict_check_input_path(iris_data, sk_model, tmp_path):
 427      with mlflow.start_run():
 428          mlflow.sklearn.log_model(sk_model, name="model", registered_model_name="impredicting")
 429      model_registry_uri = "models:/impredicting/None"
 430      input_json_path = tmp_path / "input with space.json"
 431      input_csv_path = tmp_path / "input.csv"
 432      output_json_path = tmp_path / "output.json"
 433  
 434      x, _ = iris_data
 435      with input_json_path.open("w") as f:
 436          json.dump({"dataframe_split": pd.DataFrame(x).to_dict(orient="split")}, f)
 437  
 438      pd.DataFrame(x).to_csv(input_csv_path, index=False)
 439  
 440      # Valid input path with space
 441      prc = subprocess.run(
 442          [
 443              sys.executable,
 444              "-m",
 445              "mlflow",
 446              "models",
 447              "predict",
 448              "-m",
 449              model_registry_uri,
 450              "-i",
 451              f"{input_json_path}",
 452              "-o",
 453              output_json_path,
 454              "--env-manager",
 455              "local",
 456          ],
 457          stdout=subprocess.PIPE,
 458          stderr=subprocess.PIPE,
 459          env=env_with_tracking_uri(),
 460          check=False,
 461          text=True,
 462      )
 463      assert prc.returncode == 0
 464  
 465      # Throw errors for invalid input_path
 466      prc = subprocess.run(
 467          [
 468              sys.executable,
 469              "-m",
 470              "mlflow",
 471              "models",
 472              "predict",
 473              "-m",
 474              model_registry_uri,
 475              "-i",
 476              f'{input_json_path}"; echo ThisIsABug! "',
 477              "-o",
 478              output_json_path,
 479              "--env-manager",
 480              "local",
 481          ],
 482          stdout=subprocess.PIPE,
 483          stderr=subprocess.PIPE,
 484          env=env_with_tracking_uri(),
 485          check=False,
 486          text=True,
 487      )
 488      assert prc.returncode != 0
 489      assert "ThisIsABug!" not in prc.stdout
 490      assert "FileNotFoundError" in prc.stderr
 491  
 492      prc = subprocess.run(
 493          [
 494              sys.executable,
 495              "-m",
 496              "mlflow",
 497              "models",
 498              "predict",
 499              "-m",
 500              model_registry_uri,
 501              "-i",
 502              f'{input_csv_path}"; echo ThisIsABug! "',
 503              "-o",
 504              output_json_path,
 505              "-t",
 506              "csv",
 507              "--env-manager",
 508              "local",
 509          ],
 510          stdout=subprocess.PIPE,
 511          stderr=subprocess.PIPE,
 512          env=env_with_tracking_uri(),
 513          check=False,
 514          text=True,
 515      )
 516      assert prc.returncode != 0
 517      assert "ThisIsABug!" not in prc.stdout
 518      assert "FileNotFoundError" in prc.stderr
 519  
 520  
 521  def test_predict_check_output_path(iris_data, sk_model, tmp_path):
 522      with mlflow.start_run():
 523          mlflow.sklearn.log_model(sk_model, name="model", registered_model_name="impredicting")
 524      model_registry_uri = "models:/impredicting/None"
 525      input_json_path = tmp_path / "input.json"
 526      input_csv_path = tmp_path / "input.csv"
 527      output_json_path = tmp_path / "output.json"
 528  
 529      x, _ = iris_data
 530      with input_json_path.open("w") as f:
 531          json.dump({"dataframe_split": pd.DataFrame(x).to_dict(orient="split")}, f)
 532  
 533      pd.DataFrame(x).to_csv(input_csv_path, index=False)
 534  
 535      prc = subprocess.run(
 536          [
 537              sys.executable,
 538              "-m",
 539              "mlflow",
 540              "models",
 541              "predict",
 542              "-m",
 543              model_registry_uri,
 544              "-i",
 545              input_json_path,
 546              "-o",
 547              f'{output_json_path}"; echo ThisIsABug! "',
 548              "--env-manager",
 549              "local",
 550          ],
 551          stdout=subprocess.PIPE,
 552          stderr=subprocess.PIPE,
 553          env=env_with_tracking_uri(),
 554          check=False,
 555          text=True,
 556      )
 557      assert prc.returncode == 0
 558      assert "ThisIsABug!" not in prc.stdout
 559  
 560  
 561  def test_prepare_env_passes(sk_model):
 562      if no_conda:
 563          pytest.skip("This test requires conda.")
 564  
 565      with TempDir(chdr=True):
 566          with mlflow.start_run() as active_run:
 567              mlflow.sklearn.log_model(sk_model, name="model")
 568              model_uri = f"runs:/{active_run.info.run_id}/model"
 569  
 570          # With conda
 571          subprocess.run(
 572              [
 573                  sys.executable,
 574                  "-m",
 575                  "mlflow",
 576                  "models",
 577                  "prepare-env",
 578                  "-m",
 579                  model_uri,
 580              ],
 581              env=env_with_tracking_uri(),
 582              check=True,
 583          )
 584  
 585          # Should be idempotent
 586          subprocess.run(
 587              [
 588                  sys.executable,
 589                  "-m",
 590                  "mlflow",
 591                  "models",
 592                  "prepare-env",
 593                  "-m",
 594                  model_uri,
 595              ],
 596              env=env_with_tracking_uri(),
 597              check=True,
 598          )
 599  
 600  
 601  def test_prepare_env_fails(sk_model):
 602      if no_conda:
 603          pytest.skip("This test requires conda.")
 604  
 605      with TempDir(chdr=True):
 606          with mlflow.start_run() as active_run:
 607              mlflow.sklearn.log_model(
 608                  sk_model, name="model", pip_requirements=["does-not-exist-dep==abc"]
 609              )
 610              model_uri = f"runs:/{active_run.info.run_id}/model"
 611  
 612          # With conda - should fail due to bad conda environment.
 613          prc = subprocess.run(
 614              [
 615                  sys.executable,
 616                  "-m",
 617                  "mlflow",
 618                  "models",
 619                  "prepare-env",
 620                  "-m",
 621                  model_uri,
 622              ],
 623              env=env_with_tracking_uri(),
 624              check=False,
 625          )
 626          assert prc.returncode != 0
 627  
 628  
 629  @pytest.mark.parametrize("enable_mlserver", [True, False])
 630  def test_generate_dockerfile(sk_model, enable_mlserver, tmp_path):
 631      with mlflow.start_run() as active_run:
 632          if enable_mlserver:
 633              mlflow.sklearn.log_model(
 634                  sk_model, name="model", extra_pip_requirements=["/opt/mlflow", PROTOBUF_REQUIREMENT]
 635              )
 636          else:
 637              mlflow.sklearn.log_model(sk_model, name="model")
 638          model_uri = f"runs:/{active_run.info.run_id}/model"
 639      extra_args = ["--install-mlflow"]
 640      if enable_mlserver:
 641          extra_args.append("--enable-mlserver")
 642  
 643      output_directory = tmp_path.joinpath("output_directory")
 644      pyfunc_generate_dockerfile(
 645          output_directory,
 646          model_uri,
 647          extra_args=extra_args,
 648          env=env_with_tracking_uri(),
 649      )
 650      assert output_directory.is_dir()
 651      assert output_directory.joinpath("Dockerfile").exists()
 652      assert output_directory.joinpath("model_dir").is_dir()
 653      # Assert file is not empty
 654      assert output_directory.joinpath("Dockerfile").stat().st_size != 0
 655  
 656  
 657  @pytest.mark.parametrize("enable_mlserver", [True, False])
 658  def test_build_docker(iris_data, sk_model, enable_mlserver):
 659      with mlflow.start_run() as active_run:
 660          if enable_mlserver:
 661              mlflow.sklearn.log_model(
 662                  sk_model, name="model", extra_pip_requirements=["/opt/mlflow", PROTOBUF_REQUIREMENT]
 663              )
 664          else:
 665              mlflow.sklearn.log_model(sk_model, name="model", extra_pip_requirements=["/opt/mlflow"])
 666          model_uri = f"runs:/{active_run.info.run_id}/model"
 667  
 668      x, _ = iris_data
 669      df = pd.DataFrame(x)
 670  
 671      extra_args = ["--install-mlflow"]
 672      if enable_mlserver:
 673          extra_args.append("--enable-mlserver")
 674  
 675      image_name = pyfunc_build_image(
 676          model_uri,
 677          extra_args=extra_args,
 678          env=env_with_tracking_uri(),
 679      )
 680      host_port = get_safe_port()
 681      scoring_proc = pyfunc_serve_from_docker_image(image_name, host_port)
 682      _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model, enable_mlserver)
 683  
 684  
 685  def test_build_docker_virtualenv(iris_data, sk_model):
 686      with mlflow.start_run():
 687          model_info = mlflow.sklearn.log_model(
 688              sk_model, name="model", extra_pip_requirements=["/opt/mlflow"]
 689          )
 690  
 691      x, _ = iris_data
 692      df = pd.DataFrame(iris_data[0])
 693  
 694      extra_args = ["--install-mlflow", "--env-manager", "virtualenv"]
 695      image_name = pyfunc_build_image(
 696          model_info.model_uri,
 697          extra_args=extra_args,
 698          env=env_with_tracking_uri(),
 699      )
 700      host_port = get_safe_port()
 701      scoring_proc = pyfunc_serve_from_docker_image(image_name, host_port)
 702      _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model)
 703  
 704  
 705  @pytest.mark.parametrize("enable_mlserver", [True, False])
 706  def test_build_docker_with_env_override(iris_data, sk_model, enable_mlserver):
 707      with mlflow.start_run() as active_run:
 708          if enable_mlserver:
 709              mlflow.sklearn.log_model(
 710                  sk_model, name="model", extra_pip_requirements=["/opt/mlflow", PROTOBUF_REQUIREMENT]
 711              )
 712          else:
 713              mlflow.sklearn.log_model(sk_model, name="model", extra_pip_requirements=["/opt/mlflow"])
 714          model_uri = f"runs:/{active_run.info.run_id}/model"
 715      x, _ = iris_data
 716      df = pd.DataFrame(x)
 717  
 718      extra_args = ["--install-mlflow"]
 719      if enable_mlserver:
 720          extra_args.append("--enable-mlserver")
 721  
 722      image_name = pyfunc_build_image(
 723          model_uri,
 724          extra_args=extra_args,
 725          env=env_with_tracking_uri(),
 726      )
 727      host_port = get_safe_port()
 728      scoring_proc = pyfunc_serve_from_docker_image_with_env_override(image_name, host_port)
 729      _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model, enable_mlserver)
 730  
 731  
 732  def test_build_docker_without_model_uri(iris_data, sk_model, tmp_path):
 733      model_path = tmp_path.joinpath("model")
 734      mlflow.sklearn.save_model(sk_model, model_path, extra_pip_requirements=["/opt/mlflow"])
 735      image_name = pyfunc_build_image(model_uri=None)
 736      host_port = get_safe_port()
 737      scoring_proc = pyfunc_serve_from_docker_image_with_env_override(
 738          image_name,
 739          host_port,
 740          extra_docker_run_options=["-v", f"{model_path}:/opt/ml/model"],
 741      )
 742      x = iris_data[0]
 743      df = pd.DataFrame(x)
 744      _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model)
 745  
 746  
 747  def _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model, enable_mlserver=False):
 748      with RestEndpoint(proc=scoring_proc, port=host_port, validate_version=False) as endpoint:
 749          for content_type in [CONTENT_TYPE_JSON, CONTENT_TYPE_CSV]:
 750              scoring_response = endpoint.invoke(df, content_type)
 751              assert scoring_response.status_code == 200, (
 752                  f"Failed to serve prediction, got response {scoring_response.text}"
 753              )
 754              np.testing.assert_array_equal(
 755                  np.array(json.loads(scoring_response.text)["predictions"]), sk_model.predict(x)
 756              )
 757          # Try examples of bad input, verify we get a non-200 status code
 758          for content_type in [CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_JSON]:
 759              scoring_response = endpoint.invoke(data="", content_type=content_type)
 760              expected_status_code = 500 if enable_mlserver else 400
 761              assert scoring_response.status_code == expected_status_code, (
 762                  f"Expected server failure with error code {expected_status_code}, "
 763                  f"got response with status code {scoring_response.status_code} "
 764                  f"and body {scoring_response.text}"
 765              )
 766  
 767              if enable_mlserver:
 768                  # MLServer returns a different set of errors.
 769                  # Skip these assertions until this issue gets tackled:
 770                  # https://github.com/SeldonIO/MLServer/issues/360)
 771                  continue
 772  
 773              scoring_response_dict = json.loads(scoring_response.content)
 774              assert "error_code" in scoring_response_dict
 775              assert scoring_response_dict["error_code"] == ErrorCode.Name(BAD_REQUEST)
 776              assert "message" in scoring_response_dict
 777  
 778  
 779  def test_env_manager_warning_for_use_of_conda(monkeypatch):
 780      with mock.patch("mlflow.models.cli.get_flavor_backend") as mock_get_flavor_backend:
 781          with pytest.warns(UserWarning, match=r"Use of conda is discouraged"):
 782              CliRunner().invoke(
 783                  models_cli.serve,
 784                  ["--model-uri", "model", "--env-manager", "conda"],
 785                  catch_exceptions=False,
 786              )
 787  
 788          with warnings.catch_warnings():
 789              warnings.simplefilter("error")
 790              monkeypatch.setenv(MLFLOW_DISABLE_ENV_MANAGER_CONDA_WARNING.name, "TRUE")
 791              CliRunner().invoke(
 792                  models_cli.serve,
 793                  ["--model-uri", "model", "--env-manager", "conda"],
 794                  catch_exceptions=False,
 795              )
 796  
 797          assert mock_get_flavor_backend.call_count == 2
 798  
 799  
 800  def test_env_manager_unsupported_value():
 801      with pytest.raises(MlflowException, match=r"Invalid value for `env_manager`"):
 802          CliRunner().invoke(
 803              models_cli.serve,
 804              ["--model-uri", "model", "--env-manager", "abc"],
 805              catch_exceptions=False,
 806          )
 807  
 808  
 809  def test_host_invalid_value():
 810      class MyModel(mlflow.pyfunc.PythonModel):
 811          def predict(self, context, model_input):
 812              return model_input
 813  
 814      with mlflow.start_run():
 815          model_info = mlflow.pyfunc.log_model(
 816              name="test_model", python_model=MyModel(), registered_model_name="model"
 817          )
 818  
 819      with mock.patch(
 820          "mlflow.models.cli.get_flavor_backend",
 821          return_value=PyFuncBackend({}, env_manager=_EnvManager.VIRTUALENV),
 822      ):
 823          with pytest.raises(ShellCommandException, match=r"Non-zero exit code: 1"):
 824              CliRunner().invoke(
 825                  models_cli.serve,
 826                  ["--model-uri", model_info.model_uri, "--host", "localhost & echo BUG"],
 827                  catch_exceptions=False,
 828              )
 829  
 830  
 831  def test_change_conda_env_root_location(tmp_path, sk_model):
 832      def _test_model(env_root_path, model_path, sklearn_ver):
 833          env_root_path.mkdir(exist_ok=True)
 834  
 835          mlflow.sklearn.save_model(
 836              sk_model, str(model_path), pip_requirements=[f"scikit-learn=={sklearn_ver}"]
 837          )
 838  
 839          env = get_flavor_backend(
 840              str(model_path),
 841              env_manager=_EnvManager.CONDA,
 842              install_mlflow=False,
 843              env_root_dir=str(env_root_path),
 844          ).prepare_env(model_uri=str(model_path))
 845  
 846          conda_env_name = _get_conda_env_name(
 847              str(model_path / "conda.yaml"), env_root_dir=env_root_path
 848          )
 849          env_path = env_root_path / "conda_envs" / conda_env_name
 850          assert env_path.exists()
 851  
 852          python_exec_path = str(env_path / "bin" / "python")
 853  
 854          # Test execution of command under the correct activated python env.
 855          env.execute(
 856              command=f"python -c \"import sys; assert sys.executable == '{python_exec_path}'; "
 857              f"import sklearn; assert sklearn.__version__ == '{sklearn_ver}'\"",
 858          )
 859  
 860          # Cleanup model path and Conda environment to prevent out of space failures on CI
 861          shutil.rmtree(model_path)
 862          shutil.rmtree(env_path)
 863  
 864      env_root1_path = tmp_path / "root1"
 865      env_root2_path = tmp_path / "root2"
 866  
 867      # Test with model1_path
 868      model1_path = tmp_path / "model1"
 869  
 870      _test_model(env_root1_path, model1_path, "1.4.0")
 871      _test_model(env_root2_path, model1_path, "1.4.0")
 872  
 873      # Test with model2_path
 874      model2_path = tmp_path / "model2"
 875      _test_model(env_root1_path, model2_path, "1.4.2")
 876  
 877  
 878  @pytest.mark.parametrize(
 879      ("input_schema", "output_schema", "params_schema"),
 880      [(True, False, False), (False, True, False), (False, False, True)],
 881  )
 882  def test_signature_enforcement_with_model_serving(input_schema, output_schema, params_schema):
 883      class MyModel(mlflow.pyfunc.PythonModel):
 884          def predict(self, context, model_input, params=None):
 885              return ["test"]
 886  
 887      input_data = ["test_input"] if input_schema else None
 888      output_data = ["test_output"] if output_schema else None
 889      params = {"test": "test"} if params_schema else None
 890  
 891      signature = mlflow.models.infer_signature(
 892          model_input=input_data, model_output=output_data, params=params
 893      )
 894  
 895      with mlflow.start_run():
 896          model_info = mlflow.pyfunc.log_model(
 897              name="test_model", python_model=MyModel(), signature=signature
 898          )
 899  
 900      inference_payload = json.dumps({"inputs": ["test"]})
 901  
 902      # Serve and score the model
 903      scoring_result = pyfunc_serve_and_score_model(
 904          model_uri=model_info.model_uri,
 905          data=inference_payload,
 906          content_type=CONTENT_TYPE_JSON,
 907          extra_args=["--env-manager", "local"],
 908      )
 909      scoring_result.raise_for_status()
 910  
 911      # Assert the prediction result
 912      assert json.loads(scoring_result.content)["predictions"] == ["test"]
 913  
 914  
 915  def assert_base_model_reqs():
 916      """
 917      Helper function for testing model requirements. Asserts that the
 918      contents of requirements.txt and conda.yaml are as expected, then
 919      returns their filepaths so mutations can be performed.
 920      """
 921      import cloudpickle
 922  
 923      class MyModel(mlflow.pyfunc.PythonModel):
 924          def predict(self, context, model_input, params=None):
 925              return ["test"]
 926  
 927      with mlflow.start_run():
 928          model_info = mlflow.pyfunc.log_model(name="model", python_model=MyModel())
 929  
 930      resolved_uri = ModelsArtifactRepository.get_underlying_uri(model_info.model_uri)
 931      local_paths = get_model_requirements_files(resolved_uri)
 932  
 933      requirements_txt_file = local_paths.requirements
 934      conda_env_file = local_paths.conda
 935  
 936      reqs = _get_requirements_from_file(requirements_txt_file)
 937      assert Requirement(f"mlflow=={mlflow.__version__}") in reqs
 938      assert Requirement(f"cloudpickle=={cloudpickle.__version__}") in reqs
 939  
 940      reqs = _get_requirements_from_file(conda_env_file)
 941      assert Requirement(f"mlflow=={mlflow.__version__}") in reqs
 942      assert Requirement(f"cloudpickle=={cloudpickle.__version__}") in reqs
 943  
 944      return model_info.model_uri
 945  
 946  
 947  def test_update_requirements_cli_adds_reqs_successfully():
 948      import cloudpickle
 949  
 950      model_uri = assert_base_model_reqs()
 951  
 952      CliRunner().invoke(
 953          models_cli.update_pip_requirements,
 954          ["-m", f"{model_uri}", "add", "mlflow>=2.9, !=2.9.0", "coolpackage[extra]==8.8.8"],
 955          catch_exceptions=False,
 956      )
 957  
 958      resolved_uri = ModelsArtifactRepository.get_underlying_uri(model_uri)
 959      local_paths = get_model_requirements_files(resolved_uri)
 960  
 961      # the tool should overwrite mlflow, add coolpackage, and leave cloudpickle alone
 962      reqs = _get_requirements_from_file(local_paths.requirements)
 963      assert Requirement("mlflow!=2.9.0,>=2.9") in reqs
 964      assert Requirement("coolpackage[extra]==8.8.8") in reqs
 965      assert Requirement(f"cloudpickle=={cloudpickle.__version__}") in reqs
 966  
 967      reqs = _get_requirements_from_file(local_paths.conda)
 968      assert Requirement("mlflow!=2.9.0,>=2.9") in reqs
 969      assert Requirement("coolpackage[extra]==8.8.8") in reqs
 970      assert Requirement(f"cloudpickle=={cloudpickle.__version__}") in reqs
 971  
 972  
 973  def test_update_requirements_cli_removes_reqs_successfully():
 974      import cloudpickle
 975  
 976      model_uri = assert_base_model_reqs()
 977  
 978      CliRunner().invoke(
 979          models_cli.update_pip_requirements,
 980          ["-m", f"{model_uri}", "remove", "mlflow"],
 981          catch_exceptions=False,
 982      )
 983  
 984      resolved_uri = ModelsArtifactRepository.get_underlying_uri(model_uri)
 985      local_paths = get_model_requirements_files(resolved_uri)
 986  
 987      # the tool should remove mlflow and leave cloudpickle alone
 988      reqs = _get_requirements_from_file(local_paths.requirements)
 989      assert reqs == [Requirement(f"cloudpickle=={cloudpickle.__version__}")]
 990  
 991      reqs = _get_requirements_from_file(local_paths.conda)
 992      assert reqs == [Requirement(f"cloudpickle=={cloudpickle.__version__}")]
 993  
 994  
 995  def test_update_requirements_cli_throws_on_incompatible_input():
 996      model_uri = assert_base_model_reqs()
 997  
 998      with pytest.raises(
 999          MlflowException, match="The specified requirements versions are incompatible"
1000      ):
1001          CliRunner().invoke(
1002              models_cli.update_pip_requirements,
1003              ["-m", f"{model_uri}", "add", "mlflow<2.6", "mlflow>2.7"],
1004              catch_exceptions=False,
1005          )
1006  
1007  
1008  def test_update_model_requirements_add():
1009      import cloudpickle
1010  
1011      model_uri = assert_base_model_reqs()
1012      update_model_requirements(
1013          model_uri, "add", ["mlflow>=2.9, !=2.9.0", "coolpackage[extra]==8.8.8"]
1014      )
1015  
1016      resolved_uri = ModelsArtifactRepository.get_underlying_uri(model_uri)
1017      local_paths = get_model_requirements_files(resolved_uri)
1018  
1019      # the tool should overwrite mlflow, add coolpackage, and leave cloudpickle alone
1020      reqs = _get_requirements_from_file(local_paths.requirements)
1021      assert Requirement("mlflow!=2.9.0,>=2.9") in reqs
1022      assert Requirement("coolpackage[extra]==8.8.8") in reqs
1023      assert Requirement(f"cloudpickle=={cloudpickle.__version__}") in reqs
1024  
1025      reqs = _get_requirements_from_file(local_paths.conda)
1026      assert Requirement("mlflow!=2.9.0,>=2.9") in reqs
1027      assert Requirement("coolpackage[extra]==8.8.8") in reqs
1028      assert Requirement(f"cloudpickle=={cloudpickle.__version__}") in reqs
1029  
1030  
1031  def test_update_model_requirements_remove():
1032      import cloudpickle
1033  
1034      model_uri = assert_base_model_reqs()
1035  
1036      update_model_requirements(model_uri, "remove", ["mlflow"])
1037      resolved_uri = ModelsArtifactRepository.get_underlying_uri(model_uri)
1038      local_paths = get_model_requirements_files(resolved_uri)
1039  
1040      # the tool should remove mlflow and leave cloudpickle alone
1041      reqs = _get_requirements_from_file(local_paths.requirements)
1042      assert reqs == [Requirement(f"cloudpickle=={cloudpickle.__version__}")]
1043  
1044      reqs = _get_requirements_from_file(local_paths.conda)
1045      assert reqs == [Requirement(f"cloudpickle=={cloudpickle.__version__}")]