/ tests / projects / test_entry_point.py
test_entry_point.py
  1  import os
  2  from shlex import quote
  3  from unittest import mock
  4  
  5  import pytest
  6  
  7  from mlflow.exceptions import ExecutionException
  8  from mlflow.projects._project_spec import EntryPoint
  9  from mlflow.utils.file_utils import TempDir, path_to_local_file_uri
 10  
 11  from tests.projects.utils import TEST_PROJECT_DIR, load_project
 12  
 13  
 14  def test_entry_point_compute_params():
 15      """
 16      Tests that EntryPoint correctly computes a final set of parameters to use when running a project
 17      """
 18      project = load_project()
 19      entry_point = project.get_entry_point("greeter")
 20      # Pass extra "excitement" param, use default value for `greeting` param
 21      with TempDir() as storage_dir:
 22          params, extra_params = entry_point.compute_parameters(
 23              {"name": "friend", "excitement": 10}, storage_dir
 24          )
 25          assert params == {"name": "friend", "greeting": "hi"}
 26          assert extra_params == {"excitement": "10"}
 27          # Don't pass extra "excitement" param, pass value for `greeting`
 28          params, extra_params = entry_point.compute_parameters(
 29              {"name": "friend", "greeting": "hello"}, storage_dir
 30          )
 31          assert params == {"name": "friend", "greeting": "hello"}
 32          assert extra_params == {}
 33          # Raise exception on missing required parameter
 34          with pytest.raises(
 35              ExecutionException, match="No value given for missing parameters: 'name'"
 36          ):
 37              entry_point.compute_parameters({}, storage_dir)
 38  
 39  
 40  def test_entry_point_compute_command():
 41      """
 42      Tests that EntryPoint correctly computes the command to execute in order to run the entry point.
 43      """
 44      project = load_project()
 45      entry_point = project.get_entry_point("greeter")
 46      with TempDir() as tmp:
 47          storage_dir = tmp.path()
 48          command = entry_point.compute_command({"name": "friend", "excitement": 10}, storage_dir)
 49          assert command == "python greeter.py hi friend --excitement 10"
 50          with pytest.raises(
 51              ExecutionException, match="No value given for missing parameters: 'name'"
 52          ):
 53              entry_point.compute_command({}, storage_dir)
 54          # Test shell escaping
 55          name_value = "friend; echo 'hi'"
 56          command = entry_point.compute_command({"name": name_value}, storage_dir)
 57          assert command == "python greeter.py {} {}".format(quote("hi"), quote(name_value))
 58  
 59  
 60  def test_path_parameter():
 61      """
 62      Tests that MLflow file-download APIs get called when necessary for arguments of type `path`.
 63      """
 64      project = load_project()
 65      entry_point = project.get_entry_point("line_count")
 66      with mock.patch(
 67          "mlflow.tracking.artifact_utils._download_artifact_from_uri", return_value=0
 68      ) as download_uri_mock:
 69          # Verify that we don't attempt to call download_uri when passing a local file to a
 70          # parameter of type "path"
 71          with TempDir() as tmp:
 72              dst_dir = tmp.path()
 73              local_path = os.path.join(TEST_PROJECT_DIR, "MLproject")
 74              params, _ = entry_point.compute_parameters(
 75                  user_parameters={"path": local_path}, storage_dir=dst_dir
 76              )
 77              assert params["path"] == os.path.abspath(local_path)
 78              assert download_uri_mock.call_count == 0
 79  
 80              params, _ = entry_point.compute_parameters(
 81                  user_parameters={"path": path_to_local_file_uri(local_path)}, storage_dir=dst_dir
 82              )
 83              assert params["path"] == os.path.abspath(local_path)
 84              assert download_uri_mock.call_count == 0
 85  
 86          # Verify that we raise an exception when passing a non-existent local file to a
 87          # parameter of type "path"
 88          with TempDir() as tmp:
 89              dst_dir = tmp.path()
 90              with pytest.raises(ExecutionException, match="no such file or directory"):
 91                  entry_point.compute_parameters(
 92                      user_parameters={"path": os.path.join(dst_dir, "some/nonexistent/file")},
 93                      storage_dir=dst_dir,
 94                  )
 95          # Verify that we do call `download_uri` when passing a URI to a parameter of type "path"
 96          for i, prefix in enumerate(["dbfs:/", "s3://", "gs://"]):
 97              with TempDir() as tmp:
 98                  dst_dir = tmp.path()
 99                  file_to_download = "images.tgz"
100                  download_path = f"{dst_dir}/{file_to_download}"
101                  download_uri_mock.return_value = download_path
102                  params, _ = entry_point.compute_parameters(
103                      user_parameters={"path": os.path.join(prefix, file_to_download)},
104                      storage_dir=dst_dir,
105                  )
106                  assert params["path"] == download_path
107                  assert download_uri_mock.call_count == i + 1
108  
109  
110  def test_uri_parameter():
111      project = load_project()
112      entry_point = project.get_entry_point("download_uri")
113      with (
114          mock.patch(
115              "mlflow.tracking.artifact_utils._download_artifact_from_uri"
116          ) as download_uri_mock,
117          TempDir() as tmp,
118      ):
119          dst_dir = tmp.path()
120          # Test that we don't attempt to locally download parameters of type URI
121          entry_point.compute_command(
122              user_parameters={"uri": f"file://{dst_dir}"}, storage_dir=dst_dir
123          )
124          assert download_uri_mock.call_count == 0
125          # Test that we raise an exception if a local path is passed to a parameter of type URI
126          with pytest.raises(ExecutionException, match="Expected URI for parameter uri"):
127              entry_point.compute_command(user_parameters={"uri": dst_dir}, storage_dir=dst_dir)
128  
129  
130  def test_params():
131      defaults = {
132          "alpha": "float",
133          "l1_ratio": {"type": "float", "default": 0.1},
134          "l2_ratio": {"type": "float", "default": 0.0003},
135          "random_str": {"type": "string", "default": "hello"},
136      }
137      entry_point = EntryPoint("entry_point_name", defaults, "command_name script.py")
138  
139      user1 = {}
140      with pytest.raises(ExecutionException, match="No value given for missing parameters"):
141          entry_point._validate_parameters(user1)
142  
143      user_2 = {"beta": 0.004}
144      with pytest.raises(ExecutionException, match="No value given for missing parameters"):
145          entry_point._validate_parameters(user_2)
146  
147      user_3 = {"alpha": 0.004, "gamma": 0.89}
148      expected_final_3 = {
149          "alpha": "0.004",
150          "l1_ratio": "0.1",
151          "l2_ratio": "0.0003",
152          "random_str": "hello",
153      }
154      expected_extra_3 = {"gamma": "0.89"}
155      final_3, extra_3 = entry_point.compute_parameters(user_3, None)
156      assert expected_extra_3 == extra_3
157      assert expected_final_3 == final_3
158  
159      user_4 = {"alpha": 0.004, "l1_ratio": 0.0008, "random_str_2": "hello"}
160      expected_final_4 = {
161          "alpha": "0.004",
162          "l1_ratio": "0.0008",
163          "l2_ratio": "0.0003",
164          "random_str": "hello",
165      }
166      expected_extra_4 = {"random_str_2": "hello"}
167      final_4, extra_4 = entry_point.compute_parameters(user_4, None)
168      assert expected_extra_4 == extra_4
169      assert expected_final_4 == final_4
170  
171      user_5 = {"alpha": -0.99, "random_str": "hi"}
172      expected_final_5 = {
173          "alpha": "-0.99",
174          "l1_ratio": "0.1",
175          "l2_ratio": "0.0003",
176          "random_str": "hi",
177      }
178      expected_extra_5 = {}
179      final_5, extra_5 = entry_point.compute_parameters(user_5, None)
180      assert expected_final_5 == final_5
181      assert expected_extra_5 == extra_5
182  
183      user_6 = {"alpha": 0.77, "ALPHA": 0.89}
184      expected_final_6 = {
185          "alpha": "0.77",
186          "l1_ratio": "0.1",
187          "l2_ratio": "0.0003",
188          "random_str": "hello",
189      }
190      expected_extra_6 = {"ALPHA": "0.89"}
191      final_6, extra_6 = entry_point.compute_parameters(user_6, None)
192      assert expected_extra_6 == extra_6
193      assert expected_final_6 == final_6
194  
195  
196  def test_path_params():
197      data_file = "s3://path.test/resources/data_file.csv"
198      defaults = {
199          "constants": {"type": "uri", "default": "s3://path.test/b1"},
200          "data": {"type": "path", "default": data_file},
201      }
202      entry_point = EntryPoint("entry_point_name", defaults, "command_name script.py")
203  
204      with mock.patch(
205          "mlflow.tracking.artifact_utils._download_artifact_from_uri", return_value=None
206      ) as download_uri_mock:
207          final_1, extra_1 = entry_point.compute_parameters({}, None)
208          assert final_1 == {"constants": "s3://path.test/b1", "data": data_file}
209          assert extra_1 == {}
210          assert download_uri_mock.call_count == 0
211  
212      with mock.patch(
213          "mlflow.tracking.artifact_utils._download_artifact_from_uri"
214      ) as download_uri_mock:
215          user_2 = {"alpha": 0.001, "constants": "s3://path.test/b_two"}
216          final_2, extra_2 = entry_point.compute_parameters(user_2, None)
217          assert final_2 == {"constants": "s3://path.test/b_two", "data": data_file}
218          assert extra_2 == {"alpha": "0.001"}
219          assert download_uri_mock.call_count == 0
220  
221      with (
222          mock.patch(
223              "mlflow.tracking.artifact_utils._download_artifact_from_uri"
224          ) as download_uri_mock,
225          TempDir() as tmp,
226      ):
227          dest_path = tmp.path()
228          download_path = f"{dest_path}/data_file.csv"
229          download_uri_mock.return_value = download_path
230          user_3 = {"alpha": 0.001}
231          final_3, extra_3 = entry_point.compute_parameters(user_3, dest_path)
232          assert final_3 == {"constants": "s3://path.test/b1", "data": download_path}
233          assert extra_3 == {"alpha": "0.001"}
234          assert download_uri_mock.call_count == 1
235  
236      with (
237          mock.patch(
238              "mlflow.tracking.artifact_utils._download_artifact_from_uri"
239          ) as download_uri_mock,
240          TempDir() as tmp,
241      ):
242          dest_path = tmp.path()
243          download_path = f"{dest_path}/images.tgz"
244          download_uri_mock.return_value = download_path
245          user_4 = {"data": "s3://another.example.test/data_stash/images.tgz"}
246          final_4, extra_4 = entry_point.compute_parameters(user_4, dest_path)
247          assert final_4 == {"constants": "s3://path.test/b1", "data": download_path}
248          assert extra_4 == {}
249          assert download_uri_mock.call_count == 1