test_uri.py
1 import os 2 import pathlib 3 import posixpath 4 5 import pytest 6 7 from mlflow.exceptions import MlflowException 8 from mlflow.store.db.db_types import DATABASE_ENGINES 9 from mlflow.utils.os import is_windows 10 from mlflow.utils.uri import ( 11 add_databricks_profile_info_to_artifact_uri, 12 append_to_uri_path, 13 append_to_uri_query_params, 14 dbfs_hdfs_uri_to_fuse_path, 15 extract_and_normalize_path, 16 extract_db_type_from_uri, 17 get_databricks_profile_uri_from_artifact_uri, 18 get_db_info_from_uri, 19 get_uri_scheme, 20 is_databricks_acled_artifacts_uri, 21 is_databricks_uri, 22 is_fuse_or_uc_volumes_uri, 23 is_http_uri, 24 is_local_uri, 25 is_valid_dbfs_uri, 26 remove_databricks_profile_info_from_artifact_uri, 27 resolve_uri_if_local, 28 strip_scheme, 29 validate_path_is_safe, 30 validate_path_within_directory, 31 ) 32 33 34 def test_extract_db_type_from_uri(): 35 uri = "{}://username:password@host:port/database" 36 for legit_db in DATABASE_ENGINES: 37 assert legit_db == extract_db_type_from_uri(uri.format(legit_db)) 38 assert legit_db == get_uri_scheme(uri.format(legit_db)) 39 40 with_driver = legit_db + "+driver-string" 41 assert legit_db == extract_db_type_from_uri(uri.format(with_driver)) 42 assert legit_db == get_uri_scheme(uri.format(with_driver)) 43 44 for unsupported_db in ["a", "aa", "sql"]: 45 with pytest.raises(MlflowException, match="Invalid database engine"): 46 extract_db_type_from_uri(unsupported_db) 47 48 49 @pytest.mark.parametrize( 50 ("server_uri", "result"), 51 [ 52 ("databricks://aAbB", ("aAbB", None)), 53 ("databricks://aAbB/", ("aAbB", None)), 54 ("databricks://aAbB/path", ("aAbB", None)), 55 ("databricks://profile:prefix", ("profile", "prefix")), 56 ("databricks://profile:prefix/extra", ("profile", "prefix")), 57 ("nondatabricks://profile:prefix", (None, None)), 58 ("databricks://profile", ("profile", None)), 59 ("databricks://profile/", ("profile", None)), 60 ("databricks-uc://profile:prefix", ("profile", "prefix")), 61 ("databricks-uc://profile:prefix/extra", ("profile", "prefix")), 62 ("databricks-uc://profile", ("profile", None)), 63 ("databricks-uc://profile/", ("profile", None)), 64 ], 65 ) 66 def test_get_db_info_from_uri(server_uri, result): 67 assert get_db_info_from_uri(server_uri) == result 68 69 70 @pytest.mark.parametrize( 71 "server_uri", 72 ["databricks:/profile:prefix", "databricks:/", "databricks://"], 73 ) 74 def test_get_db_info_from_uri_errors_no_netloc(server_uri): 75 with pytest.raises(MlflowException, match="URI is formatted incorrectly"): 76 get_db_info_from_uri(server_uri) 77 78 79 @pytest.mark.parametrize( 80 "server_uri", 81 [ 82 "databricks://profile:prefix:extra", 83 "databricks://profile:prefix:extra ", 84 "databricks://profile:prefix extra", 85 "databricks://profile:prefix ", 86 "databricks://profile ", 87 "databricks://profile:", 88 "databricks://profile: ", 89 ], 90 ) 91 def test_get_db_info_from_uri_errors_invalid_profile(server_uri): 92 with pytest.raises(MlflowException, match="Unsupported Databricks profile"): 93 get_db_info_from_uri(server_uri) 94 95 96 def test_is_local_uri(): 97 assert is_local_uri("mlruns") 98 assert is_local_uri("./mlruns") 99 assert is_local_uri("file:///foo/mlruns") 100 assert is_local_uri("file:foo/mlruns") 101 assert is_local_uri("file://./mlruns") 102 assert is_local_uri("file://localhost/mlruns") 103 assert is_local_uri("file://localhost:5000/mlruns") 104 assert is_local_uri("file://127.0.0.1/mlruns") 105 assert is_local_uri("file://127.0.0.1:5000/mlruns") 106 assert is_local_uri("//proc/self/root") 107 assert is_local_uri("/proc/self/root") 108 109 assert not is_local_uri("https://whatever") 110 assert not is_local_uri("http://whatever") 111 assert not is_local_uri("databricks") 112 assert not is_local_uri("databricks:whatever") 113 assert not is_local_uri("databricks://whatever") 114 115 with pytest.raises(MlflowException, match="is not a valid remote uri."): 116 is_local_uri("file://myhostname/path/to/file") 117 118 119 @pytest.mark.skipif(not is_windows(), reason="Windows-only test") 120 def test_is_local_uri_windows(): 121 assert is_local_uri("C:\\foo\\mlruns") 122 assert is_local_uri("C:/foo/mlruns") 123 assert is_local_uri("file:///C:\\foo\\mlruns") 124 assert not is_local_uri("\\\\server\\aa\\bb") 125 126 127 def test_is_databricks_uri(): 128 assert is_databricks_uri("databricks") 129 assert is_databricks_uri("databricks:whatever") 130 assert is_databricks_uri("databricks://whatever") 131 assert not is_databricks_uri("mlruns") 132 assert not is_databricks_uri("http://whatever") 133 134 135 def test_is_http_uri(): 136 assert is_http_uri("http://whatever") 137 assert is_http_uri("https://whatever") 138 assert not is_http_uri("file://whatever") 139 assert not is_http_uri("databricks://whatever") 140 assert not is_http_uri("mlruns") 141 142 143 def validate_append_to_uri_path_test_cases(cases): 144 for input_uri, input_path, expected_output_uri in cases: 145 assert append_to_uri_path(input_uri, input_path) == expected_output_uri 146 assert append_to_uri_path(input_uri, *posixpath.split(input_path)) == expected_output_uri 147 148 149 def test_append_to_uri_path_joins_uri_paths_and_posixpaths_correctly(): 150 validate_append_to_uri_path_test_cases([ 151 ("", "path", "path"), 152 ("", "/path", "/path"), 153 ("path", "", "path/"), 154 ("path", "subpath", "path/subpath"), 155 ("path/", "subpath", "path/subpath"), 156 ("path/", "/subpath", "path/subpath"), 157 ("path", "/subpath", "path/subpath"), 158 ("/path", "/subpath", "/path/subpath"), 159 ("//path", "/subpath", "//path/subpath"), 160 ("///path", "/subpath", "///path/subpath"), 161 ("/path", "/subpath/subdir", "/path/subpath/subdir"), 162 ("file:path", "", "file:path/"), 163 ("file:path/", "", "file:path/"), 164 ("file:path", "subpath", "file:path/subpath"), 165 ("file:path", "/subpath", "file:path/subpath"), 166 ("file:/", "", "file:///"), 167 ("file:/path", "/subpath", "file:///path/subpath"), 168 ("file:///", "", "file:///"), 169 ("file:///", "subpath", "file:///subpath"), 170 ("file:///path", "/subpath", "file:///path/subpath"), 171 ("file:///path/", "subpath", "file:///path/subpath"), 172 ("file:///path", "subpath", "file:///path/subpath"), 173 ("s3://", "", "s3:"), 174 ("s3://", "subpath", "s3:subpath"), 175 ("s3://", "/subpath", "s3:/subpath"), 176 ("s3://host", "subpath", "s3://host/subpath"), 177 ("s3://host", "/subpath", "s3://host/subpath"), 178 ("s3://host/", "subpath", "s3://host/subpath"), 179 ("s3://host/", "/subpath", "s3://host/subpath"), 180 ("s3://host", "subpath/subdir", "s3://host/subpath/subdir"), 181 ]) 182 183 184 def test_append_to_uri_path_handles_special_uri_characters_in_posixpaths(): 185 """ 186 Certain characters are treated specially when parsing and interpreting URIs. However, in the 187 case where a URI input for `append_to_uri_path` is simply a POSIX path, these characters should 188 not receive special treatment. This test case verifies that `append_to_uri_path` properly joins 189 POSIX paths containing these characters. 190 """ 191 192 def create_char_case(special_char): 193 def char_case(*case_args): 194 return tuple(item.format(c=special_char) for item in case_args) 195 196 return char_case 197 198 for special_char in [ 199 ".", 200 "-", 201 "+", 202 ":", 203 "?", 204 "@", 205 "&", 206 "$", 207 "%", 208 "/", 209 "[", 210 "]", 211 "(", 212 ")", 213 "*", 214 "'", 215 ",", 216 ]: 217 char_case = create_char_case(special_char) 218 validate_append_to_uri_path_test_cases([ 219 char_case("", "{c}subpath", "{c}subpath"), 220 char_case("", "/{c}subpath", "/{c}subpath"), 221 char_case("dirwith{c}{c}chars", "", "dirwith{c}{c}chars/"), 222 char_case("dirwith{c}{c}chars", "subpath", "dirwith{c}{c}chars/subpath"), 223 char_case("{c}{c}charsdir", "", "{c}{c}charsdir/"), 224 char_case("/{c}{c}charsdir", "", "/{c}{c}charsdir/"), 225 char_case("/{c}{c}charsdir", "subpath", "/{c}{c}charsdir/subpath"), 226 char_case("/{c}{c}charsdir", "subpath", "/{c}{c}charsdir/subpath"), 227 ]) 228 229 validate_append_to_uri_path_test_cases([ 230 ("#?charsdir:", ":?subpath#", "#?charsdir:/:?subpath#"), 231 ("/#--+charsdir.//:", "/../:?subpath#", "/#--+charsdir.//:/../:?subpath#"), 232 ("$@''(,", ")]*%", "$@''(,/)]*%"), 233 ]) 234 235 236 @pytest.mark.parametrize( 237 "uri", 238 [ 239 # query string contains '..' (and its encoded form) are considered invalid 240 "https://example.com?..", 241 "https://example.com?/path/../path/../path", 242 "https://example.com?key=value&../../path", 243 "https://example.com?key=value&%2E%2E%2Fpath", 244 "https://example.com?key=value&%252E%252E%252Fpath", 245 ], 246 ) 247 def test_append_to_uri_throws_for_malicious_query_string_in_uri(uri): 248 with pytest.raises(MlflowException, match=r"Invalid query string"): 249 append_to_uri_path(uri) 250 251 252 @pytest.mark.parametrize( 253 ("uri", "existing_query_params", "query_params", "expected"), 254 [ 255 ("https://example.com", "", [("key", "value")], "https://example.com?key=value"), 256 ( 257 "https://example.com", 258 "existing_key=existing_value", 259 [("new_key", "new_value")], 260 "https://example.com?existing_key=existing_value&new_key=new_value", 261 ), 262 ( 263 "https://example.com", 264 "", 265 [("key1", "value1"), ("key2", "value2"), ("key3", "value3")], 266 "https://example.com?key1=value1&key2=value2&key3=value3", 267 ), 268 ( 269 "https://example.com", 270 "", 271 [("key", "value with spaces"), ("key2", "special#characters")], 272 "https://example.com?key=value+with+spaces&key2=special%23characters", 273 ), 274 ("", "", [("key", "value")], "?key=value"), 275 ("https://example.com", "", [], "https://example.com"), 276 ( 277 "https://example.com", 278 "", 279 [("key1", 123), ("key2", 456)], 280 "https://example.com?key1=123&key2=456", 281 ), 282 ( 283 "https://example.com?existing_key=existing_value", 284 "", 285 [("existing_key", "new_value"), ("existing_key", "new_value_2")], 286 "https://example.com?existing_key=existing_value&existing_key=new_value&existing_key=new_value_2", 287 ), 288 ( 289 "s3://bucket/key", 290 "prev1=foo&prev2=bar", 291 [("param1", "value1"), ("param2", "value2")], 292 "s3://bucket/key?prev1=foo&prev2=bar¶m1=value1¶m2=value2", 293 ), 294 ( 295 "s3://bucket/key?existing_param=existing_value", 296 "", 297 [("new_param", "new_value")], 298 "s3://bucket/key?existing_param=existing_value&new_param=new_value", 299 ), 300 ], 301 ) 302 def test_append_to_uri_query_params_appends_as_expected( 303 uri, existing_query_params, query_params, expected 304 ): 305 if existing_query_params: 306 uri += f"?{existing_query_params}" 307 308 result = append_to_uri_query_params(uri, *query_params) 309 assert result == expected 310 311 312 def test_append_to_uri_path_preserves_uri_schemes_hosts_queries_and_fragments(): 313 validate_append_to_uri_path_test_cases([ 314 ("dbscheme+dbdriver:", "", "dbscheme+dbdriver:"), 315 ("dbscheme+dbdriver:", "subpath", "dbscheme+dbdriver:subpath"), 316 ("dbscheme+dbdriver:path", "subpath", "dbscheme+dbdriver:path/subpath"), 317 ("dbscheme+dbdriver://host/path", "/subpath", "dbscheme+dbdriver://host/path/subpath"), 318 ("dbscheme+dbdriver:///path", "subpath", "dbscheme+dbdriver:/path/subpath"), 319 ("dbscheme+dbdriver:?somequery", "subpath", "dbscheme+dbdriver:subpath?somequery"), 320 ("dbscheme+dbdriver:?somequery", "/subpath", "dbscheme+dbdriver:/subpath?somequery"), 321 ("dbscheme+dbdriver:/?somequery", "subpath", "dbscheme+dbdriver:/subpath?somequery"), 322 ("dbscheme+dbdriver://?somequery", "subpath", "dbscheme+dbdriver:subpath?somequery"), 323 ("dbscheme+dbdriver:///?somequery", "/subpath", "dbscheme+dbdriver:/subpath?somequery"), 324 ("dbscheme+dbdriver:#somefrag", "subpath", "dbscheme+dbdriver:subpath#somefrag"), 325 ("dbscheme+dbdriver:#somefrag", "/subpath", "dbscheme+dbdriver:/subpath#somefrag"), 326 ("dbscheme+dbdriver:/#somefrag", "subpath", "dbscheme+dbdriver:/subpath#somefrag"), 327 ("dbscheme+dbdriver://#somefrag", "subpath", "dbscheme+dbdriver:subpath#somefrag"), 328 ("dbscheme+dbdriver:///#somefrag", "/subpath", "dbscheme+dbdriver:/subpath#somefrag"), 329 ( 330 "dbscheme+dbdriver://root:password?creds=creds", 331 "subpath", 332 "dbscheme+dbdriver://root:password/subpath?creds=creds", 333 ), 334 ( 335 "dbscheme+dbdriver://root:password/path/?creds=creds", 336 "/subpath/anotherpath", 337 "dbscheme+dbdriver://root:password/path/subpath/anotherpath?creds=creds", 338 ), 339 ( 340 "dbscheme+dbdriver://root:password///path/?creds=creds", 341 "subpath/anotherpath", 342 "dbscheme+dbdriver://root:password///path/subpath/anotherpath?creds=creds", 343 ), 344 ( 345 "dbscheme+dbdriver://root:password///path/?creds=creds", 346 "/subpath", 347 "dbscheme+dbdriver://root:password///path/subpath?creds=creds", 348 ), 349 ( 350 "dbscheme+dbdriver://root:password#myfragment", 351 "/subpath", 352 "dbscheme+dbdriver://root:password/subpath#myfragment", 353 ), 354 ( 355 "dbscheme+dbdriver://root:password//path/#fragmentwith$pecial@", 356 "subpath/anotherpath", 357 "dbscheme+dbdriver://root:password//path/subpath/anotherpath#fragmentwith$pecial@", 358 ), 359 ( 360 "dbscheme+dbdriver://root:password@host?creds=creds#fragmentwith$pecial@", 361 "subpath", 362 "dbscheme+dbdriver://root:password@host/subpath?creds=creds#fragmentwith$pecial@", 363 ), 364 ( 365 "dbscheme+dbdriver://root:password@host.com/path?creds=creds#*frag@*", 366 "subpath/dir", 367 "dbscheme+dbdriver://root:password@host.com/path/subpath/dir?creds=creds#*frag@*", 368 ), 369 ( 370 "dbscheme-dbdriver://root:password@host.com/path?creds=creds#*frag@*", 371 "subpath/dir", 372 "dbscheme-dbdriver://root:password@host.com/path/subpath/dir?creds=creds#*frag@*", 373 ), 374 ( 375 "dbscheme+dbdriver://root:password@host.com/path?creds=creds,param=value#*frag@*", 376 "subpath/dir", 377 "dbscheme+dbdriver://root:password@host.com/path/subpath/dir?" 378 "creds=creds,param=value#*frag@*", 379 ), 380 ]) 381 382 383 def test_extract_and_normalize_path(): 384 base_uri = "databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts" 385 assert ( 386 extract_and_normalize_path("dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts") 387 == base_uri 388 ) 389 assert ( 390 extract_and_normalize_path("dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts") 391 == base_uri 392 ) 393 assert ( 394 extract_and_normalize_path("dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts") 395 == base_uri 396 ) 397 assert ( 398 extract_and_normalize_path( 399 "dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/" 400 ) 401 == base_uri 402 ) 403 assert ( 404 extract_and_normalize_path( 405 "dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//" 406 ) 407 == base_uri 408 ) 409 assert ( 410 extract_and_normalize_path( 411 "dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//" 412 ) 413 == base_uri 414 ) 415 416 417 def test_is_databricks_acled_artifacts_uri(): 418 assert is_databricks_acled_artifacts_uri( 419 "dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts" 420 ) 421 assert is_databricks_acled_artifacts_uri( 422 "dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts" 423 ) 424 assert is_databricks_acled_artifacts_uri( 425 "dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts" 426 ) 427 assert is_databricks_acled_artifacts_uri( 428 "dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/" 429 ) 430 assert is_databricks_acled_artifacts_uri( 431 "dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//" 432 ) 433 assert is_databricks_acled_artifacts_uri( 434 "dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//" 435 ) 436 assert not is_databricks_acled_artifacts_uri( 437 "dbfs:/databricks/mlflow//EXP_ID//RUN_ID///artifacts//" 438 ) 439 440 441 def _get_databricks_profile_uri_test_cases(): 442 # Each test case is (uri, result, result_scheme) 443 test_case_groups = [ 444 [ 445 # URIs with no databricks profile info -> return None 446 ("ftp://user:pass@realhost:port/path/to/nowhere", None, result_scheme), 447 ("dbfs:/path/to/nowhere", None, result_scheme), 448 ("dbfs://nondatabricks/path/to/nowhere", None, result_scheme), 449 ("dbfs://incorrect:netloc:format/path/to/nowhere", None, result_scheme), 450 # URIs with legit databricks profile info 451 (f"dbfs://{result_scheme}", result_scheme, result_scheme), 452 (f"dbfs://{result_scheme}/", result_scheme, result_scheme), 453 (f"dbfs://{result_scheme}/path/to/nowhere", result_scheme, result_scheme), 454 (f"dbfs://{result_scheme}:port/path/to/nowhere", result_scheme, result_scheme), 455 (f"dbfs://@{result_scheme}/path/to/nowhere", result_scheme, result_scheme), 456 (f"dbfs://@{result_scheme}:port/path/to/nowhere", result_scheme, result_scheme), 457 ( 458 f"dbfs://profile@{result_scheme}/path/to/nowhere", 459 f"{result_scheme}://profile", 460 result_scheme, 461 ), 462 ( 463 f"dbfs://profile@{result_scheme}:port/path/to/nowhere", 464 f"{result_scheme}://profile", 465 result_scheme, 466 ), 467 ( 468 f"dbfs://scope:key_prefix@{result_scheme}/path/abc", 469 f"{result_scheme}://scope:key_prefix", 470 result_scheme, 471 ), 472 ( 473 f"dbfs://scope:key_prefix@{result_scheme}:port/path/abc", 474 f"{result_scheme}://scope:key_prefix", 475 result_scheme, 476 ), 477 # Doesn't care about the scheme of the artifact URI 478 ( 479 f"runs://scope:key_prefix@{result_scheme}/path/abc", 480 f"{result_scheme}://scope:key_prefix", 481 result_scheme, 482 ), 483 ( 484 f"models://scope:key_prefix@{result_scheme}/path/abc", 485 f"{result_scheme}://scope:key_prefix", 486 result_scheme, 487 ), 488 ( 489 f"s3://scope:key_prefix@{result_scheme}/path/abc", 490 f"{result_scheme}://scope:key_prefix", 491 result_scheme, 492 ), 493 ] 494 for result_scheme in ["databricks", "databricks-uc"] 495 ] 496 return [test_case for test_case_group in test_case_groups for test_case in test_case_group] 497 498 499 @pytest.mark.parametrize( 500 ("uri", "result", "result_scheme"), _get_databricks_profile_uri_test_cases() 501 ) 502 def test_get_databricks_profile_uri_from_artifact_uri(uri, result, result_scheme): 503 assert get_databricks_profile_uri_from_artifact_uri(uri, result_scheme=result_scheme) == result 504 505 506 @pytest.mark.parametrize( 507 "uri", 508 [ 509 # Treats secret key prefixes with ":" to be invalid 510 "dbfs://incorrect:netloc:format@databricks/path/a", 511 "dbfs://scope::key_prefix@databricks/path/abc", 512 "dbfs://scope:key_prefix:@databricks/path/abc", 513 ], 514 ) 515 def test_get_databricks_profile_uri_from_artifact_uri_error_cases(uri): 516 with pytest.raises(MlflowException, match="Unsupported Databricks profile"): 517 get_databricks_profile_uri_from_artifact_uri(uri) 518 519 520 @pytest.mark.parametrize( 521 ("uri", "result"), 522 [ 523 # URIs with no databricks profile info should stay the same 524 ( 525 "ftp://user:pass@realhost:port/path/nowhere", 526 "ftp://user:pass@realhost:port/path/nowhere", 527 ), 528 ("dbfs:/path/to/nowhere", "dbfs:/path/to/nowhere"), 529 ("dbfs://nondatabricks/path/to/nowhere", "dbfs://nondatabricks/path/to/nowhere"), 530 ("dbfs://incorrect:netloc:format/path/", "dbfs://incorrect:netloc:format/path/"), 531 # URIs with legit databricks profile info 532 ("dbfs://databricks", "dbfs:"), 533 ("dbfs://databricks/", "dbfs:/"), 534 ("dbfs://databricks/path/to/nowhere", "dbfs:/path/to/nowhere"), 535 ("dbfs://databricks:port/path/to/nowhere", "dbfs:/path/to/nowhere"), 536 ("dbfs://@databricks/path/to/nowhere", "dbfs:/path/to/nowhere"), 537 ("dbfs://@databricks:port/path/to/nowhere", "dbfs:/path/to/nowhere"), 538 ("dbfs://profile@databricks/path/to/nowhere", "dbfs:/path/to/nowhere"), 539 ("dbfs://profile@databricks:port/path/to/nowhere", "dbfs:/path/to/nowhere"), 540 ("dbfs://scope:key_prefix@databricks/path/abc", "dbfs:/path/abc"), 541 ("dbfs://scope:key_prefix@databricks:port/path/abc", "dbfs:/path/abc"), 542 # Treats secret key prefixes with ":" to be valid 543 ("dbfs://incorrect:netloc:format@databricks/path/to/nowhere", "dbfs:/path/to/nowhere"), 544 # Doesn't care about the scheme of the artifact URI 545 ("runs://scope:key_prefix@databricks/path/abc", "runs:/path/abc"), 546 ("models://scope:key_prefix@databricks/path/abc", "models:/path/abc"), 547 ("s3://scope:key_prefix@databricks/path/abc", "s3:/path/abc"), 548 ], 549 ) 550 def test_remove_databricks_profile_info_from_artifact_uri(uri, result): 551 assert remove_databricks_profile_info_from_artifact_uri(uri) == result 552 553 554 @pytest.mark.parametrize( 555 ("artifact_uri", "profile_uri", "result"), 556 [ 557 # test various profile URIs 558 ("dbfs:/path/a/b", "databricks", "dbfs://databricks/path/a/b"), 559 ("dbfs:/path/a/b/", "databricks", "dbfs://databricks/path/a/b/"), 560 ("dbfs:/path/a/b/", "databricks://Profile", "dbfs://Profile@databricks/path/a/b/"), 561 ("dbfs:/path/a/b/", "databricks://profile/", "dbfs://profile@databricks/path/a/b/"), 562 ("dbfs:/path/a/b/", "databricks://scope:key", "dbfs://scope:key@databricks/path/a/b/"), 563 ( 564 "dbfs:/path/a/b/", 565 "databricks://scope:key/random_stuff", 566 "dbfs://scope:key@databricks/path/a/b/", 567 ), 568 ("dbfs:/path/a/b/", "nondatabricks://profile", "dbfs:/path/a/b/"), 569 # test various artifact schemes 570 ("runs:/path/a/b/", "databricks://Profile", "runs://Profile@databricks/path/a/b/"), 571 ("runs:/path/a/b/", "nondatabricks://profile", "runs:/path/a/b/"), 572 ("models:/path/a/b/", "databricks://profile", "models://profile@databricks/path/a/b/"), 573 ("models:/path/a/b/", "nondatabricks://Profile", "models:/path/a/b/"), 574 ("s3:/path/a/b/", "databricks://Profile", "s3:/path/a/b/"), 575 ("s3:/path/a/b/", "nondatabricks://profile", "s3:/path/a/b/"), 576 ("ftp:/path/a/b/", "databricks://profile", "ftp:/path/a/b/"), 577 ("ftp:/path/a/b/", "nondatabricks://Profile", "ftp:/path/a/b/"), 578 # test artifact URIs already with authority 579 ("ftp://user:pass@host:port/a/b", "databricks://Profile", "ftp://user:pass@host:port/a/b"), 580 ("ftp://user:pass@host:port/a/b", "nothing://Profile", "ftp://user:pass@host:port/a/b"), 581 ("dbfs://databricks", "databricks://OtherProfile", "dbfs://databricks"), 582 ("dbfs://databricks", "nondatabricks://Profile", "dbfs://databricks"), 583 ("dbfs://databricks/path/a/b", "databricks://OtherProfile", "dbfs://databricks/path/a/b"), 584 ("dbfs://databricks/path/a/b", "nondatabricks://Profile", "dbfs://databricks/path/a/b"), 585 ("dbfs://@databricks/path/a/b", "databricks://OtherProfile", "dbfs://@databricks/path/a/b"), 586 ("dbfs://@databricks/path/a/b", "nondatabricks://Profile", "dbfs://@databricks/path/a/b"), 587 ( 588 "dbfs://profile@databricks/pp", 589 "databricks://OtherProfile", 590 "dbfs://profile@databricks/pp", 591 ), 592 ( 593 "dbfs://profile@databricks/path", 594 "databricks://profile", 595 "dbfs://profile@databricks/path", 596 ), 597 ( 598 "dbfs://profile@databricks/path", 599 "nondatabricks://Profile", 600 "dbfs://profile@databricks/path", 601 ), 602 ], 603 ) 604 def test_add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri, result): 605 assert add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri) == result 606 607 608 @pytest.mark.parametrize( 609 ("artifact_uri", "profile_uri"), 610 [ 611 ("dbfs:/path/a/b", "databricks://not:legit:auth"), 612 ("dbfs:/path/a/b/", "databricks://scope::key"), 613 ("dbfs:/path/a/b/", "databricks://scope:key:/"), 614 ("dbfs:/path/a/b/", "databricks://scope:key "), 615 ], 616 ) 617 def test_add_databricks_profile_info_to_artifact_uri_errors(artifact_uri, profile_uri): 618 with pytest.raises(MlflowException, match="Unsupported Databricks profile"): 619 add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri) 620 621 622 @pytest.mark.parametrize( 623 ("uri", "result"), 624 [ 625 ("dbfs:/path/a/b", True), 626 ("dbfs://databricks/a/b", True), 627 ("dbfs://@databricks/a/b", True), 628 ("dbfs://profile@databricks/a/b", True), 629 ("dbfs://scope:key@databricks/a/b", True), 630 ("dbfs://scope:key:@databricks/a/b", False), 631 ("dbfs://scope::key@databricks/a/b", False), 632 ("dbfs://profile@notdatabricks/a/b", False), 633 ("dbfs://scope:key@notdatabricks/a/b", False), 634 ("dbfs://scope:key/a/b", False), 635 ("dbfs://notdatabricks/a/b", False), 636 ("s3:/path/a/b", False), 637 ("ftp://user:pass@host:port/path/a/b", False), 638 ("ftp://user:pass@databricks/path/a/b", False), 639 ], 640 ) 641 def test_is_valid_dbfs_uri(uri, result): 642 assert is_valid_dbfs_uri(uri) == result 643 644 645 @pytest.mark.parametrize( 646 ("uri", "result"), 647 [ 648 ("/tmp/path", "/dbfs/tmp/path"), 649 ("dbfs:/path", "/dbfs/path"), 650 ("dbfs:/path/a/b", "/dbfs/path/a/b"), 651 ("dbfs:/dbfs/123/abc", "/dbfs/dbfs/123/abc"), 652 ], 653 ) 654 def test_dbfs_hdfs_uri_to_fuse_path(uri, result): 655 assert dbfs_hdfs_uri_to_fuse_path(uri) == result 656 657 658 @pytest.mark.parametrize( 659 "path", 660 ["some/relative/local/path", "s3:/some/s3/path", "C:/cool/windows/path"], 661 ) 662 def test_dbfs_hdfs_uri_to_fuse_path_raises(path): 663 with pytest.raises(MlflowException, match="did not start with expected DBFS URI prefix"): 664 dbfs_hdfs_uri_to_fuse_path(path) 665 666 667 def _assert_resolve_uri_if_local(input_uri, expected_uri): 668 cwd = pathlib.Path.cwd().as_posix() 669 drive = pathlib.Path.cwd().drive 670 if is_windows(): 671 cwd = f"/{cwd}" 672 drive = f"{drive}/" 673 assert resolve_uri_if_local(input_uri) == expected_uri.format(cwd=cwd, drive=drive) 674 675 676 @pytest.mark.skipif(is_windows(), reason="This test fails on Windows") 677 @pytest.mark.parametrize( 678 ("input_uri", "expected_uri"), 679 [ 680 ("my/path", "{cwd}/my/path"), 681 ("#my/path?a=b", "{cwd}/#my/path?a=b"), 682 ("file://localhost/my/path", "file://localhost/my/path"), 683 ("file:///my/path", "file:///{drive}my/path"), 684 ("file:my/path", "file://{cwd}/my/path"), 685 ("/home/my/path", "/home/my/path"), 686 ("dbfs://databricks/a/b", "dbfs://databricks/a/b"), 687 ("s3://host/my/path", "s3://host/my/path"), 688 ], 689 ) 690 def test_resolve_uri_if_local(input_uri, expected_uri): 691 _assert_resolve_uri_if_local(input_uri, expected_uri) 692 693 694 @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") 695 @pytest.mark.parametrize( 696 ("input_uri", "expected_uri"), 697 [ 698 ("my/path", "file://{cwd}/my/path"), 699 ("#my/path?a=b", "file://{cwd}/#my/path?a=b"), 700 ("\\myhostname/my/path", "file:///{drive}myhostname/my/path"), 701 ("file:///my/path", "file:///{drive}my/path"), 702 ("file:my/path", "file://{cwd}/my/path"), 703 ("/home/my/path", "file:///{drive}home/my/path"), 704 ("dbfs://databricks/a/b", "dbfs://databricks/a/b"), 705 ("s3://host/my/path", "s3://host/my/path"), 706 ], 707 ) 708 def test_resolve_uri_if_local_on_windows(input_uri, expected_uri): 709 _assert_resolve_uri_if_local(input_uri, expected_uri) 710 711 712 @pytest.mark.parametrize( 713 "uri", 714 [ 715 "/dbfs/my_path", 716 "dbfs:/my_path", 717 "/Volumes/my_path", 718 "/.fuse-mounts/my_path", 719 "//dbfs////my_path", 720 "///Volumes/", 721 "dbfs://my///path", 722 "/volumes/path/to/file", 723 "/volumes/", 724 "DBFS:/my/path", 725 ], 726 ) 727 def test_correctly_detect_fuse_and_uc_uris(uri): 728 assert is_fuse_or_uc_volumes_uri(uri) 729 730 731 @pytest.mark.parametrize( 732 "uri", 733 [ 734 "/My_Volumes/my_path", 735 "s3a:/my_path", 736 "Volumes/my_path", 737 "Volume:/my_path", 738 "dbfs/my_path", 739 "/fuse-mounts/my_path", 740 ], 741 ) 742 def test_negative_detection(uri): 743 assert not is_fuse_or_uc_volumes_uri(uri) 744 745 746 @pytest.mark.parametrize( 747 "path", 748 [ 749 "path", 750 "path/", 751 "path/to/file", 752 "dog%step%100%timestamp%100", 753 "dog+step+100+timestamp+100", 754 ], 755 ) 756 def test_validate_path_is_safe_good(path): 757 validate_path_is_safe(path) 758 759 760 @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") 761 @pytest.mark.parametrize( 762 "path", 763 [ 764 # relative path from current directory of C: drive 765 ".../...//", 766 ], 767 ) 768 def test_validate_path_is_safe_windows_good(path): 769 validate_path_is_safe(path) 770 771 772 @pytest.mark.skipif(is_windows(), reason="This test does not pass on Windows") 773 @pytest.mark.parametrize( 774 "path", 775 [ 776 "/path", 777 "../path", 778 "../../path", 779 "./../path", 780 "path/../to/file", 781 "path/../../to/file", 782 "file://a#/..//tmp", 783 "file://a%23/..//tmp/", 784 "/etc/passwd", 785 "/etc/passwd%00.jpg", 786 "/etc/passwd%00.html", 787 "/etc/passwd%00.txt", 788 "/etc/passwd%00.php", 789 "/etc/passwd%00.asp", 790 "/file://etc/passwd", 791 # Encoded paths with '..' 792 "%2E%2E%2Fpath", 793 "%2E%2E%2F%2E%2E%2Fpath", 794 # Some URIs are passed to urllib.parse.urlparse after validation, 795 # which strips out some whitespace characters. If they are further 796 # decoded, this could result in a path that is not safe. 797 # In this example, %2%0952e -> %2\t52e -> %252e -> %2e -> . 798 "%2%0952e%2%0952e/%2%0A52e%2%0A52e/path", 799 ], 800 ) 801 def test_validate_path_is_safe_bad(path): 802 with pytest.raises(MlflowException, match="Invalid path"): 803 validate_path_is_safe(path) 804 805 806 @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") 807 @pytest.mark.parametrize( 808 "path", 809 [ 810 r"../path", 811 r"../../path", 812 r"./../path", 813 r"path/../to/file", 814 r"path/../../to/file", 815 r"..\path", 816 r"..\..\path", 817 r".\..\path", 818 r"path\..\to\file", 819 r"path\..\..\to\file", 820 # Drive-relative paths 821 r"C:path", 822 r"C:path/", 823 r"C:path/to/file", 824 r"C:../path/to/file", 825 r"C:\path", 826 r"C:/path", 827 r"C:\path\to\file", 828 r"C:\path/to/file", 829 r"C:\path\..\to\file", 830 r"C:/path/../to/file", 831 # UNC(Universal Naming Convention) paths 832 r"\\path\to\file", 833 r"\\path/to/file", 834 r"\\.\\C:\path\to\file", 835 r"\\?\C:\path\to\file", 836 r"\\?\UNC/path/to/file", 837 # Other potential attackable paths 838 r"/etc/password", 839 r"/path", 840 r"/etc/passwd%00.jpg", 841 r"/etc/passwd%00.html", 842 r"/etc/passwd%00.txt", 843 r"/etc/passwd%00.php", 844 r"/etc/passwd%00.asp", 845 r"/Windows/no/such/path", 846 r"/file://etc/passwd", 847 r"/file:c:/passwd", 848 r"/file://d:/windows/win.ini", 849 r"/file://./windows/win.ini", 850 r"file://c:/boot.ini", 851 r"file://C:path", 852 r"file://C:path/", 853 r"file://C:path/to/file", 854 r"file:///C:/Windows/System32/", 855 r"file:///etc/passwd", 856 r"file:///d:/windows/repair/sam", 857 r"file:///proc/version", 858 r"file:///inetpub/wwwroot/global.asa", 859 r"/file://../windows/win.ini", 860 r"../etc/passwd", 861 r"..\Windows\System32\\", 862 r"C:\Windows\System32\\", 863 r"/etc/passwd", 864 r"::Windows\System32", 865 r"..\..\..\..\Windows\System32\\", 866 r"../Windows/System32", 867 r"....\\", 868 r"\\?\C:\Windows\System32\\", 869 r"\\.\C:\Windows\System32\\", 870 r"\\UNC\Server\Share\\", 871 r"\\Server\Share\folder\\", 872 r"\\127.0.0.1\c$\Windows\\", 873 r"\\localhost\c$\Windows\\", 874 r"\\smbserver\share\path\\", 875 r"..\\?\C:\Windows\System32\\", 876 r"C:/Windows/../Windows/System32/", 877 r"C:\Windows\..\Windows\System32\\", 878 r"../../../../../../../../../../../../Windows/System32", 879 r"../../../../../../../../../../../../etc/passwd", 880 r"../../../../../../../../../../../../var/www/html/index.html", 881 r"../../../../../../../../../../../../usr/local/etc/openvpn/server.conf", 882 r"../../../../../../../../../../../../Program Files (x86)", 883 r"/../../../../../../../../../../../../Windows/System32", 884 r"/Windows\../etc/passwd", 885 r"/Windows\..\Windows\System32\\", 886 r"/Windows\..\Windows\System32\cmd.exe", 887 r"/Windows\..\Windows\System32\msconfig.exe", 888 r"/Windows\..\Windows\System32\regedit.exe", 889 r"/Windows\..\Windows\System32\taskmgr.exe", 890 r"/Windows\..\Windows\System32\control.exe", 891 r"/Windows\..\Windows\System32\services.msc", 892 r"/Windows\..\Windows\System32\diskmgmt.msc", 893 r"/Windows\..\Windows\System32\eventvwr.msc", 894 r"/Windows/System32/drivers/etc/hosts", 895 ], 896 ) 897 def test_validate_path_is_safe_windows_bad(path): 898 with pytest.raises(MlflowException, match="Invalid path"): 899 validate_path_is_safe(path) 900 901 902 @pytest.mark.parametrize( 903 ("uri", "expected"), 904 [ 905 ("file:///path", "/path"), 906 ("file://host/path", "//host/path"), 907 ("file://host", "//host"), 908 ], 909 ) 910 def test_strip_scheme(uri: str, expected: str): 911 assert strip_scheme(uri) == expected 912 913 914 def test_validate_path_within_directory_allows_valid_paths(tmp_path): 915 base_dir = tmp_path / "artifacts" 916 base_dir.mkdir() 917 constructed_path = base_dir / "subdir" / "file.txt" 918 result = validate_path_within_directory(str(base_dir), str(constructed_path)) 919 assert result == str(constructed_path) 920 921 922 def test_validate_path_within_directory_blocks_symlink_escape(tmp_path): 923 base_dir = tmp_path / "artifacts" 924 base_dir.mkdir() 925 external_dir = tmp_path / "external" 926 external_dir.mkdir() 927 external_file = external_dir / "secret.txt" 928 external_file.write_text("SECRET") 929 symlink_path = base_dir / "leak" 930 os.symlink(str(external_dir), str(symlink_path)) 931 constructed_path = symlink_path / "secret.txt" 932 with pytest.raises(MlflowException, match="resolved path is outside the artifact directory"): 933 validate_path_within_directory(str(base_dir), str(constructed_path)) 934 935 936 def test_validate_path_within_directory_blocks_parent_symlink(tmp_path): 937 base_dir = tmp_path / "artifacts" 938 base_dir.mkdir() 939 symlink_path = base_dir / "parent" 940 os.symlink(str(tmp_path), str(symlink_path)) 941 constructed_path = symlink_path / "artifacts" / ".." / "external" 942 with pytest.raises(MlflowException, match="resolved path is outside the artifact directory"): 943 validate_path_within_directory(str(base_dir), str(constructed_path)) 944 945 946 def test_validate_path_within_directory_allows_internal_symlink(tmp_path): 947 base_dir = tmp_path / "artifacts" 948 base_dir.mkdir() 949 real_file = base_dir / "real_file.txt" 950 real_file.write_text("CONTENT") 951 symlink_path = base_dir / "link" 952 os.symlink(str(real_file), str(symlink_path)) 953 result = validate_path_within_directory(str(base_dir), str(symlink_path)) 954 assert result == str(symlink_path) 955 956 957 def test_validate_path_within_directory_allows_base_dir_itself(tmp_path): 958 base_dir = tmp_path / "artifacts" 959 base_dir.mkdir() 960 result = validate_path_within_directory(str(base_dir), str(base_dir)) 961 assert result == str(base_dir) 962 963 964 def test_validate_path_within_directory_allows_subdirectory_symlink(tmp_path): 965 base_dir = tmp_path / "artifacts" 966 base_dir.mkdir() 967 subdir = base_dir / "subdir" 968 subdir.mkdir() 969 file_in_subdir = subdir / "file.txt" 970 file_in_subdir.write_text("CONTENT") 971 symlink_to_subdir = base_dir / "link_to_subdir" 972 os.symlink(str(subdir), str(symlink_to_subdir)) 973 constructed_path = symlink_to_subdir / "file.txt" 974 result = validate_path_within_directory(str(base_dir), str(constructed_path)) 975 assert result == str(constructed_path)