test_project_spec.py
1 import os 2 import textwrap 3 4 import pytest 5 6 from mlflow.exceptions import ExecutionException 7 from mlflow.projects import _project_spec 8 9 from tests.projects.utils import load_project 10 11 12 def test_project_get_entry_point(): 13 project = load_project() 14 entry_point = project.get_entry_point("greeter") 15 assert entry_point.name == "greeter" 16 assert entry_point.command == "python greeter.py {greeting} {name}" 17 # Validate parameters 18 assert set(entry_point.parameters.keys()) == {"name", "greeting"} 19 name_param = entry_point.parameters["name"] 20 assert name_param.type == "string" 21 assert name_param.default is None 22 greeting_param = entry_point.parameters["greeting"] 23 assert greeting_param.type == "string" 24 assert greeting_param.default == "hi" 25 26 27 def test_project_get_unspecified_entry_point(): 28 project = load_project() 29 entry_point = project.get_entry_point("my_script.py") 30 assert entry_point.name == "my_script.py" 31 assert entry_point.command == "python my_script.py" 32 assert entry_point.parameters == {} 33 entry_point = project.get_entry_point("my_script.sh") 34 assert entry_point.name == "my_script.sh" 35 assert entry_point.command == "{} my_script.sh".format(os.environ.get("SHELL", "bash")) 36 assert entry_point.parameters == {} 37 with pytest.raises(ExecutionException, match="Could not find my_program.scala"): 38 project.get_entry_point("my_program.scala") 39 40 41 @pytest.mark.parametrize( 42 ( 43 # Contents of MLproject file. If None, no MLproject file will be written. 44 "mlproject", 45 # Path to conda environment file. If None, no conda environment file will be written. 46 "conda_env_path", 47 # Contents of conda environment file (written if conda_env_path is not None). 48 "conda_env_contents", 49 # Path to MLproject file. If None, the MLproject file will be written to "MLproject". 50 "mlproject_path", 51 ), 52 [ 53 (None, None, "", None), 54 ("key: value", "conda.yaml", "hi", "MLproject"), 55 ("conda_env: some-env.yaml", "some-env.yaml", "hi", "mlproject"), 56 ], 57 ) 58 def test_load_project(tmp_path, mlproject, conda_env_path, conda_env_contents, mlproject_path): 59 """ 60 Test that we can load a project with various combinations of an MLproject / conda.yaml file 61 """ 62 if mlproject: 63 tmp_path.joinpath(mlproject_path).write_text(mlproject) 64 if conda_env_path: 65 tmp_path.joinpath(conda_env_path).write_text(conda_env_contents) 66 project = _project_spec.load_project(str(tmp_path)) 67 assert project._entry_points == {} 68 expected_env_path = str(tmp_path.joinpath(conda_env_path)) if conda_env_path else None 69 assert project.env_config_path == expected_env_path 70 if conda_env_path: 71 with open(project.env_config_path) as f: 72 assert f.read() == conda_env_contents 73 74 75 def test_load_docker_project(tmp_path): 76 tmp_path.joinpath("MLproject").write_text( 77 textwrap.dedent( 78 """ 79 docker_env: 80 image: some-image 81 """ 82 ) 83 ) 84 project = _project_spec.load_project(str(tmp_path)) 85 assert project._entry_points == {} 86 assert project.env_config_path is None 87 assert project.docker_env.get("image") == "some-image" 88 89 90 def test_load_virtualenv_project(tmp_path): 91 tmp_path.joinpath("MLproject").write_text("python_env: python_env.yaml") 92 python_env = tmp_path.joinpath("python_env.yaml") 93 python_env.write_text("python: 3.8.15") 94 project = _project_spec.load_project(tmp_path) 95 assert project._entry_points == {} 96 assert python_env.samefile(project.env_config_path) 97 98 99 @pytest.mark.parametrize( 100 ("invalid_project_contents", "expected_error_msg"), 101 [ 102 ( 103 textwrap.dedent( 104 """ 105 docker_env: 106 image: some-image 107 conda_env: some-file.yaml 108 """ 109 ), 110 "cannot contain multiple environment fields", 111 ), 112 ( 113 textwrap.dedent( 114 """ 115 docker_env: 116 not-image-attribute: blah 117 """ 118 ), 119 "no image attribute found", 120 ), 121 ], 122 ) 123 def test_load_invalid_project(tmp_path, invalid_project_contents, expected_error_msg): 124 tmp_path.joinpath("MLproject").write_text(invalid_project_contents) 125 with pytest.raises(ExecutionException, match=expected_error_msg) as e: 126 _project_spec.load_project(str(tmp_path)) 127 assert expected_error_msg in str(e.value)