/ tests / utils / test_uri.py
test_uri.py
  1  import os
  2  import pathlib
  3  import posixpath
  4  
  5  import pytest
  6  
  7  from mlflow.exceptions import MlflowException
  8  from mlflow.store.db.db_types import DATABASE_ENGINES
  9  from mlflow.utils.os import is_windows
 10  from mlflow.utils.uri import (
 11      add_databricks_profile_info_to_artifact_uri,
 12      append_to_uri_path,
 13      append_to_uri_query_params,
 14      dbfs_hdfs_uri_to_fuse_path,
 15      extract_and_normalize_path,
 16      extract_db_type_from_uri,
 17      get_databricks_profile_uri_from_artifact_uri,
 18      get_db_info_from_uri,
 19      get_uri_scheme,
 20      is_databricks_acled_artifacts_uri,
 21      is_databricks_uri,
 22      is_fuse_or_uc_volumes_uri,
 23      is_http_uri,
 24      is_local_uri,
 25      is_valid_dbfs_uri,
 26      remove_databricks_profile_info_from_artifact_uri,
 27      resolve_uri_if_local,
 28      strip_scheme,
 29      validate_path_is_safe,
 30      validate_path_within_directory,
 31  )
 32  
 33  
 34  def test_extract_db_type_from_uri():
 35      uri = "{}://username:password@host:port/database"
 36      for legit_db in DATABASE_ENGINES:
 37          assert legit_db == extract_db_type_from_uri(uri.format(legit_db))
 38          assert legit_db == get_uri_scheme(uri.format(legit_db))
 39  
 40          with_driver = legit_db + "+driver-string"
 41          assert legit_db == extract_db_type_from_uri(uri.format(with_driver))
 42          assert legit_db == get_uri_scheme(uri.format(with_driver))
 43  
 44      for unsupported_db in ["a", "aa", "sql"]:
 45          with pytest.raises(MlflowException, match="Invalid database engine"):
 46              extract_db_type_from_uri(unsupported_db)
 47  
 48  
 49  @pytest.mark.parametrize(
 50      ("server_uri", "result"),
 51      [
 52          ("databricks://aAbB", ("aAbB", None)),
 53          ("databricks://aAbB/", ("aAbB", None)),
 54          ("databricks://aAbB/path", ("aAbB", None)),
 55          ("databricks://profile:prefix", ("profile", "prefix")),
 56          ("databricks://profile:prefix/extra", ("profile", "prefix")),
 57          ("nondatabricks://profile:prefix", (None, None)),
 58          ("databricks://profile", ("profile", None)),
 59          ("databricks://profile/", ("profile", None)),
 60          ("databricks-uc://profile:prefix", ("profile", "prefix")),
 61          ("databricks-uc://profile:prefix/extra", ("profile", "prefix")),
 62          ("databricks-uc://profile", ("profile", None)),
 63          ("databricks-uc://profile/", ("profile", None)),
 64      ],
 65  )
 66  def test_get_db_info_from_uri(server_uri, result):
 67      assert get_db_info_from_uri(server_uri) == result
 68  
 69  
 70  @pytest.mark.parametrize(
 71      "server_uri",
 72      ["databricks:/profile:prefix", "databricks:/", "databricks://"],
 73  )
 74  def test_get_db_info_from_uri_errors_no_netloc(server_uri):
 75      with pytest.raises(MlflowException, match="URI is formatted incorrectly"):
 76          get_db_info_from_uri(server_uri)
 77  
 78  
 79  @pytest.mark.parametrize(
 80      "server_uri",
 81      [
 82          "databricks://profile:prefix:extra",
 83          "databricks://profile:prefix:extra  ",
 84          "databricks://profile:prefix extra",
 85          "databricks://profile:prefix  ",
 86          "databricks://profile ",
 87          "databricks://profile:",
 88          "databricks://profile: ",
 89      ],
 90  )
 91  def test_get_db_info_from_uri_errors_invalid_profile(server_uri):
 92      with pytest.raises(MlflowException, match="Unsupported Databricks profile"):
 93          get_db_info_from_uri(server_uri)
 94  
 95  
 96  def test_is_local_uri():
 97      assert is_local_uri("mlruns")
 98      assert is_local_uri("./mlruns")
 99      assert is_local_uri("file:///foo/mlruns")
100      assert is_local_uri("file:foo/mlruns")
101      assert is_local_uri("file://./mlruns")
102      assert is_local_uri("file://localhost/mlruns")
103      assert is_local_uri("file://localhost:5000/mlruns")
104      assert is_local_uri("file://127.0.0.1/mlruns")
105      assert is_local_uri("file://127.0.0.1:5000/mlruns")
106      assert is_local_uri("//proc/self/root")
107      assert is_local_uri("/proc/self/root")
108  
109      assert not is_local_uri("https://whatever")
110      assert not is_local_uri("http://whatever")
111      assert not is_local_uri("databricks")
112      assert not is_local_uri("databricks:whatever")
113      assert not is_local_uri("databricks://whatever")
114  
115      with pytest.raises(MlflowException, match="is not a valid remote uri."):
116          is_local_uri("file://myhostname/path/to/file")
117  
118  
119  @pytest.mark.skipif(not is_windows(), reason="Windows-only test")
120  def test_is_local_uri_windows():
121      assert is_local_uri("C:\\foo\\mlruns")
122      assert is_local_uri("C:/foo/mlruns")
123      assert is_local_uri("file:///C:\\foo\\mlruns")
124      assert not is_local_uri("\\\\server\\aa\\bb")
125  
126  
127  def test_is_databricks_uri():
128      assert is_databricks_uri("databricks")
129      assert is_databricks_uri("databricks:whatever")
130      assert is_databricks_uri("databricks://whatever")
131      assert not is_databricks_uri("mlruns")
132      assert not is_databricks_uri("http://whatever")
133  
134  
135  def test_is_http_uri():
136      assert is_http_uri("http://whatever")
137      assert is_http_uri("https://whatever")
138      assert not is_http_uri("file://whatever")
139      assert not is_http_uri("databricks://whatever")
140      assert not is_http_uri("mlruns")
141  
142  
143  def validate_append_to_uri_path_test_cases(cases):
144      for input_uri, input_path, expected_output_uri in cases:
145          assert append_to_uri_path(input_uri, input_path) == expected_output_uri
146          assert append_to_uri_path(input_uri, *posixpath.split(input_path)) == expected_output_uri
147  
148  
149  def test_append_to_uri_path_joins_uri_paths_and_posixpaths_correctly():
150      validate_append_to_uri_path_test_cases([
151          ("", "path", "path"),
152          ("", "/path", "/path"),
153          ("path", "", "path/"),
154          ("path", "subpath", "path/subpath"),
155          ("path/", "subpath", "path/subpath"),
156          ("path/", "/subpath", "path/subpath"),
157          ("path", "/subpath", "path/subpath"),
158          ("/path", "/subpath", "/path/subpath"),
159          ("//path", "/subpath", "//path/subpath"),
160          ("///path", "/subpath", "///path/subpath"),
161          ("/path", "/subpath/subdir", "/path/subpath/subdir"),
162          ("file:path", "", "file:path/"),
163          ("file:path/", "", "file:path/"),
164          ("file:path", "subpath", "file:path/subpath"),
165          ("file:path", "/subpath", "file:path/subpath"),
166          ("file:/", "", "file:///"),
167          ("file:/path", "/subpath", "file:///path/subpath"),
168          ("file:///", "", "file:///"),
169          ("file:///", "subpath", "file:///subpath"),
170          ("file:///path", "/subpath", "file:///path/subpath"),
171          ("file:///path/", "subpath", "file:///path/subpath"),
172          ("file:///path", "subpath", "file:///path/subpath"),
173          ("s3://", "", "s3:"),
174          ("s3://", "subpath", "s3:subpath"),
175          ("s3://", "/subpath", "s3:/subpath"),
176          ("s3://host", "subpath", "s3://host/subpath"),
177          ("s3://host", "/subpath", "s3://host/subpath"),
178          ("s3://host/", "subpath", "s3://host/subpath"),
179          ("s3://host/", "/subpath", "s3://host/subpath"),
180          ("s3://host", "subpath/subdir", "s3://host/subpath/subdir"),
181      ])
182  
183  
184  def test_append_to_uri_path_handles_special_uri_characters_in_posixpaths():
185      """
186      Certain characters are treated specially when parsing and interpreting URIs. However, in the
187      case where a URI input for `append_to_uri_path` is simply a POSIX path, these characters should
188      not receive special treatment. This test case verifies that `append_to_uri_path` properly joins
189      POSIX paths containing these characters.
190      """
191  
192      def create_char_case(special_char):
193          def char_case(*case_args):
194              return tuple(item.format(c=special_char) for item in case_args)
195  
196          return char_case
197  
198      for special_char in [
199          ".",
200          "-",
201          "+",
202          ":",
203          "?",
204          "@",
205          "&",
206          "$",
207          "%",
208          "/",
209          "[",
210          "]",
211          "(",
212          ")",
213          "*",
214          "'",
215          ",",
216      ]:
217          char_case = create_char_case(special_char)
218          validate_append_to_uri_path_test_cases([
219              char_case("", "{c}subpath", "{c}subpath"),
220              char_case("", "/{c}subpath", "/{c}subpath"),
221              char_case("dirwith{c}{c}chars", "", "dirwith{c}{c}chars/"),
222              char_case("dirwith{c}{c}chars", "subpath", "dirwith{c}{c}chars/subpath"),
223              char_case("{c}{c}charsdir", "", "{c}{c}charsdir/"),
224              char_case("/{c}{c}charsdir", "", "/{c}{c}charsdir/"),
225              char_case("/{c}{c}charsdir", "subpath", "/{c}{c}charsdir/subpath"),
226              char_case("/{c}{c}charsdir", "subpath", "/{c}{c}charsdir/subpath"),
227          ])
228  
229      validate_append_to_uri_path_test_cases([
230          ("#?charsdir:", ":?subpath#", "#?charsdir:/:?subpath#"),
231          ("/#--+charsdir.//:", "/../:?subpath#", "/#--+charsdir.//:/../:?subpath#"),
232          ("$@''(,", ")]*%", "$@''(,/)]*%"),
233      ])
234  
235  
236  @pytest.mark.parametrize(
237      "uri",
238      [
239          # query string contains '..' (and its encoded form) are considered invalid
240          "https://example.com?..",
241          "https://example.com?/path/../path/../path",
242          "https://example.com?key=value&../../path",
243          "https://example.com?key=value&%2E%2E%2Fpath",
244          "https://example.com?key=value&%252E%252E%252Fpath",
245      ],
246  )
247  def test_append_to_uri_throws_for_malicious_query_string_in_uri(uri):
248      with pytest.raises(MlflowException, match=r"Invalid query string"):
249          append_to_uri_path(uri)
250  
251  
252  @pytest.mark.parametrize(
253      ("uri", "existing_query_params", "query_params", "expected"),
254      [
255          ("https://example.com", "", [("key", "value")], "https://example.com?key=value"),
256          (
257              "https://example.com",
258              "existing_key=existing_value",
259              [("new_key", "new_value")],
260              "https://example.com?existing_key=existing_value&new_key=new_value",
261          ),
262          (
263              "https://example.com",
264              "",
265              [("key1", "value1"), ("key2", "value2"), ("key3", "value3")],
266              "https://example.com?key1=value1&key2=value2&key3=value3",
267          ),
268          (
269              "https://example.com",
270              "",
271              [("key", "value with spaces"), ("key2", "special#characters")],
272              "https://example.com?key=value+with+spaces&key2=special%23characters",
273          ),
274          ("", "", [("key", "value")], "?key=value"),
275          ("https://example.com", "", [], "https://example.com"),
276          (
277              "https://example.com",
278              "",
279              [("key1", 123), ("key2", 456)],
280              "https://example.com?key1=123&key2=456",
281          ),
282          (
283              "https://example.com?existing_key=existing_value",
284              "",
285              [("existing_key", "new_value"), ("existing_key", "new_value_2")],
286              "https://example.com?existing_key=existing_value&existing_key=new_value&existing_key=new_value_2",
287          ),
288          (
289              "s3://bucket/key",
290              "prev1=foo&prev2=bar",
291              [("param1", "value1"), ("param2", "value2")],
292              "s3://bucket/key?prev1=foo&prev2=bar&param1=value1&param2=value2",
293          ),
294          (
295              "s3://bucket/key?existing_param=existing_value",
296              "",
297              [("new_param", "new_value")],
298              "s3://bucket/key?existing_param=existing_value&new_param=new_value",
299          ),
300      ],
301  )
302  def test_append_to_uri_query_params_appends_as_expected(
303      uri, existing_query_params, query_params, expected
304  ):
305      if existing_query_params:
306          uri += f"?{existing_query_params}"
307  
308      result = append_to_uri_query_params(uri, *query_params)
309      assert result == expected
310  
311  
312  def test_append_to_uri_path_preserves_uri_schemes_hosts_queries_and_fragments():
313      validate_append_to_uri_path_test_cases([
314          ("dbscheme+dbdriver:", "", "dbscheme+dbdriver:"),
315          ("dbscheme+dbdriver:", "subpath", "dbscheme+dbdriver:subpath"),
316          ("dbscheme+dbdriver:path", "subpath", "dbscheme+dbdriver:path/subpath"),
317          ("dbscheme+dbdriver://host/path", "/subpath", "dbscheme+dbdriver://host/path/subpath"),
318          ("dbscheme+dbdriver:///path", "subpath", "dbscheme+dbdriver:/path/subpath"),
319          ("dbscheme+dbdriver:?somequery", "subpath", "dbscheme+dbdriver:subpath?somequery"),
320          ("dbscheme+dbdriver:?somequery", "/subpath", "dbscheme+dbdriver:/subpath?somequery"),
321          ("dbscheme+dbdriver:/?somequery", "subpath", "dbscheme+dbdriver:/subpath?somequery"),
322          ("dbscheme+dbdriver://?somequery", "subpath", "dbscheme+dbdriver:subpath?somequery"),
323          ("dbscheme+dbdriver:///?somequery", "/subpath", "dbscheme+dbdriver:/subpath?somequery"),
324          ("dbscheme+dbdriver:#somefrag", "subpath", "dbscheme+dbdriver:subpath#somefrag"),
325          ("dbscheme+dbdriver:#somefrag", "/subpath", "dbscheme+dbdriver:/subpath#somefrag"),
326          ("dbscheme+dbdriver:/#somefrag", "subpath", "dbscheme+dbdriver:/subpath#somefrag"),
327          ("dbscheme+dbdriver://#somefrag", "subpath", "dbscheme+dbdriver:subpath#somefrag"),
328          ("dbscheme+dbdriver:///#somefrag", "/subpath", "dbscheme+dbdriver:/subpath#somefrag"),
329          (
330              "dbscheme+dbdriver://root:password?creds=creds",
331              "subpath",
332              "dbscheme+dbdriver://root:password/subpath?creds=creds",
333          ),
334          (
335              "dbscheme+dbdriver://root:password/path/?creds=creds",
336              "/subpath/anotherpath",
337              "dbscheme+dbdriver://root:password/path/subpath/anotherpath?creds=creds",
338          ),
339          (
340              "dbscheme+dbdriver://root:password///path/?creds=creds",
341              "subpath/anotherpath",
342              "dbscheme+dbdriver://root:password///path/subpath/anotherpath?creds=creds",
343          ),
344          (
345              "dbscheme+dbdriver://root:password///path/?creds=creds",
346              "/subpath",
347              "dbscheme+dbdriver://root:password///path/subpath?creds=creds",
348          ),
349          (
350              "dbscheme+dbdriver://root:password#myfragment",
351              "/subpath",
352              "dbscheme+dbdriver://root:password/subpath#myfragment",
353          ),
354          (
355              "dbscheme+dbdriver://root:password//path/#fragmentwith$pecial@",
356              "subpath/anotherpath",
357              "dbscheme+dbdriver://root:password//path/subpath/anotherpath#fragmentwith$pecial@",
358          ),
359          (
360              "dbscheme+dbdriver://root:password@host?creds=creds#fragmentwith$pecial@",
361              "subpath",
362              "dbscheme+dbdriver://root:password@host/subpath?creds=creds#fragmentwith$pecial@",
363          ),
364          (
365              "dbscheme+dbdriver://root:password@host.com/path?creds=creds#*frag@*",
366              "subpath/dir",
367              "dbscheme+dbdriver://root:password@host.com/path/subpath/dir?creds=creds#*frag@*",
368          ),
369          (
370              "dbscheme-dbdriver://root:password@host.com/path?creds=creds#*frag@*",
371              "subpath/dir",
372              "dbscheme-dbdriver://root:password@host.com/path/subpath/dir?creds=creds#*frag@*",
373          ),
374          (
375              "dbscheme+dbdriver://root:password@host.com/path?creds=creds,param=value#*frag@*",
376              "subpath/dir",
377              "dbscheme+dbdriver://root:password@host.com/path/subpath/dir?"
378              "creds=creds,param=value#*frag@*",
379          ),
380      ])
381  
382  
383  def test_extract_and_normalize_path():
384      base_uri = "databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts"
385      assert (
386          extract_and_normalize_path("dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts")
387          == base_uri
388      )
389      assert (
390          extract_and_normalize_path("dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts")
391          == base_uri
392      )
393      assert (
394          extract_and_normalize_path("dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts")
395          == base_uri
396      )
397      assert (
398          extract_and_normalize_path(
399              "dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/"
400          )
401          == base_uri
402      )
403      assert (
404          extract_and_normalize_path(
405              "dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//"
406          )
407          == base_uri
408      )
409      assert (
410          extract_and_normalize_path(
411              "dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//"
412          )
413          == base_uri
414      )
415  
416  
417  def test_is_databricks_acled_artifacts_uri():
418      assert is_databricks_acled_artifacts_uri(
419          "dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts"
420      )
421      assert is_databricks_acled_artifacts_uri(
422          "dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts"
423      )
424      assert is_databricks_acled_artifacts_uri(
425          "dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts"
426      )
427      assert is_databricks_acled_artifacts_uri(
428          "dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/"
429      )
430      assert is_databricks_acled_artifacts_uri(
431          "dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//"
432      )
433      assert is_databricks_acled_artifacts_uri(
434          "dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//"
435      )
436      assert not is_databricks_acled_artifacts_uri(
437          "dbfs:/databricks/mlflow//EXP_ID//RUN_ID///artifacts//"
438      )
439  
440  
441  def _get_databricks_profile_uri_test_cases():
442      # Each test case is (uri, result, result_scheme)
443      test_case_groups = [
444          [
445              # URIs with no databricks profile info -> return None
446              ("ftp://user:pass@realhost:port/path/to/nowhere", None, result_scheme),
447              ("dbfs:/path/to/nowhere", None, result_scheme),
448              ("dbfs://nondatabricks/path/to/nowhere", None, result_scheme),
449              ("dbfs://incorrect:netloc:format/path/to/nowhere", None, result_scheme),
450              # URIs with legit databricks profile info
451              (f"dbfs://{result_scheme}", result_scheme, result_scheme),
452              (f"dbfs://{result_scheme}/", result_scheme, result_scheme),
453              (f"dbfs://{result_scheme}/path/to/nowhere", result_scheme, result_scheme),
454              (f"dbfs://{result_scheme}:port/path/to/nowhere", result_scheme, result_scheme),
455              (f"dbfs://@{result_scheme}/path/to/nowhere", result_scheme, result_scheme),
456              (f"dbfs://@{result_scheme}:port/path/to/nowhere", result_scheme, result_scheme),
457              (
458                  f"dbfs://profile@{result_scheme}/path/to/nowhere",
459                  f"{result_scheme}://profile",
460                  result_scheme,
461              ),
462              (
463                  f"dbfs://profile@{result_scheme}:port/path/to/nowhere",
464                  f"{result_scheme}://profile",
465                  result_scheme,
466              ),
467              (
468                  f"dbfs://scope:key_prefix@{result_scheme}/path/abc",
469                  f"{result_scheme}://scope:key_prefix",
470                  result_scheme,
471              ),
472              (
473                  f"dbfs://scope:key_prefix@{result_scheme}:port/path/abc",
474                  f"{result_scheme}://scope:key_prefix",
475                  result_scheme,
476              ),
477              # Doesn't care about the scheme of the artifact URI
478              (
479                  f"runs://scope:key_prefix@{result_scheme}/path/abc",
480                  f"{result_scheme}://scope:key_prefix",
481                  result_scheme,
482              ),
483              (
484                  f"models://scope:key_prefix@{result_scheme}/path/abc",
485                  f"{result_scheme}://scope:key_prefix",
486                  result_scheme,
487              ),
488              (
489                  f"s3://scope:key_prefix@{result_scheme}/path/abc",
490                  f"{result_scheme}://scope:key_prefix",
491                  result_scheme,
492              ),
493          ]
494          for result_scheme in ["databricks", "databricks-uc"]
495      ]
496      return [test_case for test_case_group in test_case_groups for test_case in test_case_group]
497  
498  
499  @pytest.mark.parametrize(
500      ("uri", "result", "result_scheme"), _get_databricks_profile_uri_test_cases()
501  )
502  def test_get_databricks_profile_uri_from_artifact_uri(uri, result, result_scheme):
503      assert get_databricks_profile_uri_from_artifact_uri(uri, result_scheme=result_scheme) == result
504  
505  
506  @pytest.mark.parametrize(
507      "uri",
508      [
509          # Treats secret key prefixes with ":" to be invalid
510          "dbfs://incorrect:netloc:format@databricks/path/a",
511          "dbfs://scope::key_prefix@databricks/path/abc",
512          "dbfs://scope:key_prefix:@databricks/path/abc",
513      ],
514  )
515  def test_get_databricks_profile_uri_from_artifact_uri_error_cases(uri):
516      with pytest.raises(MlflowException, match="Unsupported Databricks profile"):
517          get_databricks_profile_uri_from_artifact_uri(uri)
518  
519  
520  @pytest.mark.parametrize(
521      ("uri", "result"),
522      [
523          # URIs with no databricks profile info should stay the same
524          (
525              "ftp://user:pass@realhost:port/path/nowhere",
526              "ftp://user:pass@realhost:port/path/nowhere",
527          ),
528          ("dbfs:/path/to/nowhere", "dbfs:/path/to/nowhere"),
529          ("dbfs://nondatabricks/path/to/nowhere", "dbfs://nondatabricks/path/to/nowhere"),
530          ("dbfs://incorrect:netloc:format/path/", "dbfs://incorrect:netloc:format/path/"),
531          # URIs with legit databricks profile info
532          ("dbfs://databricks", "dbfs:"),
533          ("dbfs://databricks/", "dbfs:/"),
534          ("dbfs://databricks/path/to/nowhere", "dbfs:/path/to/nowhere"),
535          ("dbfs://databricks:port/path/to/nowhere", "dbfs:/path/to/nowhere"),
536          ("dbfs://@databricks/path/to/nowhere", "dbfs:/path/to/nowhere"),
537          ("dbfs://@databricks:port/path/to/nowhere", "dbfs:/path/to/nowhere"),
538          ("dbfs://profile@databricks/path/to/nowhere", "dbfs:/path/to/nowhere"),
539          ("dbfs://profile@databricks:port/path/to/nowhere", "dbfs:/path/to/nowhere"),
540          ("dbfs://scope:key_prefix@databricks/path/abc", "dbfs:/path/abc"),
541          ("dbfs://scope:key_prefix@databricks:port/path/abc", "dbfs:/path/abc"),
542          # Treats secret key prefixes with ":" to be valid
543          ("dbfs://incorrect:netloc:format@databricks/path/to/nowhere", "dbfs:/path/to/nowhere"),
544          # Doesn't care about the scheme of the artifact URI
545          ("runs://scope:key_prefix@databricks/path/abc", "runs:/path/abc"),
546          ("models://scope:key_prefix@databricks/path/abc", "models:/path/abc"),
547          ("s3://scope:key_prefix@databricks/path/abc", "s3:/path/abc"),
548      ],
549  )
550  def test_remove_databricks_profile_info_from_artifact_uri(uri, result):
551      assert remove_databricks_profile_info_from_artifact_uri(uri) == result
552  
553  
554  @pytest.mark.parametrize(
555      ("artifact_uri", "profile_uri", "result"),
556      [
557          # test various profile URIs
558          ("dbfs:/path/a/b", "databricks", "dbfs://databricks/path/a/b"),
559          ("dbfs:/path/a/b/", "databricks", "dbfs://databricks/path/a/b/"),
560          ("dbfs:/path/a/b/", "databricks://Profile", "dbfs://Profile@databricks/path/a/b/"),
561          ("dbfs:/path/a/b/", "databricks://profile/", "dbfs://profile@databricks/path/a/b/"),
562          ("dbfs:/path/a/b/", "databricks://scope:key", "dbfs://scope:key@databricks/path/a/b/"),
563          (
564              "dbfs:/path/a/b/",
565              "databricks://scope:key/random_stuff",
566              "dbfs://scope:key@databricks/path/a/b/",
567          ),
568          ("dbfs:/path/a/b/", "nondatabricks://profile", "dbfs:/path/a/b/"),
569          # test various artifact schemes
570          ("runs:/path/a/b/", "databricks://Profile", "runs://Profile@databricks/path/a/b/"),
571          ("runs:/path/a/b/", "nondatabricks://profile", "runs:/path/a/b/"),
572          ("models:/path/a/b/", "databricks://profile", "models://profile@databricks/path/a/b/"),
573          ("models:/path/a/b/", "nondatabricks://Profile", "models:/path/a/b/"),
574          ("s3:/path/a/b/", "databricks://Profile", "s3:/path/a/b/"),
575          ("s3:/path/a/b/", "nondatabricks://profile", "s3:/path/a/b/"),
576          ("ftp:/path/a/b/", "databricks://profile", "ftp:/path/a/b/"),
577          ("ftp:/path/a/b/", "nondatabricks://Profile", "ftp:/path/a/b/"),
578          # test artifact URIs already with authority
579          ("ftp://user:pass@host:port/a/b", "databricks://Profile", "ftp://user:pass@host:port/a/b"),
580          ("ftp://user:pass@host:port/a/b", "nothing://Profile", "ftp://user:pass@host:port/a/b"),
581          ("dbfs://databricks", "databricks://OtherProfile", "dbfs://databricks"),
582          ("dbfs://databricks", "nondatabricks://Profile", "dbfs://databricks"),
583          ("dbfs://databricks/path/a/b", "databricks://OtherProfile", "dbfs://databricks/path/a/b"),
584          ("dbfs://databricks/path/a/b", "nondatabricks://Profile", "dbfs://databricks/path/a/b"),
585          ("dbfs://@databricks/path/a/b", "databricks://OtherProfile", "dbfs://@databricks/path/a/b"),
586          ("dbfs://@databricks/path/a/b", "nondatabricks://Profile", "dbfs://@databricks/path/a/b"),
587          (
588              "dbfs://profile@databricks/pp",
589              "databricks://OtherProfile",
590              "dbfs://profile@databricks/pp",
591          ),
592          (
593              "dbfs://profile@databricks/path",
594              "databricks://profile",
595              "dbfs://profile@databricks/path",
596          ),
597          (
598              "dbfs://profile@databricks/path",
599              "nondatabricks://Profile",
600              "dbfs://profile@databricks/path",
601          ),
602      ],
603  )
604  def test_add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri, result):
605      assert add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri) == result
606  
607  
608  @pytest.mark.parametrize(
609      ("artifact_uri", "profile_uri"),
610      [
611          ("dbfs:/path/a/b", "databricks://not:legit:auth"),
612          ("dbfs:/path/a/b/", "databricks://scope::key"),
613          ("dbfs:/path/a/b/", "databricks://scope:key:/"),
614          ("dbfs:/path/a/b/", "databricks://scope:key "),
615      ],
616  )
617  def test_add_databricks_profile_info_to_artifact_uri_errors(artifact_uri, profile_uri):
618      with pytest.raises(MlflowException, match="Unsupported Databricks profile"):
619          add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri)
620  
621  
622  @pytest.mark.parametrize(
623      ("uri", "result"),
624      [
625          ("dbfs:/path/a/b", True),
626          ("dbfs://databricks/a/b", True),
627          ("dbfs://@databricks/a/b", True),
628          ("dbfs://profile@databricks/a/b", True),
629          ("dbfs://scope:key@databricks/a/b", True),
630          ("dbfs://scope:key:@databricks/a/b", False),
631          ("dbfs://scope::key@databricks/a/b", False),
632          ("dbfs://profile@notdatabricks/a/b", False),
633          ("dbfs://scope:key@notdatabricks/a/b", False),
634          ("dbfs://scope:key/a/b", False),
635          ("dbfs://notdatabricks/a/b", False),
636          ("s3:/path/a/b", False),
637          ("ftp://user:pass@host:port/path/a/b", False),
638          ("ftp://user:pass@databricks/path/a/b", False),
639      ],
640  )
641  def test_is_valid_dbfs_uri(uri, result):
642      assert is_valid_dbfs_uri(uri) == result
643  
644  
645  @pytest.mark.parametrize(
646      ("uri", "result"),
647      [
648          ("/tmp/path", "/dbfs/tmp/path"),
649          ("dbfs:/path", "/dbfs/path"),
650          ("dbfs:/path/a/b", "/dbfs/path/a/b"),
651          ("dbfs:/dbfs/123/abc", "/dbfs/dbfs/123/abc"),
652      ],
653  )
654  def test_dbfs_hdfs_uri_to_fuse_path(uri, result):
655      assert dbfs_hdfs_uri_to_fuse_path(uri) == result
656  
657  
658  @pytest.mark.parametrize(
659      "path",
660      ["some/relative/local/path", "s3:/some/s3/path", "C:/cool/windows/path"],
661  )
662  def test_dbfs_hdfs_uri_to_fuse_path_raises(path):
663      with pytest.raises(MlflowException, match="did not start with expected DBFS URI prefix"):
664          dbfs_hdfs_uri_to_fuse_path(path)
665  
666  
667  def _assert_resolve_uri_if_local(input_uri, expected_uri):
668      cwd = pathlib.Path.cwd().as_posix()
669      drive = pathlib.Path.cwd().drive
670      if is_windows():
671          cwd = f"/{cwd}"
672          drive = f"{drive}/"
673      assert resolve_uri_if_local(input_uri) == expected_uri.format(cwd=cwd, drive=drive)
674  
675  
676  @pytest.mark.skipif(is_windows(), reason="This test fails on Windows")
677  @pytest.mark.parametrize(
678      ("input_uri", "expected_uri"),
679      [
680          ("my/path", "{cwd}/my/path"),
681          ("#my/path?a=b", "{cwd}/#my/path?a=b"),
682          ("file://localhost/my/path", "file://localhost/my/path"),
683          ("file:///my/path", "file:///{drive}my/path"),
684          ("file:my/path", "file://{cwd}/my/path"),
685          ("/home/my/path", "/home/my/path"),
686          ("dbfs://databricks/a/b", "dbfs://databricks/a/b"),
687          ("s3://host/my/path", "s3://host/my/path"),
688      ],
689  )
690  def test_resolve_uri_if_local(input_uri, expected_uri):
691      _assert_resolve_uri_if_local(input_uri, expected_uri)
692  
693  
694  @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
695  @pytest.mark.parametrize(
696      ("input_uri", "expected_uri"),
697      [
698          ("my/path", "file://{cwd}/my/path"),
699          ("#my/path?a=b", "file://{cwd}/#my/path?a=b"),
700          ("\\myhostname/my/path", "file:///{drive}myhostname/my/path"),
701          ("file:///my/path", "file:///{drive}my/path"),
702          ("file:my/path", "file://{cwd}/my/path"),
703          ("/home/my/path", "file:///{drive}home/my/path"),
704          ("dbfs://databricks/a/b", "dbfs://databricks/a/b"),
705          ("s3://host/my/path", "s3://host/my/path"),
706      ],
707  )
708  def test_resolve_uri_if_local_on_windows(input_uri, expected_uri):
709      _assert_resolve_uri_if_local(input_uri, expected_uri)
710  
711  
712  @pytest.mark.parametrize(
713      "uri",
714      [
715          "/dbfs/my_path",
716          "dbfs:/my_path",
717          "/Volumes/my_path",
718          "/.fuse-mounts/my_path",
719          "//dbfs////my_path",
720          "///Volumes/",
721          "dbfs://my///path",
722          "/volumes/path/to/file",
723          "/volumes/",
724          "DBFS:/my/path",
725      ],
726  )
727  def test_correctly_detect_fuse_and_uc_uris(uri):
728      assert is_fuse_or_uc_volumes_uri(uri)
729  
730  
731  @pytest.mark.parametrize(
732      "uri",
733      [
734          "/My_Volumes/my_path",
735          "s3a:/my_path",
736          "Volumes/my_path",
737          "Volume:/my_path",
738          "dbfs/my_path",
739          "/fuse-mounts/my_path",
740      ],
741  )
742  def test_negative_detection(uri):
743      assert not is_fuse_or_uc_volumes_uri(uri)
744  
745  
746  @pytest.mark.parametrize(
747      "path",
748      [
749          "path",
750          "path/",
751          "path/to/file",
752          "dog%step%100%timestamp%100",
753          "dog+step+100+timestamp+100",
754      ],
755  )
756  def test_validate_path_is_safe_good(path):
757      validate_path_is_safe(path)
758  
759  
760  @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
761  @pytest.mark.parametrize(
762      "path",
763      [
764          # relative path from current directory of C: drive
765          ".../...//",
766      ],
767  )
768  def test_validate_path_is_safe_windows_good(path):
769      validate_path_is_safe(path)
770  
771  
772  @pytest.mark.skipif(is_windows(), reason="This test does not pass on Windows")
773  @pytest.mark.parametrize(
774      "path",
775      [
776          "/path",
777          "../path",
778          "../../path",
779          "./../path",
780          "path/../to/file",
781          "path/../../to/file",
782          "file://a#/..//tmp",
783          "file://a%23/..//tmp/",
784          "/etc/passwd",
785          "/etc/passwd%00.jpg",
786          "/etc/passwd%00.html",
787          "/etc/passwd%00.txt",
788          "/etc/passwd%00.php",
789          "/etc/passwd%00.asp",
790          "/file://etc/passwd",
791          # Encoded paths with '..'
792          "%2E%2E%2Fpath",
793          "%2E%2E%2F%2E%2E%2Fpath",
794          # Some URIs are passed to urllib.parse.urlparse after validation,
795          # which strips out some whitespace characters. If they are further
796          # decoded, this could result in a path that is not safe.
797          # In this example, %2%0952e -> %2\t52e -> %252e -> %2e -> .
798          "%2%0952e%2%0952e/%2%0A52e%2%0A52e/path",
799      ],
800  )
801  def test_validate_path_is_safe_bad(path):
802      with pytest.raises(MlflowException, match="Invalid path"):
803          validate_path_is_safe(path)
804  
805  
806  @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
807  @pytest.mark.parametrize(
808      "path",
809      [
810          r"../path",
811          r"../../path",
812          r"./../path",
813          r"path/../to/file",
814          r"path/../../to/file",
815          r"..\path",
816          r"..\..\path",
817          r".\..\path",
818          r"path\..\to\file",
819          r"path\..\..\to\file",
820          # Drive-relative paths
821          r"C:path",
822          r"C:path/",
823          r"C:path/to/file",
824          r"C:../path/to/file",
825          r"C:\path",
826          r"C:/path",
827          r"C:\path\to\file",
828          r"C:\path/to/file",
829          r"C:\path\..\to\file",
830          r"C:/path/../to/file",
831          # UNC(Universal Naming Convention) paths
832          r"\\path\to\file",
833          r"\\path/to/file",
834          r"\\.\\C:\path\to\file",
835          r"\\?\C:\path\to\file",
836          r"\\?\UNC/path/to/file",
837          # Other potential attackable paths
838          r"/etc/password",
839          r"/path",
840          r"/etc/passwd%00.jpg",
841          r"/etc/passwd%00.html",
842          r"/etc/passwd%00.txt",
843          r"/etc/passwd%00.php",
844          r"/etc/passwd%00.asp",
845          r"/Windows/no/such/path",
846          r"/file://etc/passwd",
847          r"/file:c:/passwd",
848          r"/file://d:/windows/win.ini",
849          r"/file://./windows/win.ini",
850          r"file://c:/boot.ini",
851          r"file://C:path",
852          r"file://C:path/",
853          r"file://C:path/to/file",
854          r"file:///C:/Windows/System32/",
855          r"file:///etc/passwd",
856          r"file:///d:/windows/repair/sam",
857          r"file:///proc/version",
858          r"file:///inetpub/wwwroot/global.asa",
859          r"/file://../windows/win.ini",
860          r"../etc/passwd",
861          r"..\Windows\System32\\",
862          r"C:\Windows\System32\\",
863          r"/etc/passwd",
864          r"::Windows\System32",
865          r"..\..\..\..\Windows\System32\\",
866          r"../Windows/System32",
867          r"....\\",
868          r"\\?\C:\Windows\System32\\",
869          r"\\.\C:\Windows\System32\\",
870          r"\\UNC\Server\Share\\",
871          r"\\Server\Share\folder\\",
872          r"\\127.0.0.1\c$\Windows\\",
873          r"\\localhost\c$\Windows\\",
874          r"\\smbserver\share\path\\",
875          r"..\\?\C:\Windows\System32\\",
876          r"C:/Windows/../Windows/System32/",
877          r"C:\Windows\..\Windows\System32\\",
878          r"../../../../../../../../../../../../Windows/System32",
879          r"../../../../../../../../../../../../etc/passwd",
880          r"../../../../../../../../../../../../var/www/html/index.html",
881          r"../../../../../../../../../../../../usr/local/etc/openvpn/server.conf",
882          r"../../../../../../../../../../../../Program Files (x86)",
883          r"/../../../../../../../../../../../../Windows/System32",
884          r"/Windows\../etc/passwd",
885          r"/Windows\..\Windows\System32\\",
886          r"/Windows\..\Windows\System32\cmd.exe",
887          r"/Windows\..\Windows\System32\msconfig.exe",
888          r"/Windows\..\Windows\System32\regedit.exe",
889          r"/Windows\..\Windows\System32\taskmgr.exe",
890          r"/Windows\..\Windows\System32\control.exe",
891          r"/Windows\..\Windows\System32\services.msc",
892          r"/Windows\..\Windows\System32\diskmgmt.msc",
893          r"/Windows\..\Windows\System32\eventvwr.msc",
894          r"/Windows/System32/drivers/etc/hosts",
895      ],
896  )
897  def test_validate_path_is_safe_windows_bad(path):
898      with pytest.raises(MlflowException, match="Invalid path"):
899          validate_path_is_safe(path)
900  
901  
902  @pytest.mark.parametrize(
903      ("uri", "expected"),
904      [
905          ("file:///path", "/path"),
906          ("file://host/path", "//host/path"),
907          ("file://host", "//host"),
908      ],
909  )
910  def test_strip_scheme(uri: str, expected: str):
911      assert strip_scheme(uri) == expected
912  
913  
914  def test_validate_path_within_directory_allows_valid_paths(tmp_path):
915      base_dir = tmp_path / "artifacts"
916      base_dir.mkdir()
917      constructed_path = base_dir / "subdir" / "file.txt"
918      result = validate_path_within_directory(str(base_dir), str(constructed_path))
919      assert result == str(constructed_path)
920  
921  
922  def test_validate_path_within_directory_blocks_symlink_escape(tmp_path):
923      base_dir = tmp_path / "artifacts"
924      base_dir.mkdir()
925      external_dir = tmp_path / "external"
926      external_dir.mkdir()
927      external_file = external_dir / "secret.txt"
928      external_file.write_text("SECRET")
929      symlink_path = base_dir / "leak"
930      os.symlink(str(external_dir), str(symlink_path))
931      constructed_path = symlink_path / "secret.txt"
932      with pytest.raises(MlflowException, match="resolved path is outside the artifact directory"):
933          validate_path_within_directory(str(base_dir), str(constructed_path))
934  
935  
936  def test_validate_path_within_directory_blocks_parent_symlink(tmp_path):
937      base_dir = tmp_path / "artifacts"
938      base_dir.mkdir()
939      symlink_path = base_dir / "parent"
940      os.symlink(str(tmp_path), str(symlink_path))
941      constructed_path = symlink_path / "artifacts" / ".." / "external"
942      with pytest.raises(MlflowException, match="resolved path is outside the artifact directory"):
943          validate_path_within_directory(str(base_dir), str(constructed_path))
944  
945  
946  def test_validate_path_within_directory_allows_internal_symlink(tmp_path):
947      base_dir = tmp_path / "artifacts"
948      base_dir.mkdir()
949      real_file = base_dir / "real_file.txt"
950      real_file.write_text("CONTENT")
951      symlink_path = base_dir / "link"
952      os.symlink(str(real_file), str(symlink_path))
953      result = validate_path_within_directory(str(base_dir), str(symlink_path))
954      assert result == str(symlink_path)
955  
956  
957  def test_validate_path_within_directory_allows_base_dir_itself(tmp_path):
958      base_dir = tmp_path / "artifacts"
959      base_dir.mkdir()
960      result = validate_path_within_directory(str(base_dir), str(base_dir))
961      assert result == str(base_dir)
962  
963  
964  def test_validate_path_within_directory_allows_subdirectory_symlink(tmp_path):
965      base_dir = tmp_path / "artifacts"
966      base_dir.mkdir()
967      subdir = base_dir / "subdir"
968      subdir.mkdir()
969      file_in_subdir = subdir / "file.txt"
970      file_in_subdir.write_text("CONTENT")
971      symlink_to_subdir = base_dir / "link_to_subdir"
972      os.symlink(str(subdir), str(symlink_to_subdir))
973      constructed_path = symlink_to_subdir / "file.txt"
974      result = validate_path_within_directory(str(base_dir), str(constructed_path))
975      assert result == str(constructed_path)