/ tests / projects / test_project_spec.py
test_project_spec.py
  1  import os
  2  import textwrap
  3  
  4  import pytest
  5  
  6  from mlflow.exceptions import ExecutionException
  7  from mlflow.projects import _project_spec
  8  
  9  from tests.projects.utils import load_project
 10  
 11  
 12  def test_project_get_entry_point():
 13      project = load_project()
 14      entry_point = project.get_entry_point("greeter")
 15      assert entry_point.name == "greeter"
 16      assert entry_point.command == "python greeter.py {greeting} {name}"
 17      # Validate parameters
 18      assert set(entry_point.parameters.keys()) == {"name", "greeting"}
 19      name_param = entry_point.parameters["name"]
 20      assert name_param.type == "string"
 21      assert name_param.default is None
 22      greeting_param = entry_point.parameters["greeting"]
 23      assert greeting_param.type == "string"
 24      assert greeting_param.default == "hi"
 25  
 26  
 27  def test_project_get_unspecified_entry_point():
 28      project = load_project()
 29      entry_point = project.get_entry_point("my_script.py")
 30      assert entry_point.name == "my_script.py"
 31      assert entry_point.command == "python my_script.py"
 32      assert entry_point.parameters == {}
 33      entry_point = project.get_entry_point("my_script.sh")
 34      assert entry_point.name == "my_script.sh"
 35      assert entry_point.command == "{} my_script.sh".format(os.environ.get("SHELL", "bash"))
 36      assert entry_point.parameters == {}
 37      with pytest.raises(ExecutionException, match="Could not find my_program.scala"):
 38          project.get_entry_point("my_program.scala")
 39  
 40  
 41  @pytest.mark.parametrize(
 42      (
 43          # Contents of MLproject file. If None, no MLproject file will be written.
 44          "mlproject",
 45          # Path to conda environment file. If None, no conda environment file will be written.
 46          "conda_env_path",
 47          # Contents of conda environment file (written if conda_env_path is not None).
 48          "conda_env_contents",
 49          # Path to MLproject file. If None, the MLproject file will be written to "MLproject".
 50          "mlproject_path",
 51      ),
 52      [
 53          (None, None, "", None),
 54          ("key: value", "conda.yaml", "hi", "MLproject"),
 55          ("conda_env: some-env.yaml", "some-env.yaml", "hi", "mlproject"),
 56      ],
 57  )
 58  def test_load_project(tmp_path, mlproject, conda_env_path, conda_env_contents, mlproject_path):
 59      """
 60      Test that we can load a project with various combinations of an MLproject / conda.yaml file
 61      """
 62      if mlproject:
 63          tmp_path.joinpath(mlproject_path).write_text(mlproject)
 64      if conda_env_path:
 65          tmp_path.joinpath(conda_env_path).write_text(conda_env_contents)
 66      project = _project_spec.load_project(str(tmp_path))
 67      assert project._entry_points == {}
 68      expected_env_path = str(tmp_path.joinpath(conda_env_path)) if conda_env_path else None
 69      assert project.env_config_path == expected_env_path
 70      if conda_env_path:
 71          with open(project.env_config_path) as f:
 72              assert f.read() == conda_env_contents
 73  
 74  
 75  def test_load_docker_project(tmp_path):
 76      tmp_path.joinpath("MLproject").write_text(
 77          textwrap.dedent(
 78              """
 79      docker_env:
 80          image: some-image
 81      """
 82          )
 83      )
 84      project = _project_spec.load_project(str(tmp_path))
 85      assert project._entry_points == {}
 86      assert project.env_config_path is None
 87      assert project.docker_env.get("image") == "some-image"
 88  
 89  
 90  def test_load_virtualenv_project(tmp_path):
 91      tmp_path.joinpath("MLproject").write_text("python_env: python_env.yaml")
 92      python_env = tmp_path.joinpath("python_env.yaml")
 93      python_env.write_text("python: 3.8.15")
 94      project = _project_spec.load_project(tmp_path)
 95      assert project._entry_points == {}
 96      assert python_env.samefile(project.env_config_path)
 97  
 98  
 99  @pytest.mark.parametrize(
100      ("invalid_project_contents", "expected_error_msg"),
101      [
102          (
103              textwrap.dedent(
104                  """
105      docker_env:
106          image: some-image
107      conda_env: some-file.yaml
108      """
109              ),
110              "cannot contain multiple environment fields",
111          ),
112          (
113              textwrap.dedent(
114                  """
115      docker_env:
116          not-image-attribute: blah
117      """
118              ),
119              "no image attribute found",
120          ),
121      ],
122  )
123  def test_load_invalid_project(tmp_path, invalid_project_contents, expected_error_msg):
124      tmp_path.joinpath("MLproject").write_text(invalid_project_contents)
125      with pytest.raises(ExecutionException, match=expected_error_msg) as e:
126          _project_spec.load_project(str(tmp_path))
127      assert expected_error_msg in str(e.value)