test_entry_point.py
1 import os 2 from shlex import quote 3 from unittest import mock 4 5 import pytest 6 7 from mlflow.exceptions import ExecutionException 8 from mlflow.projects._project_spec import EntryPoint 9 from mlflow.utils.file_utils import TempDir, path_to_local_file_uri 10 11 from tests.projects.utils import TEST_PROJECT_DIR, load_project 12 13 14 def test_entry_point_compute_params(): 15 """ 16 Tests that EntryPoint correctly computes a final set of parameters to use when running a project 17 """ 18 project = load_project() 19 entry_point = project.get_entry_point("greeter") 20 # Pass extra "excitement" param, use default value for `greeting` param 21 with TempDir() as storage_dir: 22 params, extra_params = entry_point.compute_parameters( 23 {"name": "friend", "excitement": 10}, storage_dir 24 ) 25 assert params == {"name": "friend", "greeting": "hi"} 26 assert extra_params == {"excitement": "10"} 27 # Don't pass extra "excitement" param, pass value for `greeting` 28 params, extra_params = entry_point.compute_parameters( 29 {"name": "friend", "greeting": "hello"}, storage_dir 30 ) 31 assert params == {"name": "friend", "greeting": "hello"} 32 assert extra_params == {} 33 # Raise exception on missing required parameter 34 with pytest.raises( 35 ExecutionException, match="No value given for missing parameters: 'name'" 36 ): 37 entry_point.compute_parameters({}, storage_dir) 38 39 40 def test_entry_point_compute_command(): 41 """ 42 Tests that EntryPoint correctly computes the command to execute in order to run the entry point. 43 """ 44 project = load_project() 45 entry_point = project.get_entry_point("greeter") 46 with TempDir() as tmp: 47 storage_dir = tmp.path() 48 command = entry_point.compute_command({"name": "friend", "excitement": 10}, storage_dir) 49 assert command == "python greeter.py hi friend --excitement 10" 50 with pytest.raises( 51 ExecutionException, match="No value given for missing parameters: 'name'" 52 ): 53 entry_point.compute_command({}, storage_dir) 54 # Test shell escaping 55 name_value = "friend; echo 'hi'" 56 command = entry_point.compute_command({"name": name_value}, storage_dir) 57 assert command == "python greeter.py {} {}".format(quote("hi"), quote(name_value)) 58 59 60 def test_path_parameter(): 61 """ 62 Tests that MLflow file-download APIs get called when necessary for arguments of type `path`. 63 """ 64 project = load_project() 65 entry_point = project.get_entry_point("line_count") 66 with mock.patch( 67 "mlflow.tracking.artifact_utils._download_artifact_from_uri", return_value=0 68 ) as download_uri_mock: 69 # Verify that we don't attempt to call download_uri when passing a local file to a 70 # parameter of type "path" 71 with TempDir() as tmp: 72 dst_dir = tmp.path() 73 local_path = os.path.join(TEST_PROJECT_DIR, "MLproject") 74 params, _ = entry_point.compute_parameters( 75 user_parameters={"path": local_path}, storage_dir=dst_dir 76 ) 77 assert params["path"] == os.path.abspath(local_path) 78 assert download_uri_mock.call_count == 0 79 80 params, _ = entry_point.compute_parameters( 81 user_parameters={"path": path_to_local_file_uri(local_path)}, storage_dir=dst_dir 82 ) 83 assert params["path"] == os.path.abspath(local_path) 84 assert download_uri_mock.call_count == 0 85 86 # Verify that we raise an exception when passing a non-existent local file to a 87 # parameter of type "path" 88 with TempDir() as tmp: 89 dst_dir = tmp.path() 90 with pytest.raises(ExecutionException, match="no such file or directory"): 91 entry_point.compute_parameters( 92 user_parameters={"path": os.path.join(dst_dir, "some/nonexistent/file")}, 93 storage_dir=dst_dir, 94 ) 95 # Verify that we do call `download_uri` when passing a URI to a parameter of type "path" 96 for i, prefix in enumerate(["dbfs:/", "s3://", "gs://"]): 97 with TempDir() as tmp: 98 dst_dir = tmp.path() 99 file_to_download = "images.tgz" 100 download_path = f"{dst_dir}/{file_to_download}" 101 download_uri_mock.return_value = download_path 102 params, _ = entry_point.compute_parameters( 103 user_parameters={"path": os.path.join(prefix, file_to_download)}, 104 storage_dir=dst_dir, 105 ) 106 assert params["path"] == download_path 107 assert download_uri_mock.call_count == i + 1 108 109 110 def test_uri_parameter(): 111 project = load_project() 112 entry_point = project.get_entry_point("download_uri") 113 with ( 114 mock.patch( 115 "mlflow.tracking.artifact_utils._download_artifact_from_uri" 116 ) as download_uri_mock, 117 TempDir() as tmp, 118 ): 119 dst_dir = tmp.path() 120 # Test that we don't attempt to locally download parameters of type URI 121 entry_point.compute_command( 122 user_parameters={"uri": f"file://{dst_dir}"}, storage_dir=dst_dir 123 ) 124 assert download_uri_mock.call_count == 0 125 # Test that we raise an exception if a local path is passed to a parameter of type URI 126 with pytest.raises(ExecutionException, match="Expected URI for parameter uri"): 127 entry_point.compute_command(user_parameters={"uri": dst_dir}, storage_dir=dst_dir) 128 129 130 def test_params(): 131 defaults = { 132 "alpha": "float", 133 "l1_ratio": {"type": "float", "default": 0.1}, 134 "l2_ratio": {"type": "float", "default": 0.0003}, 135 "random_str": {"type": "string", "default": "hello"}, 136 } 137 entry_point = EntryPoint("entry_point_name", defaults, "command_name script.py") 138 139 user1 = {} 140 with pytest.raises(ExecutionException, match="No value given for missing parameters"): 141 entry_point._validate_parameters(user1) 142 143 user_2 = {"beta": 0.004} 144 with pytest.raises(ExecutionException, match="No value given for missing parameters"): 145 entry_point._validate_parameters(user_2) 146 147 user_3 = {"alpha": 0.004, "gamma": 0.89} 148 expected_final_3 = { 149 "alpha": "0.004", 150 "l1_ratio": "0.1", 151 "l2_ratio": "0.0003", 152 "random_str": "hello", 153 } 154 expected_extra_3 = {"gamma": "0.89"} 155 final_3, extra_3 = entry_point.compute_parameters(user_3, None) 156 assert expected_extra_3 == extra_3 157 assert expected_final_3 == final_3 158 159 user_4 = {"alpha": 0.004, "l1_ratio": 0.0008, "random_str_2": "hello"} 160 expected_final_4 = { 161 "alpha": "0.004", 162 "l1_ratio": "0.0008", 163 "l2_ratio": "0.0003", 164 "random_str": "hello", 165 } 166 expected_extra_4 = {"random_str_2": "hello"} 167 final_4, extra_4 = entry_point.compute_parameters(user_4, None) 168 assert expected_extra_4 == extra_4 169 assert expected_final_4 == final_4 170 171 user_5 = {"alpha": -0.99, "random_str": "hi"} 172 expected_final_5 = { 173 "alpha": "-0.99", 174 "l1_ratio": "0.1", 175 "l2_ratio": "0.0003", 176 "random_str": "hi", 177 } 178 expected_extra_5 = {} 179 final_5, extra_5 = entry_point.compute_parameters(user_5, None) 180 assert expected_final_5 == final_5 181 assert expected_extra_5 == extra_5 182 183 user_6 = {"alpha": 0.77, "ALPHA": 0.89} 184 expected_final_6 = { 185 "alpha": "0.77", 186 "l1_ratio": "0.1", 187 "l2_ratio": "0.0003", 188 "random_str": "hello", 189 } 190 expected_extra_6 = {"ALPHA": "0.89"} 191 final_6, extra_6 = entry_point.compute_parameters(user_6, None) 192 assert expected_extra_6 == extra_6 193 assert expected_final_6 == final_6 194 195 196 def test_path_params(): 197 data_file = "s3://path.test/resources/data_file.csv" 198 defaults = { 199 "constants": {"type": "uri", "default": "s3://path.test/b1"}, 200 "data": {"type": "path", "default": data_file}, 201 } 202 entry_point = EntryPoint("entry_point_name", defaults, "command_name script.py") 203 204 with mock.patch( 205 "mlflow.tracking.artifact_utils._download_artifact_from_uri", return_value=None 206 ) as download_uri_mock: 207 final_1, extra_1 = entry_point.compute_parameters({}, None) 208 assert final_1 == {"constants": "s3://path.test/b1", "data": data_file} 209 assert extra_1 == {} 210 assert download_uri_mock.call_count == 0 211 212 with mock.patch( 213 "mlflow.tracking.artifact_utils._download_artifact_from_uri" 214 ) as download_uri_mock: 215 user_2 = {"alpha": 0.001, "constants": "s3://path.test/b_two"} 216 final_2, extra_2 = entry_point.compute_parameters(user_2, None) 217 assert final_2 == {"constants": "s3://path.test/b_two", "data": data_file} 218 assert extra_2 == {"alpha": "0.001"} 219 assert download_uri_mock.call_count == 0 220 221 with ( 222 mock.patch( 223 "mlflow.tracking.artifact_utils._download_artifact_from_uri" 224 ) as download_uri_mock, 225 TempDir() as tmp, 226 ): 227 dest_path = tmp.path() 228 download_path = f"{dest_path}/data_file.csv" 229 download_uri_mock.return_value = download_path 230 user_3 = {"alpha": 0.001} 231 final_3, extra_3 = entry_point.compute_parameters(user_3, dest_path) 232 assert final_3 == {"constants": "s3://path.test/b1", "data": download_path} 233 assert extra_3 == {"alpha": "0.001"} 234 assert download_uri_mock.call_count == 1 235 236 with ( 237 mock.patch( 238 "mlflow.tracking.artifact_utils._download_artifact_from_uri" 239 ) as download_uri_mock, 240 TempDir() as tmp, 241 ): 242 dest_path = tmp.path() 243 download_path = f"{dest_path}/images.tgz" 244 download_uri_mock.return_value = download_path 245 user_4 = {"data": "s3://another.example.test/data_stash/images.tgz"} 246 final_4, extra_4 = entry_point.compute_parameters(user_4, dest_path) 247 assert final_4 == {"constants": "s3://path.test/b1", "data": download_path} 248 assert extra_4 == {} 249 assert download_uri_mock.call_count == 1