test_validation.py
1 import copy 2 import socket 3 from unittest.mock import patch 4 5 import pytest 6 7 from mlflow.entities import Metric, Param, RunTag 8 from mlflow.environment_variables import MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH 9 from mlflow.exceptions import MlflowException 10 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode 11 from mlflow.utils.os import is_windows 12 from mlflow.utils.validation import ( 13 MAX_TAG_VAL_LENGTH, 14 _is_numeric, 15 _validate_batch_log_data, 16 _validate_batch_log_limits, 17 _validate_db_type_string, 18 _validate_experiment_artifact_location, 19 _validate_experiment_artifact_location_length, 20 _validate_experiment_name, 21 _validate_list_param, 22 _validate_metric_name, 23 _validate_model_alias_name, 24 _validate_model_alias_name_reserved, 25 _validate_model_name, 26 _validate_model_renaming, 27 _validate_param_name, 28 _validate_run_id, 29 _validate_tag_name, 30 _validate_webhook_url, 31 path_not_unique, 32 ) 33 34 GOOD_METRIC_OR_PARAM_NAMES = [ 35 "a", 36 "Ab-5_", 37 "a/b/c", 38 "a.b.c", 39 ".a", 40 "b.", 41 "a..a/._./o_O/.e.", 42 "a b/c d", 43 ] 44 BAD_METRIC_OR_PARAM_NAMES = [ 45 "", 46 ".", 47 "/", 48 "..", 49 "//", 50 "a//b", 51 "a/./b", 52 "/a", 53 "a/", 54 "\\", 55 "./", 56 "/./", 57 ] 58 59 GOOD_ALIAS_NAMES = [ 60 "a", 61 "Ab-5_", 62 "test-alias", 63 "1a2b5cDeFgH", 64 "a" * 255, 65 "lates", # spellchecker: disable-line 66 "v123_temp", 67 "123", 68 "123v", 69 "temp_V123", 70 ] 71 72 BAD_ALIAS_NAMES = [ 73 "", 74 ".", 75 "/", 76 "..", 77 "//", 78 "a b", 79 "a/./b", 80 "/a", 81 "a/", 82 ":", 83 "\\", 84 "./", 85 "/./", 86 "a" * 256, 87 None, 88 "$dgs", 89 ] 90 91 92 @pytest.mark.parametrize( 93 ("path", "expected"), 94 [ 95 ("a", False), 96 ("a/b/c", False), 97 ("a.b/c", False), 98 (".a", False), 99 # Not unique paths 100 ("./a", True), 101 ("a/b/../c", True), 102 (".", True), 103 ("../a/b", True), 104 ("/a/b/c", True), 105 ], 106 ) 107 def test_path_not_unique(path, expected): 108 assert path_not_unique(path) is expected 109 110 111 @pytest.mark.parametrize( 112 ("value", "expected"), 113 [ 114 (0, True), 115 (0.0, True), 116 # Non-numeric cases 117 (True, False), 118 (False, False), 119 ("0", False), 120 (None, False), 121 ], 122 ) 123 def test_is_numeric(value, expected): 124 assert _is_numeric(value) is expected 125 126 127 @pytest.mark.parametrize("metric_name", GOOD_METRIC_OR_PARAM_NAMES) 128 def test_validate_metric_name_good(metric_name): 129 _validate_metric_name(metric_name) 130 131 132 def _bad_parameter_pattern(name): 133 if name == "\\": 134 return r"Invalid value \"\\\\\" for parameter" # Manually handle the backslash case 135 elif name == "*****": 136 return r"Invalid value \"\*\*\*\*\*\" for parameter" 137 else: 138 return f'Invalid value "{name}" for parameter' 139 140 141 @pytest.mark.parametrize("metric_name", BAD_METRIC_OR_PARAM_NAMES) 142 def test_validate_metric_name_bad(metric_name): 143 with pytest.raises( 144 MlflowException, 145 match=_bad_parameter_pattern(metric_name), 146 ) as e: 147 _validate_metric_name(metric_name) 148 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 149 150 151 @pytest.mark.parametrize("param_name", GOOD_METRIC_OR_PARAM_NAMES) 152 def test_validate_param_name_good(param_name): 153 _validate_param_name(param_name) 154 155 156 @pytest.mark.parametrize("param_name", BAD_METRIC_OR_PARAM_NAMES) 157 def test_validate_param_name_bad(param_name): 158 with pytest.raises(MlflowException, match=_bad_parameter_pattern(param_name)) as e: 159 _validate_param_name(param_name) 160 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 161 162 163 @pytest.mark.skipif(not is_windows(), reason="Windows do not support colon in params and metrics") 164 @pytest.mark.parametrize( 165 "param_name", 166 [ 167 ":", 168 "aa:bb:cc", 169 ], 170 ) 171 def test_validate_colon_name_bad_windows(param_name): 172 with pytest.raises(MlflowException, match=_bad_parameter_pattern(param_name)) as e: 173 _validate_param_name(param_name) 174 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 175 176 177 @pytest.mark.parametrize("tag_name", GOOD_METRIC_OR_PARAM_NAMES) 178 def test_validate_tag_name_good(tag_name): 179 _validate_tag_name(tag_name) 180 181 182 @pytest.mark.parametrize("tag_name", BAD_METRIC_OR_PARAM_NAMES) 183 def test_validate_tag_name_bad(tag_name): 184 with pytest.raises(MlflowException, match=_bad_parameter_pattern(tag_name)) as e: 185 _validate_tag_name(tag_name) 186 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 187 188 189 @pytest.mark.parametrize("alias_name", GOOD_ALIAS_NAMES) 190 def test_validate_model_alias_name_good(alias_name): 191 _validate_model_alias_name(alias_name) 192 193 194 @pytest.mark.parametrize("alias_name", BAD_ALIAS_NAMES) 195 def test_validate_model_alias_name_bad(alias_name): 196 with pytest.raises(MlflowException, match="alias name") as e: 197 _validate_model_alias_name(alias_name) 198 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 199 200 201 @pytest.mark.parametrize("alias_name", ["latest", "LATEST", "Latest", "v123", "V1"]) 202 def test_validate_model_alias_name_reserved(alias_name): 203 with pytest.raises(MlflowException, match="reserved") as e: 204 _validate_model_alias_name_reserved(alias_name) 205 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 206 207 208 @pytest.mark.parametrize( 209 "run_id", 210 [ 211 "a" * 32, 212 "f0" * 16, 213 "abcdef0123456789" * 2, 214 "a" * 33, 215 "a" * 31, 216 "a" * 256, 217 "A" * 32, 218 "g" * 32, 219 "a_" * 32, 220 "abcdefghijklmnopqrstuvqxyz", 221 ], 222 ) 223 def test_validate_run_id_good(run_id): 224 _validate_run_id(run_id) 225 226 227 @pytest.mark.parametrize("run_id", ["a/bc" * 8, "", "a" * 400, "*" * 5]) 228 def test_validate_run_id_bad(run_id): 229 with pytest.raises(MlflowException, match=_bad_parameter_pattern(run_id)) as e: 230 _validate_run_id(run_id) 231 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 232 233 234 def test_validate_batch_log_limits(): 235 too_many_metrics = [Metric(f"metric-key-{i}", 1, 0, i * 2) for i in range(1001)] 236 too_many_params = [Param(f"param-key-{i}", "b") for i in range(101)] 237 too_many_tags = [RunTag(f"tag-key-{i}", "b") for i in range(101)] 238 239 good_kwargs = {"metrics": [], "params": [], "tags": []} 240 bad_kwargs = { 241 "metrics": [too_many_metrics], 242 "params": [too_many_params], 243 "tags": [too_many_tags], 244 } 245 match = r"A batch logging request can contain at most \d+" 246 for arg_name, arg_values in bad_kwargs.items(): 247 for arg_value in arg_values: 248 final_kwargs = copy.deepcopy(good_kwargs) 249 final_kwargs[arg_name] = arg_value 250 with pytest.raises(MlflowException, match=match): 251 _validate_batch_log_limits(**final_kwargs) 252 # Test the case where there are too many entities in aggregate 253 with pytest.raises(MlflowException, match=match): 254 _validate_batch_log_limits(too_many_metrics[:900], too_many_params[:51], too_many_tags[:50]) 255 # Test that we don't reject entities within the limit 256 _validate_batch_log_limits(too_many_metrics[:1000], [], []) 257 _validate_batch_log_limits([], too_many_params[:100], []) 258 _validate_batch_log_limits([], [], too_many_tags[:100]) 259 260 261 def test_validate_batch_log_data(monkeypatch): 262 metrics_with_bad_key = [ 263 Metric("good-metric-key", 1.0, 0, 0), 264 Metric("super-long-bad-key" * 1000, 4.0, 0, 0), 265 ] 266 metrics_with_bad_val = [Metric("good-metric-key", "not-a-double-val", 0, 0)] 267 metrics_with_bool_val = [Metric("good-metric-key", True, 0, 0)] 268 metrics_with_bad_ts = [Metric("good-metric-key", 1.0, "not-a-timestamp", 0)] 269 metrics_with_neg_ts = [Metric("good-metric-key", 1.0, -123, 0)] 270 metrics_with_bad_step = [Metric("good-metric-key", 1.0, 0, "not-a-step")] 271 params_with_bad_key = [ 272 Param("good-param-key", "hi"), 273 Param("super-long-bad-key" * 1000, "but-good-val"), 274 ] 275 params_with_bad_val = [ 276 Param("good-param-key", "hi"), 277 Param("another-good-key", "but-bad-val" * 1000), 278 ] 279 tags_with_bad_key = [ 280 RunTag("good-tag-key", "hi"), 281 RunTag("super-long-bad-key" * 1000, "but-good-val"), 282 ] 283 tags_with_bad_val = [ 284 RunTag("good-tag-key", "hi"), 285 RunTag("another-good-key", "a" * (MAX_TAG_VAL_LENGTH + 1)), 286 ] 287 bad_kwargs = { 288 "metrics": [ 289 metrics_with_bad_key, 290 metrics_with_bad_val, 291 metrics_with_bool_val, 292 metrics_with_bad_ts, 293 metrics_with_neg_ts, 294 metrics_with_bad_step, 295 ], 296 "params": [params_with_bad_key, params_with_bad_val], 297 "tags": [tags_with_bad_key, tags_with_bad_val], 298 } 299 good_kwargs = {"metrics": [], "params": [], "tags": []} 300 monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false") 301 for arg_name, arg_values in bad_kwargs.items(): 302 for arg_value in arg_values: 303 final_kwargs = copy.deepcopy(good_kwargs) 304 final_kwargs[arg_name] = arg_value 305 with pytest.raises(MlflowException, match=r".+"): 306 _validate_batch_log_data(**final_kwargs) 307 # Test that we don't reject entities within the limit 308 _validate_batch_log_data( 309 metrics=[Metric("metric-key", 1.0, 0, 0)], 310 params=[Param("param-key", "param-val")], 311 tags=[RunTag("tag-key", "tag-val")], 312 ) 313 314 315 @pytest.mark.parametrize("location", ["abcde", None]) 316 def test_validate_experiment_artifact_location_good(location): 317 _validate_experiment_artifact_location(location) 318 319 320 @pytest.mark.parametrize("location", ["runs:/blah/bleh/blergh"]) 321 def test_validate_experiment_artifact_location_bad(location): 322 with pytest.raises(MlflowException, match="Artifact location cannot be a runs:/ URI"): 323 _validate_experiment_artifact_location(location) 324 325 326 @pytest.mark.parametrize("experiment_name", ["validstring", b"test byte string".decode("utf-8")]) 327 def test_validate_experiment_name_good(experiment_name): 328 _validate_experiment_name(experiment_name) 329 330 331 @pytest.mark.parametrize("experiment_name", ["", 12, 12.7, None, {}, []]) 332 def test_validate_experiment_name_bad(experiment_name): 333 with pytest.raises(MlflowException, match="Invalid experiment name"): 334 _validate_experiment_name(experiment_name) 335 336 337 @pytest.mark.parametrize("db_type", ["mysql", "mssql", "postgresql", "sqlite"]) 338 def test_validate_db_type_string_good(db_type): 339 _validate_db_type_string(db_type) 340 341 342 @pytest.mark.parametrize("db_type", ["MySQL", "mongo", "cassandra", "sql", ""]) 343 def test_validate_db_type_string_bad(db_type): 344 with pytest.raises(MlflowException, match="Invalid database engine") as e: 345 _validate_db_type_string(db_type) 346 assert "Invalid database engine" in e.value.message 347 348 349 @pytest.mark.parametrize( 350 "artifact_location", 351 [ 352 "s3://test-bucket/", 353 "file:///path/to/artifacts", 354 "mlflow-artifacts:/path/to/artifacts", 355 "dbfs:/databricks/mlflow-tracking/some-id", 356 ], 357 ) 358 def test_validate_experiment_artifact_location_length_good(artifact_location): 359 _validate_experiment_artifact_location_length(artifact_location) 360 361 362 @pytest.mark.parametrize( 363 "artifact_location", 364 ["s3://test-bucket/" + "a" * 10000, "file:///path/to/" + "directory" * 1111], 365 ids=["s3_long_path", "file_long_path"], 366 ) 367 def test_validate_experiment_artifact_location_length_bad(artifact_location): 368 with pytest.raises(MlflowException, match="Invalid artifact path length"): 369 _validate_experiment_artifact_location_length(artifact_location) 370 371 372 def test_setting_experiment_artifact_location_env_var_works(monkeypatch): 373 artifact_location = "file://aaaa" # length 11 374 375 # should not throw 376 _validate_experiment_artifact_location_length(artifact_location) 377 378 # reduce limit to 10 379 monkeypatch.setenv(MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH.name, "10") 380 with pytest.raises(MlflowException, match="Invalid artifact path length"): 381 _validate_experiment_artifact_location_length(artifact_location) 382 383 # increase limit to 11 384 monkeypatch.setenv(MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH.name, "11") 385 _validate_experiment_artifact_location_length(artifact_location) 386 387 388 @pytest.mark.parametrize( 389 "param_value", 390 [ 391 ["1", "2", "3"], 392 [], 393 [1, 2, 3], 394 ], 395 ) 396 def test_validate_list_param_with_valid_list(param_value): 397 _validate_list_param("experiment_ids", param_value) 398 399 400 def test_validate_list_param_with_none_not_allowed(): 401 with pytest.raises(MlflowException, match="experiment_ids must be a list"): 402 _validate_list_param("experiment_ids", None, allow_none=False) 403 404 405 def test_validate_list_param_with_none_allowed(): 406 _validate_list_param("experiment_ids", None, allow_none=True) 407 408 409 @pytest.mark.parametrize( 410 ("param_name", "param_value", "expected_type"), 411 [ 412 ("experiment_ids", 4, "int"), 413 ("param_name", "value", "str"), 414 ("my_param", {"key": "value"}, "dict"), 415 ], 416 ) 417 def test_validate_list_param_with_invalid_type(param_name, param_value, expected_type): 418 with pytest.raises( 419 MlflowException, match=rf"{param_name} must be a list, got {expected_type}" 420 ) as exc_info: 421 _validate_list_param(param_name, param_value) 422 assert f"Did you mean to use {param_name}=[{param_value!r}]?" in str(exc_info.value) 423 assert exc_info.value.error_code == "INVALID_PARAMETER_VALUE" 424 425 426 # -- _validate_webhook_url tests -- 427 428 429 def _mock_getaddrinfo(ip_str): 430 return lambda host, port, *a, **kw: [(None, None, None, None, (ip_str, 0))] 431 432 433 @pytest.mark.parametrize( 434 ("url", "expected_match"), 435 [ 436 (123, "Webhook URL must be a string"), 437 ("", "Webhook URL cannot be empty"), 438 (" ", "Webhook URL cannot be empty"), 439 ("ftp://example.com", "Invalid webhook URL scheme"), 440 ("http://example.com", "Invalid webhook URL scheme"), 441 ("https://", "must include a hostname"), 442 ], 443 ) 444 def test_validate_webhook_url_rejects_invalid_input(url, expected_match): 445 with pytest.raises(MlflowException, match=expected_match): 446 _validate_webhook_url(url) 447 448 449 @pytest.mark.parametrize( 450 ("url", "resolved_ip"), 451 [ 452 ("https://127.0.0.1/callback", "127.0.0.1"), 453 ("https://localhost/callback", "127.0.0.1"), 454 ("https://internal.corp/hook", "10.0.0.1"), 455 ("https://internal.corp/hook", "172.16.0.1"), 456 ("https://internal.corp/hook", "192.168.1.1"), 457 ("https://metadata.internal/hook", "169.254.169.254"), 458 ("https://cgnat.internal/hook", "100.64.0.1"), 459 ("https://ipv6-loopback.internal/hook", "::1"), 460 ("https://ipv6-private.internal/hook", "fc00::1"), 461 ], 462 ) 463 def test_validate_webhook_url_rejects_private_ips(url, resolved_ip): 464 with patch( 465 "mlflow.utils.validation.socket.getaddrinfo", 466 side_effect=_mock_getaddrinfo(resolved_ip), 467 ): 468 with pytest.raises(MlflowException, match="must not resolve to a non-public"): 469 _validate_webhook_url(url) 470 471 472 def test_validate_webhook_url_rejects_unresolvable_hostname(): 473 with patch( 474 "mlflow.utils.validation.socket.getaddrinfo", 475 side_effect=socket.gaierror("Name or service not known"), 476 ): 477 with pytest.raises(MlflowException, match="Cannot resolve webhook URL hostname"): 478 _validate_webhook_url("https://does-not-exist.invalid/hook") 479 480 481 def test_validate_webhook_url_rejects_if_any_resolved_address_is_private(): 482 def multi_resolve(host, port, *a, **kw): 483 return [ 484 (None, None, None, None, ("8.8.8.8", 0)), 485 (None, None, None, None, ("10.0.0.1", 0)), 486 ] 487 488 with patch("mlflow.utils.validation.socket.getaddrinfo", side_effect=multi_resolve): 489 with pytest.raises(MlflowException, match="must not resolve to a non-public"): 490 _validate_webhook_url("https://dual-homed.example.com/hook") 491 492 493 def test_validate_webhook_url_accepts_public_ip(): 494 with patch( 495 "mlflow.utils.validation.socket.getaddrinfo", 496 side_effect=_mock_getaddrinfo("8.8.8.8"), 497 ): 498 _validate_webhook_url("https://example.com/webhook") 499 500 501 def test_validate_webhook_url_allow_private_ips_env_var(monkeypatch): 502 monkeypatch.setenv("MLFLOW_WEBHOOK_ALLOW_PRIVATE_IPS", "true") 503 with patch( 504 "mlflow.utils.validation.socket.getaddrinfo", 505 side_effect=_mock_getaddrinfo("127.0.0.1"), 506 ): 507 _validate_webhook_url("https://localhost/callback") 508 509 510 @pytest.mark.parametrize("invalid_name", ["my/model", "model:v1", "name/with:both"]) 511 def test_validate_model_name_invalid_chars(invalid_name): 512 with pytest.raises( 513 MlflowException, 514 match="Names cannot contain '/' or ':'", 515 check=lambda e: e.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE), 516 ): 517 _validate_model_name(invalid_name) 518 519 520 @pytest.mark.parametrize("invalid_name", ["my/model", "model:v1", "name/with:both"]) 521 def test_validate_model_renaming_invalid_chars(invalid_name): 522 with pytest.raises( 523 MlflowException, 524 match="Names cannot contain '/' or ':'", 525 check=lambda e: e.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE), 526 ): 527 _validate_model_renaming(invalid_name)