test_rest_utils.py
1 import re 2 import time 3 import warnings 4 from unittest import mock 5 6 import numpy 7 import pytest 8 import requests 9 10 from mlflow.deployments.databricks import DatabricksDeploymentClient 11 from mlflow.environment_variables import ( 12 _MLFLOW_DATABRICKS_TRAFFIC_ID, 13 MLFLOW_HTTP_REQUEST_TIMEOUT, 14 ) 15 from mlflow.exceptions import InvalidUrlException, MlflowException, RestException 16 from mlflow.protos.databricks_pb2 import ENDPOINT_NOT_FOUND, ErrorCode 17 from mlflow.protos.service_pb2 import GetRun 18 from mlflow.pyfunc.scoring_server import NumpyEncoder 19 from mlflow.tracking.request_header.default_request_header_provider import ( 20 _CLIENT_VERSION, 21 _USER_AGENT, 22 DefaultRequestHeaderProvider, 23 ) 24 from mlflow.utils.rest_utils import ( 25 _DATABRICKS_SDK_RETRY_AFTER_SECS_DEPRECATION_WARNING, 26 MlflowHostCreds, 27 _can_parse_as_json_object, 28 augmented_raise_for_status, 29 call_endpoint, 30 call_endpoints, 31 get_workspace_client, 32 http_request, 33 http_request_safe, 34 ) 35 from mlflow.utils.workspace_context import WorkspaceContext 36 from mlflow.utils.workspace_utils import WORKSPACE_HEADER_NAME 37 38 from tests import helper_functions 39 40 41 @pytest.mark.parametrize( 42 "response_mock", 43 [ 44 helper_functions.create_mock_response(400, "Error message but not a JSON string"), 45 helper_functions.create_mock_response(400, ""), 46 helper_functions.create_mock_response(400, None), 47 ], 48 ) 49 def test_malformed_json_error_response(response_mock): 50 with mock.patch("requests.Session.request", return_value=response_mock): 51 host_only = MlflowHostCreds("http://my-host") 52 53 response_proto = GetRun.Response() 54 with pytest.raises( 55 MlflowException, match="API request to endpoint /my/endpoint failed with error code 400" 56 ): 57 call_endpoint(host_only, "/my/endpoint", "GET", None, response_proto) 58 59 60 def test_call_endpoints(): 61 with mock.patch("mlflow.utils.rest_utils.call_endpoint") as mock_call_endpoint: 62 response_proto = GetRun.Response() 63 mock_call_endpoint.side_effect = [ 64 RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}), 65 None, 66 ] 67 host_only = MlflowHostCreds("http://my-host") 68 endpoints = [("/my/endpoint", "POST"), ("/my/endpoint", "GET")] 69 resp = call_endpoints(host_only, endpoints, "", response_proto) 70 mock_call_endpoint.assert_has_calls([ 71 mock.call(host_only, endpoint, method, "", response_proto, None) 72 for endpoint, method in endpoints 73 ]) 74 assert resp is None 75 76 77 def test_call_endpoints_raises_exceptions(): 78 with mock.patch("mlflow.utils.rest_utils.call_endpoint") as mock_call_endpoint: 79 response_proto = GetRun.Response() 80 mock_call_endpoint.side_effect = [ 81 RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}), 82 RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}), 83 ] 84 host_only = MlflowHostCreds("http://my-host") 85 endpoints = [("/my/endpoint", "POST"), ("/my/endpoint", "GET")] 86 with pytest.raises(RestException, match="ENDPOINT_NOT_FOUND"): 87 call_endpoints(host_only, endpoints, "", response_proto) 88 mock_call_endpoint.side_effect = [RestException({}), None] 89 with pytest.raises(RestException, match="INTERNAL_ERROR"): 90 call_endpoints(host_only, endpoints, "", response_proto) 91 92 93 def test_http_request_hostonly(): 94 host_only = MlflowHostCreds("http://my-host") 95 response = mock.MagicMock() 96 response.status_code = 200 97 with mock.patch("requests.Session.request", return_value=response) as mock_request: 98 http_request(host_only, "/my/endpoint", "GET") 99 mock_request.assert_called_with( 100 "GET", 101 "http://my-host/my/endpoint", 102 allow_redirects=True, 103 verify=True, 104 headers=DefaultRequestHeaderProvider().request_headers(), 105 timeout=120, 106 ) 107 108 109 def test_http_request_includes_workspace_header_for_mlflow_routes(): 110 host_only = MlflowHostCreds("http://my-host") 111 response = mock.MagicMock() 112 response.status_code = 200 113 with WorkspaceContext("team-a"): 114 with mock.patch("requests.Session.request", return_value=response) as mock_request: 115 http_request(host_only, "/api/2.0/mlflow/runs/search", "GET") 116 headers = mock_request.call_args.kwargs["headers"] 117 assert headers[WORKSPACE_HEADER_NAME] == "team-a" 118 119 120 def test_http_request_omits_workspace_header_for_workspace_admin_routes(): 121 host_only = MlflowHostCreds("http://my-host") 122 response = mock.MagicMock() 123 response.status_code = 200 124 with WorkspaceContext("team-a"): 125 with mock.patch("requests.Session.request", return_value=response) as mock_request: 126 http_request(host_only, "/api/3.0/mlflow/workspaces/team-a", "GET") 127 headers = mock_request.call_args.kwargs["headers"] 128 assert WORKSPACE_HEADER_NAME not in headers 129 130 131 def test_http_request_cleans_hostname(): 132 # Add a trailing slash, should be removed. 133 host_only = MlflowHostCreds("http://my-host/") 134 response = mock.MagicMock() 135 response.status_code = 200 136 with mock.patch("requests.Session.request", return_value=response) as mock_request: 137 http_request(host_only, "/my/endpoint", "GET") 138 mock_request.assert_called_with( 139 "GET", 140 "http://my-host/my/endpoint", 141 allow_redirects=True, 142 verify=True, 143 headers=DefaultRequestHeaderProvider().request_headers(), 144 timeout=120, 145 ) 146 147 148 def test_http_request_with_basic_auth(): 149 host_only = MlflowHostCreds("http://my-host", username="user", password="pass") 150 response = mock.MagicMock() 151 response.status_code = 200 152 with mock.patch("requests.Session.request", return_value=response) as mock_request: 153 http_request(host_only, "/my/endpoint", "GET") 154 headers = DefaultRequestHeaderProvider().request_headers() 155 headers["Authorization"] = "Basic dXNlcjpwYXNz" 156 mock_request.assert_called_with( 157 "GET", 158 "http://my-host/my/endpoint", 159 allow_redirects=True, 160 verify=True, 161 headers=headers, 162 timeout=120, 163 ) 164 165 166 def test_http_request_with_aws_sigv4(monkeypatch): 167 from requests_auth_aws_sigv4 import AWSSigV4 168 169 monkeypatch.setenv("AWS_ACCESS_KEY_ID", "access-key") 170 monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret-key") 171 monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1") 172 aws_sigv4 = MlflowHostCreds("http://my-host", aws_sigv4=True) 173 response = mock.MagicMock() 174 response.status_code = 200 175 176 class AuthMatcher: 177 def __eq__(self, other): 178 return isinstance(other, AWSSigV4) 179 180 with mock.patch("requests.Session.request", return_value=response) as mock_request: 181 http_request(aws_sigv4, "/my/endpoint", "GET") 182 mock_request.assert_called_once_with( 183 "GET", 184 "http://my-host/my/endpoint", 185 allow_redirects=True, 186 verify=mock.ANY, 187 headers=mock.ANY, 188 timeout=mock.ANY, 189 auth=AuthMatcher(), 190 ) 191 192 193 def test_http_request_with_auth(): 194 mock_fetch_auth = {"test_name": "test_auth_value"} 195 auth = "test_auth_name" 196 host_only = MlflowHostCreds("http://my-host", auth=auth) 197 response = mock.MagicMock() 198 response.status_code = 200 199 with ( 200 mock.patch("requests.Session.request", return_value=response) as mock_request, 201 mock.patch( 202 "mlflow.tracking.request_auth.registry.fetch_auth", return_value=mock_fetch_auth 203 ) as mock_fetch_auth_call, 204 ): 205 http_request(host_only, "/my/endpoint", "GET") 206 207 mock_fetch_auth_call.assert_called_with(auth) 208 209 mock_request.assert_called_with( 210 "GET", 211 "http://my-host/my/endpoint", 212 allow_redirects=True, 213 verify=mock.ANY, 214 headers=mock.ANY, 215 timeout=mock.ANY, 216 auth=mock_fetch_auth, 217 ) 218 219 220 def test_http_request_with_token(): 221 host_only = MlflowHostCreds("http://my-host", token="my-token") 222 response = mock.MagicMock() 223 response.status_code = 200 224 with mock.patch("requests.Session.request", return_value=response) as mock_request: 225 http_request(host_only, "/my/endpoint", "GET") 226 headers = DefaultRequestHeaderProvider().request_headers() 227 headers["Authorization"] = "Bearer my-token" 228 mock_request.assert_called_with( 229 "GET", 230 "http://my-host/my/endpoint", 231 allow_redirects=True, 232 verify=True, 233 headers=headers, 234 timeout=120, 235 ) 236 237 238 def test_http_request_with_insecure(): 239 host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) 240 response = mock.MagicMock() 241 response.status_code = 200 242 with mock.patch("requests.Session.request", return_value=response) as mock_request: 243 http_request(host_only, "/my/endpoint", "GET") 244 mock_request.assert_called_with( 245 "GET", 246 "http://my-host/my/endpoint", 247 allow_redirects=True, 248 verify=False, 249 headers=DefaultRequestHeaderProvider().request_headers(), 250 timeout=120, 251 ) 252 253 254 def test_http_request_client_cert_path(): 255 host_only = MlflowHostCreds("http://my-host", client_cert_path="/some/path") 256 response = mock.MagicMock() 257 response.status_code = 200 258 with mock.patch("requests.Session.request", return_value=response) as mock_request: 259 http_request(host_only, "/my/endpoint", "GET") 260 mock_request.assert_called_with( 261 "GET", 262 "http://my-host/my/endpoint", 263 allow_redirects=True, 264 verify=True, 265 cert="/some/path", 266 headers=DefaultRequestHeaderProvider().request_headers(), 267 timeout=120, 268 ) 269 270 271 def test_http_request_server_cert_path(): 272 host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path") 273 response = mock.MagicMock() 274 response.status_code = 200 275 with mock.patch("requests.Session.request", return_value=response) as mock_request: 276 http_request(host_only, "/my/endpoint", "GET") 277 mock_request.assert_called_with( 278 "GET", 279 "http://my-host/my/endpoint", 280 allow_redirects=True, 281 verify="/some/path", 282 headers=DefaultRequestHeaderProvider().request_headers(), 283 timeout=120, 284 ) 285 286 287 def test_http_request_with_content_type_header(): 288 host_only = MlflowHostCreds("http://my-host", token="my-token") 289 response = mock.MagicMock() 290 response.status_code = 200 291 extra_headers = {"Content-Type": "text/plain"} 292 with mock.patch("requests.Session.request", return_value=response) as mock_request: 293 http_request(host_only, "/my/endpoint", "GET", extra_headers=extra_headers) 294 headers = DefaultRequestHeaderProvider().request_headers() 295 headers["Authorization"] = "Bearer my-token" 296 headers["Content-Type"] = "text/plain" 297 mock_request.assert_called_with( 298 "GET", 299 "http://my-host/my/endpoint", 300 allow_redirects=True, 301 verify=True, 302 headers=headers, 303 timeout=120, 304 ) 305 306 307 def test_http_request_request_headers(): 308 from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider 309 310 # The test plugin's request header provider always returns False from in_context to avoid 311 # polluting request headers in developers' environments. The following mock overrides this to 312 # perform the integration test. 313 response = mock.MagicMock() 314 response.status_code = 200 315 with ( 316 mock.patch("requests.Session.request", return_value=response) as mock_request, 317 mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True), 318 ): 319 host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path") 320 http_request(host_only, "/my/endpoint", "GET") 321 mock_request.assert_called_with( 322 "GET", 323 "http://my-host/my/endpoint", 324 allow_redirects=True, 325 verify="/some/path", 326 headers={**DefaultRequestHeaderProvider().request_headers(), "test": "header"}, 327 timeout=120, 328 ) 329 330 331 def test_http_request_request_headers_default(): 332 from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider 333 334 # The test plugin's request header provider always returns False from in_context to avoid 335 # polluting request headers in developers' environments. The following mock overrides this to 336 # perform the integration test. 337 host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path") 338 default_headers = DefaultRequestHeaderProvider().request_headers() 339 expected_headers = { 340 _USER_AGENT: "{} {}".format(default_headers[_USER_AGENT], "test_user_agent"), 341 _CLIENT_VERSION: "{} {}".format(default_headers[_CLIENT_VERSION], "test_client_version"), 342 } 343 344 response = mock.MagicMock() 345 response.status_code = 200 346 with ( 347 mock.patch("requests.Session.request", return_value=response) as mock_request, 348 mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True), 349 mock.patch.object( 350 PluginRequestHeaderProvider, 351 "request_headers", 352 return_value={_USER_AGENT: "test_user_agent", _CLIENT_VERSION: "test_client_version"}, 353 ), 354 ): 355 http_request(host_only, "/my/endpoint", "GET") 356 mock_request.assert_called_with( 357 "GET", 358 "http://my-host/my/endpoint", 359 allow_redirects=True, 360 verify="/some/path", 361 headers=expected_headers, 362 timeout=120, 363 ) 364 365 366 def test_http_request_request_headers_default_and_extra_header(): 367 from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider 368 369 # The test plugin's request header provider always returns False from in_context to avoid 370 # polluting request headers in developers' environments. The following mock overrides this to 371 # perform the integration test. 372 host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path") 373 default_headers = DefaultRequestHeaderProvider().request_headers() 374 expected_headers = { 375 _USER_AGENT: "{} {}".format(default_headers[_USER_AGENT], "test_user_agent"), 376 _CLIENT_VERSION: "{} {}".format(default_headers[_CLIENT_VERSION], "test_client_version"), 377 "header": "value", 378 } 379 380 response = mock.MagicMock() 381 response.status_code = 200 382 with ( 383 mock.patch("requests.Session.request", return_value=response) as mock_request, 384 mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True), 385 mock.patch.object( 386 PluginRequestHeaderProvider, 387 "request_headers", 388 return_value={ 389 _USER_AGENT: "test_user_agent", 390 _CLIENT_VERSION: "test_client_version", 391 "header": "value", 392 }, 393 ), 394 ): 395 http_request(host_only, "/my/endpoint", "GET") 396 mock_request.assert_called_with( 397 "GET", 398 "http://my-host/my/endpoint", 399 allow_redirects=True, 400 verify="/some/path", 401 headers=expected_headers, 402 timeout=120, 403 ) 404 405 406 def test_http_request_with_invalid_url_raise_invalid_url_exception(): 407 host_only = MlflowHostCreds("http://my-host") 408 409 with pytest.raises(InvalidUrlException, match="Invalid url: http://my-host/invalid_url"): 410 with mock.patch("requests.Session.request", side_effect=requests.exceptions.InvalidURL): 411 http_request(host_only, "/invalid_url", "GET") 412 413 414 def test_http_request_with_invalid_url_raise_mlflow_exception(): 415 host_only = MlflowHostCreds("http://my-host") 416 417 with pytest.raises(MlflowException, match="Invalid url: http://my-host/invalid_url"): 418 with mock.patch("requests.Session.request", side_effect=requests.exceptions.InvalidURL): 419 http_request(host_only, "/invalid_url", "GET") 420 421 422 def test_ignore_tls_verification_not_server_cert_path(): 423 with pytest.raises( 424 MlflowException, 425 match="When 'ignore_tls_verification' is true then 'server_cert_path' must not be set", 426 ): 427 MlflowHostCreds( 428 "http://my-host", 429 ignore_tls_verification=True, 430 server_cert_path="/some/path", 431 ) 432 433 434 def test_http_request_wrapper(): 435 host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True) 436 response = mock.MagicMock() 437 response.status_code = 200 438 response.text = "{}" 439 with mock.patch("requests.Session.request", return_value=response) as mock_request: 440 http_request_safe(host_only, "/my/endpoint", "GET") 441 mock_request.assert_called_with( 442 "GET", 443 "http://my-host/my/endpoint", 444 allow_redirects=True, 445 verify=False, 446 headers=DefaultRequestHeaderProvider().request_headers(), 447 timeout=120, 448 ) 449 response.text = "non json" 450 http_request_safe(host_only, "/my/endpoint", "GET") 451 mock_request.assert_called_with( 452 "GET", 453 "http://my-host/my/endpoint", 454 allow_redirects=True, 455 verify=False, 456 headers=DefaultRequestHeaderProvider().request_headers(), 457 timeout=120, 458 ) 459 response.status_code = 400 460 response.text = "" 461 with pytest.raises(MlflowException, match="Response body"): 462 http_request_safe(host_only, "/my/endpoint", "GET") 463 response.text = ( 464 '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}' 465 ) 466 with pytest.raises(RestException, match="RESOURCE_DOES_NOT_EXIST: Node type not supported"): 467 http_request_safe(host_only, "/my/endpoint", "GET") 468 469 470 def test_numpy_encoder(): 471 test_number = numpy.int64(42) 472 ne = NumpyEncoder() 473 defaulted_val = ne.default(test_number) 474 assert defaulted_val == 42 475 476 477 def test_numpy_encoder_fail(): 478 if not hasattr(numpy, "float128"): 479 pytest.skip("numpy on exit this platform has no float128") 480 test_number = numpy.float128 481 ne = NumpyEncoder() 482 with pytest.raises(TypeError, match="not JSON serializable"): 483 ne.default(test_number) 484 485 486 def test_can_parse_as_json_object(): 487 assert _can_parse_as_json_object("{}") 488 assert _can_parse_as_json_object('{"a": "b"}') 489 assert _can_parse_as_json_object('{"a": {"b": "c"}}') 490 assert not _can_parse_as_json_object("[0, 1, 2]") 491 assert not _can_parse_as_json_object('"abc"') 492 assert not _can_parse_as_json_object("123") 493 494 495 def test_http_request_customize_config(monkeypatch): 496 with mock.patch( 497 "mlflow.utils.rest_utils._get_http_response_with_retries" 498 ) as mock_get_http_response_with_retries: 499 host_only = MlflowHostCreds("http://my-host") 500 monkeypatch.delenv("MLFLOW_HTTP_REQUEST_MAX_RETRIES", raising=False) 501 monkeypatch.delenv("MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR", raising=False) 502 monkeypatch.delenv("MLFLOW_HTTP_REQUEST_TIMEOUT", raising=False) 503 monkeypatch.delenv("MLFLOW_HTTP_RESPECT_RETRY_AFTER_HEADER", raising=False) 504 http_request(host_only, "/my/endpoint", "GET") 505 mock_get_http_response_with_retries.assert_called_with( 506 mock.ANY, 507 mock.ANY, 508 7, 509 2, 510 1.0, 511 mock.ANY, 512 True, 513 headers=mock.ANY, 514 verify=mock.ANY, 515 timeout=120, 516 respect_retry_after_header=True, 517 ) 518 mock_get_http_response_with_retries.reset_mock() 519 monkeypatch.setenv("MLFLOW_HTTP_REQUEST_MAX_RETRIES", "8") 520 monkeypatch.setenv("MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR", "3") 521 monkeypatch.setenv("MLFLOW_HTTP_REQUEST_BACKOFF_JITTER", "1.0") 522 monkeypatch.setenv("MLFLOW_HTTP_REQUEST_TIMEOUT", "300") 523 monkeypatch.setenv("MLFLOW_HTTP_RESPECT_RETRY_AFTER_HEADER", "false") 524 http_request(host_only, "/my/endpoint", "GET") 525 mock_get_http_response_with_retries.assert_called_with( 526 mock.ANY, 527 mock.ANY, 528 8, 529 3, 530 1.0, 531 mock.ANY, 532 True, 533 headers=mock.ANY, 534 verify=mock.ANY, 535 timeout=300, 536 respect_retry_after_header=False, 537 ) 538 539 540 def test_http_request_explains_how_to_increase_timeout_in_error_message(): 541 with mock.patch("requests.Session.request", side_effect=requests.exceptions.Timeout): 542 with pytest.raises( 543 MlflowException, 544 match=( 545 r"To increase the timeout, set the environment variable " 546 + re.escape(str(MLFLOW_HTTP_REQUEST_TIMEOUT)) 547 ), 548 ): 549 http_request(MlflowHostCreds("http://my-host"), "/my/endpoint", "GET") 550 551 552 def test_augmented_raise_for_status(): 553 response = requests.Response() 554 response.status_code = 403 555 response._content = b"Token expired" 556 557 with mock.patch("requests.Session.request", return_value=response) as mock_request: 558 response = requests.get("https://github.com/mlflow/mlflow.git") 559 mock_request.assert_called_once() 560 561 with pytest.raises(requests.HTTPError, match="Token expired") as e: 562 augmented_raise_for_status(response) 563 564 assert e.value.response == response 565 assert e.value.request == response.request 566 assert response.text in str(e.value) 567 568 569 def test_provide_redirect_kwarg(): 570 with mock.patch("requests.Session.request") as mock_request: 571 mock_request.return_value.status_code = 302 572 mock_request.return_value.text = "mock response" 573 574 response = http_request( 575 MlflowHostCreds("http://my-host"), 576 "/my/endpoint", 577 "GET", 578 allow_redirects=False, 579 ) 580 581 assert response.text == "mock response" 582 mock_request.assert_called_with( 583 "GET", 584 "http://my-host/my/endpoint", 585 allow_redirects=False, 586 headers=mock.ANY, 587 verify=mock.ANY, 588 timeout=120, 589 ) 590 591 592 def test_http_request_max_retries(monkeypatch): 593 monkeypatch.setenv("_MLFLOW_HTTP_REQUEST_MAX_RETRIES_LIMIT", "15") 594 host_creds = MlflowHostCreds("http://example.com") 595 596 with mock.patch("requests.Session.request") as mock_request: 597 # Value exceeding limit should raise 598 with pytest.raises(MlflowException, match="The configured max_retries"): 599 http_request(host_creds, "/endpoint", "GET", max_retries=16) 600 mock_request.assert_not_called() 601 602 # Value equal to limit should succeed (boundary case) 603 http_request(host_creds, "/endpoint", "GET", max_retries=15) 604 assert mock_request.call_count == 1 605 606 # Value below limit should succeed 607 http_request(host_creds, "/endpoint", "GET", max_retries=3) 608 assert mock_request.call_count == 2 609 610 611 def test_http_request_backoff_factor(monkeypatch): 612 monkeypatch.setenv("_MLFLOW_HTTP_REQUEST_MAX_BACKOFF_FACTOR_LIMIT", "200") 613 host_creds = MlflowHostCreds("http://example.com") 614 615 with mock.patch("requests.Session.request") as mock_request: 616 # Value exceeding limit should raise 617 with pytest.raises(MlflowException, match="The configured backoff_factor"): 618 http_request(host_creds, "/endpoint", "GET", backoff_factor=250) 619 mock_request.assert_not_called() 620 621 # Value equal to limit should succeed (boundary case) 622 http_request(host_creds, "/endpoint", "GET", backoff_factor=200) 623 assert mock_request.call_count == 1 624 625 # Value below limit should succeed 626 http_request(host_creds, "/endpoint", "GET", backoff_factor=10) 627 assert mock_request.call_count == 2 628 629 630 def test_http_request_negative_max_retries(): 631 host_creds = MlflowHostCreds("http://example.com") 632 633 with mock.patch("requests.Session.request") as mock_request: 634 with pytest.raises(MlflowException, match="The max_retries value must be either"): 635 http_request(host_creds, "/endpoint", "GET", max_retries=-1) 636 mock_request.assert_not_called() 637 638 639 def test_http_request_negative_backoff_factor(): 640 host_creds = MlflowHostCreds("http://example.com") 641 642 with mock.patch("requests.Session.request") as mock_request: 643 with pytest.raises(MlflowException, match="The backoff_factor value must be"): 644 http_request(host_creds, "/endpoint", "GET", backoff_factor=-1) 645 mock_request.assert_not_called() 646 647 648 def test_suppress_databricks_retry_after_secs_warnings(): 649 host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True) 650 651 def mock_do(*args, **kwargs): 652 warnings.warn(_DATABRICKS_SDK_RETRY_AFTER_SECS_DEPRECATION_WARNING) 653 return mock.MagicMock() 654 655 with ( 656 warnings.catch_warnings(record=True) as recorded_warnings, 657 mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client, 658 ): 659 warnings.simplefilter("always") 660 mock_workspace_client = mock.MagicMock() 661 mock_workspace_client.api_client.do = mock_do 662 mock_get_workspace_client.return_value = mock_workspace_client 663 http_request(host_creds, "/endpoint", "GET") 664 mock_get_workspace_client.assert_called_once() 665 assert not any( 666 _DATABRICKS_SDK_RETRY_AFTER_SECS_DEPRECATION_WARNING in str(w.message) 667 for w in recorded_warnings 668 ) 669 670 671 def test_databricks_sdk_retry_on_transient_errors(): 672 host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True) 673 674 call_count = 0 675 676 def mock_do_failing_then_success(*args, **kwargs): 677 nonlocal call_count 678 call_count += 1 679 if call_count <= 2: # Fail first 2 attempts 680 from databricks.sdk.errors import DatabricksError 681 682 from mlflow.protos.databricks_pb2 import ErrorCode 683 684 raise DatabricksError( 685 error_code=ErrorCode.Name(ErrorCode.INTERNAL_ERROR), message="Transient error" 686 ) 687 # Success on 3rd attempt 688 response_mock = mock.MagicMock() 689 response_mock._response = mock.MagicMock() 690 return {"contents": response_mock} 691 692 with mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client: 693 mock_workspace_client = mock.MagicMock() 694 mock_workspace_client.api_client.do = mock_do_failing_then_success 695 mock_get_workspace_client.return_value = mock_workspace_client 696 697 # Use smaller retry timeout to make test run faster 698 response = http_request( 699 host_creds, 700 "/endpoint", 701 "GET", 702 retry_timeout_seconds=10, 703 backoff_factor=0.1, # Very small backoff for faster test 704 ) 705 706 assert call_count == 3 # Should retry 2 times, succeed on 3rd 707 assert response is not None 708 709 710 def test_databricks_sdk_retry_max_retries_exceeded(): 711 host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True) 712 713 call_count = 0 714 715 def mock_do_always_fail(*args, **kwargs): 716 nonlocal call_count 717 call_count += 1 718 from databricks.sdk.errors import DatabricksError 719 720 raise DatabricksError(error_code="INTERNAL_ERROR", message="Always fails") 721 722 with ( 723 mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client, 724 mock.patch("mlflow.utils.rest_utils._logger") as mock_logger, 725 ): 726 mock_workspace_client = mock.MagicMock() 727 mock_workspace_client.api_client.do = mock_do_always_fail 728 mock_get_workspace_client.return_value = mock_workspace_client 729 730 response = http_request(host_creds, "/endpoint", "GET", max_retries=3) 731 732 assert call_count == 4 # Initial call + 3 retries 733 assert response.status_code == 500 # Should return error response 734 735 # Check that max retries warning was logged 736 mock_logger.warning.assert_called() 737 warning_call = mock_logger.warning.call_args[0][0] 738 assert "Max retries (3) exceeded" in warning_call 739 740 741 def test_databricks_sdk_retry_timeout_exceeded(): 742 host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True) 743 744 call_count = 0 745 746 def mock_do_always_fail(*args, **kwargs): 747 nonlocal call_count 748 call_count += 1 749 750 time.sleep(0.1) # Small delay to ensure timeout 751 from databricks.sdk.errors import DatabricksError 752 753 raise DatabricksError(error_code="INTERNAL_ERROR", message="Always fails") 754 755 with ( 756 mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client, 757 mock.patch("mlflow.utils.rest_utils._logger") as mock_logger, 758 ): 759 mock_workspace_client = mock.MagicMock() 760 mock_workspace_client.api_client.do = mock_do_always_fail 761 mock_get_workspace_client.return_value = mock_workspace_client 762 763 response = http_request( 764 host_creds, 765 "/endpoint", 766 "GET", 767 retry_timeout_seconds=0.2, # Very short timeout 768 max_retries=10, # High retry limit that shouldn't be reached 769 ) 770 771 assert call_count >= 1 # At least initial call 772 assert response.status_code == 500 # Should return error response 773 774 # Check that timeout warning was logged 775 mock_logger.warning.assert_called() 776 warning_call = mock_logger.warning.call_args[0][0] 777 assert "Retry timeout (0.2s) exceeded" in warning_call 778 779 780 def test_databricks_sdk_retry_non_retryable_error(): 781 host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True) 782 783 call_count = 0 784 785 def mock_do_non_retryable_error(*args, **kwargs): 786 nonlocal call_count 787 call_count += 1 788 from databricks.sdk.errors import InvalidParameterValue 789 790 # Use an error code that maps to 400 (non-retryable) 791 raise InvalidParameterValue(error_code="INVALID_PARAMETER_VALUE", message="Bad request") 792 793 with mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client: 794 mock_workspace_client = mock.MagicMock() 795 mock_workspace_client.api_client.do = mock_do_non_retryable_error 796 mock_get_workspace_client.return_value = mock_workspace_client 797 798 response = http_request(host_creds, "/endpoint", "GET", max_retries=5) 799 800 assert call_count == 1 # Should not retry on non-retryable error 801 assert response.status_code == 400 # Should return 400 for INVALID_PARAMETER_VALUE 802 803 804 def test_databricks_sdk_retry_backoff_calculation(): 805 from databricks.sdk.errors import DatabricksError 806 807 from mlflow.utils.request_utils import _TRANSIENT_FAILURE_RESPONSE_CODES 808 from mlflow.utils.rest_utils import _retry_databricks_sdk_call_with_exponential_backoff 809 810 call_count = 0 811 812 def mock_failing_call(): 813 nonlocal call_count 814 call_count += 1 815 816 raise DatabricksError(error_code="INTERNAL_ERROR", message="Mock error") 817 818 with mock.patch("mlflow.utils.rest_utils._time_sleep") as mock_sleep: 819 with pytest.raises(DatabricksError, match="Mock error"): 820 _retry_databricks_sdk_call_with_exponential_backoff( 821 call_func=mock_failing_call, 822 retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES, 823 retry_timeout_seconds=10, 824 backoff_factor=1, # Use 1 for predictable calculation 825 backoff_jitter=0, # No jitter for predictable calculation 826 max_retries=3, 827 ) 828 829 # Verify sleep was called with correct intervals 830 # attempt 0 (1st retry): 0 seconds (immediate) 831 # attempt 1 (2nd retry): 1 * (2^1) = 2 seconds 832 # attempt 2 (3rd retry): 1 * (2^2) = 4 seconds 833 expected_sleep_times = [0, 2, 4] 834 actual_sleep_times = [call.args[0] for call in mock_sleep.call_args_list] 835 assert actual_sleep_times == expected_sleep_times 836 assert call_count == 4 # Initial + 3 retries 837 838 839 @pytest.mark.skip 840 def test_timeout_parameter_propagation_with_timeout(): 841 with ( 842 mock.patch("databricks.sdk.WorkspaceClient") as mock_workspace_client, 843 mock.patch("databricks.sdk.config.Config") as mock_config, 844 ): 845 # Test http_request with timeout via get_workspace_client directly 846 mock_workspace_client_instance = mock.MagicMock() 847 mock_workspace_client_instance.api_client.do.return_value = {"contents": mock.MagicMock()} 848 mock_workspace_client.return_value = mock_workspace_client_instance 849 850 get_workspace_client( 851 use_secret_scope_token=False, 852 host="http://my-host", 853 token=None, 854 databricks_auth_profile="my-profile", 855 retry_timeout_seconds=None, 856 timeout=180, 857 ) 858 859 mock_config.assert_called_once_with( 860 profile="my-profile", 861 http_timeout_seconds=180, 862 retry_timeout_seconds=mock.ANY, 863 ) 864 865 866 @pytest.mark.skip 867 def test_timeout_parameter_propagation_without_timeout(): 868 with ( 869 mock.patch("databricks.sdk.WorkspaceClient") as mock_workspace_client, 870 mock.patch("databricks.sdk.config.Config") as mock_config, 871 ): 872 # Test http_request without timeout via get_workspace_client directly 873 mock_workspace_client_instance = mock.MagicMock() 874 mock_workspace_client_instance.api_client.do.return_value = {"contents": mock.MagicMock()} 875 mock_workspace_client.return_value = mock_workspace_client_instance 876 877 get_workspace_client( 878 use_secret_scope_token=False, 879 host="http://my-host", 880 token=None, 881 databricks_auth_profile="my-profile", 882 retry_timeout_seconds=None, 883 timeout=None, 884 ) 885 886 mock_config.assert_called_once_with( 887 profile="my-profile", 888 retry_timeout_seconds=mock.ANY, 889 ) 890 891 892 def test_deployment_client_timeout_propagation(monkeypatch): 893 with ( 894 mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client, 895 mock.patch( 896 "mlflow.utils.databricks_utils.get_databricks_host_creds" 897 ) as mock_get_databricks_host_creds, 898 mock.patch( 899 "mlflow.deployments.databricks.get_databricks_host_creds" 900 ) as mock_deployment_host_creds, 901 ): 902 # Mock the host creds to use Databricks SDK 903 mock_host_creds = MlflowHostCreds("http://my-host", use_databricks_sdk=True) 904 mock_get_databricks_host_creds.return_value = mock_host_creds 905 mock_deployment_host_creds.return_value = mock_host_creds 906 907 # Mock workspace client and its response 908 mock_workspace_client_instance = mock.MagicMock() 909 mock_workspace_client_instance.api_client.do.return_value = {"contents": mock.MagicMock()} 910 mock_get_workspace_client.return_value = mock_workspace_client_instance 911 912 # Set the environment variable to a custom value using monkeypatch 913 monkeypatch.setenv("MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT", "300") 914 915 # Create deployment client and call predict 916 client = DatabricksDeploymentClient("databricks") 917 client.predict(endpoint="test-endpoint", inputs={"test": "data"}) 918 919 # Verify get_workspace_client was called with the deployment predict timeout 920 mock_get_workspace_client.assert_called_once_with( 921 False, # use_secret_scope_token 922 "http://my-host", # host 923 None, # token 924 None, # databricks_auth_profile 925 retry_timeout_seconds=600, 926 timeout=300, # MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT value 927 ) 928 929 930 def test_http_request_with_databricks_traffic_id(monkeypatch: pytest.MonkeyPatch): 931 response = mock.MagicMock() 932 response.status_code = 200 933 934 # Test with env var set 935 monkeypatch.setenv(_MLFLOW_DATABRICKS_TRAFFIC_ID.name, "test-traffic-id-12345") 936 with mock.patch("requests.Session.request", return_value=response) as mock_request: 937 http_request(MlflowHostCreds("http://my-host"), "/my/endpoint", "GET") 938 headers = mock_request.call_args.kwargs["headers"] 939 assert headers["x-databricks-traffic-id"] == "test-traffic-id-12345" 940 941 # Test without env var set 942 monkeypatch.delenv(_MLFLOW_DATABRICKS_TRAFFIC_ID.name) 943 with mock.patch("requests.Session.request", return_value=response) as mock_request: 944 http_request(MlflowHostCreds("http://my-host"), "/my/endpoint", "GET") 945 headers = mock_request.call_args.kwargs["headers"] 946 assert "x-databricks-traffic-id" not in headers 947 948 949 def test_http_request_with_workspace_id(): 950 response = mock.MagicMock() 951 response.status_code = 200 952 953 # With workspace_id set, header should be included 954 creds = MlflowHostCreds("http://my-host", workspace_id="6051921418418893") 955 with mock.patch("requests.Session.request", return_value=response) as mock_request: 956 http_request(creds, "/my/endpoint", "GET") 957 headers = mock_request.call_args.kwargs["headers"] 958 assert headers["x-databricks-org-id"] == "6051921418418893" 959 960 # Without workspace_id, header should not be present 961 creds = MlflowHostCreds("http://my-host") 962 with mock.patch("requests.Session.request", return_value=response) as mock_request: 963 http_request(creds, "/my/endpoint", "GET") 964 headers = mock_request.call_args.kwargs["headers"] 965 assert "x-databricks-org-id" not in headers 966 967 968 def test_mlflow_host_creds_workspace_id_equality(): 969 creds1 = MlflowHostCreds("http://my-host", workspace_id="123") 970 creds2 = MlflowHostCreds("http://my-host", workspace_id="123") 971 creds3 = MlflowHostCreds("http://my-host", workspace_id="456") 972 creds4 = MlflowHostCreds("http://my-host") 973 974 assert creds1 == creds2 975 assert creds1 != creds3 976 assert creds1 != creds4 977 assert hash(creds1) == hash(creds2) 978 assert hash(creds1) != hash(creds3) 979 980 981 @pytest.mark.parametrize( 982 ("timeout", "retry_timeout_seconds", "should_warn"), 983 [ 984 (300, 120, True), 985 (120, 600, False), 986 (300, 300, False), 987 (None, 120, False), 988 (300, None, False), 989 (None, None, False), 990 ], 991 ) 992 def test_validate_deployment_timeout_config(timeout, retry_timeout_seconds, should_warn): 993 from mlflow.utils.rest_utils import validate_deployment_timeout_config 994 995 if should_warn: 996 with warnings.catch_warnings(record=True) as w: 997 warnings.simplefilter("always") 998 validate_deployment_timeout_config( 999 timeout=timeout, retry_timeout_seconds=retry_timeout_seconds 1000 ) 1001 assert len(w) == 1 1002 warning_msg = str(w[0].message) 1003 assert "MLFLOW_DEPLOYMENT_PREDICT_TOTAL_TIMEOUT" in warning_msg 1004 assert f"({retry_timeout_seconds}s)" in warning_msg 1005 assert f"({timeout}s)" in warning_msg 1006 else: 1007 with warnings.catch_warnings(record=True) as w: 1008 warnings.simplefilter("always") 1009 validate_deployment_timeout_config( 1010 timeout=timeout, retry_timeout_seconds=retry_timeout_seconds 1011 ) 1012 assert len(w) == 0