/ tests / utils / test_rest_utils.py
test_rest_utils.py
   1  import re
   2  import time
   3  import warnings
   4  from unittest import mock
   5  
   6  import numpy
   7  import pytest
   8  import requests
   9  
  10  from mlflow.deployments.databricks import DatabricksDeploymentClient
  11  from mlflow.environment_variables import (
  12      _MLFLOW_DATABRICKS_TRAFFIC_ID,
  13      MLFLOW_HTTP_REQUEST_TIMEOUT,
  14  )
  15  from mlflow.exceptions import InvalidUrlException, MlflowException, RestException
  16  from mlflow.protos.databricks_pb2 import ENDPOINT_NOT_FOUND, ErrorCode
  17  from mlflow.protos.service_pb2 import GetRun
  18  from mlflow.pyfunc.scoring_server import NumpyEncoder
  19  from mlflow.tracking.request_header.default_request_header_provider import (
  20      _CLIENT_VERSION,
  21      _USER_AGENT,
  22      DefaultRequestHeaderProvider,
  23  )
  24  from mlflow.utils.rest_utils import (
  25      _DATABRICKS_SDK_RETRY_AFTER_SECS_DEPRECATION_WARNING,
  26      MlflowHostCreds,
  27      _can_parse_as_json_object,
  28      augmented_raise_for_status,
  29      call_endpoint,
  30      call_endpoints,
  31      get_workspace_client,
  32      http_request,
  33      http_request_safe,
  34  )
  35  from mlflow.utils.workspace_context import WorkspaceContext
  36  from mlflow.utils.workspace_utils import WORKSPACE_HEADER_NAME
  37  
  38  from tests import helper_functions
  39  
  40  
  41  @pytest.mark.parametrize(
  42      "response_mock",
  43      [
  44          helper_functions.create_mock_response(400, "Error message but not a JSON string"),
  45          helper_functions.create_mock_response(400, ""),
  46          helper_functions.create_mock_response(400, None),
  47      ],
  48  )
  49  def test_malformed_json_error_response(response_mock):
  50      with mock.patch("requests.Session.request", return_value=response_mock):
  51          host_only = MlflowHostCreds("http://my-host")
  52  
  53          response_proto = GetRun.Response()
  54          with pytest.raises(
  55              MlflowException, match="API request to endpoint /my/endpoint failed with error code 400"
  56          ):
  57              call_endpoint(host_only, "/my/endpoint", "GET", None, response_proto)
  58  
  59  
  60  def test_call_endpoints():
  61      with mock.patch("mlflow.utils.rest_utils.call_endpoint") as mock_call_endpoint:
  62          response_proto = GetRun.Response()
  63          mock_call_endpoint.side_effect = [
  64              RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}),
  65              None,
  66          ]
  67          host_only = MlflowHostCreds("http://my-host")
  68          endpoints = [("/my/endpoint", "POST"), ("/my/endpoint", "GET")]
  69          resp = call_endpoints(host_only, endpoints, "", response_proto)
  70          mock_call_endpoint.assert_has_calls([
  71              mock.call(host_only, endpoint, method, "", response_proto, None)
  72              for endpoint, method in endpoints
  73          ])
  74          assert resp is None
  75  
  76  
  77  def test_call_endpoints_raises_exceptions():
  78      with mock.patch("mlflow.utils.rest_utils.call_endpoint") as mock_call_endpoint:
  79          response_proto = GetRun.Response()
  80          mock_call_endpoint.side_effect = [
  81              RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}),
  82              RestException({"error_code": ErrorCode.Name(ENDPOINT_NOT_FOUND)}),
  83          ]
  84          host_only = MlflowHostCreds("http://my-host")
  85          endpoints = [("/my/endpoint", "POST"), ("/my/endpoint", "GET")]
  86          with pytest.raises(RestException, match="ENDPOINT_NOT_FOUND"):
  87              call_endpoints(host_only, endpoints, "", response_proto)
  88          mock_call_endpoint.side_effect = [RestException({}), None]
  89          with pytest.raises(RestException, match="INTERNAL_ERROR"):
  90              call_endpoints(host_only, endpoints, "", response_proto)
  91  
  92  
  93  def test_http_request_hostonly():
  94      host_only = MlflowHostCreds("http://my-host")
  95      response = mock.MagicMock()
  96      response.status_code = 200
  97      with mock.patch("requests.Session.request", return_value=response) as mock_request:
  98          http_request(host_only, "/my/endpoint", "GET")
  99          mock_request.assert_called_with(
 100              "GET",
 101              "http://my-host/my/endpoint",
 102              allow_redirects=True,
 103              verify=True,
 104              headers=DefaultRequestHeaderProvider().request_headers(),
 105              timeout=120,
 106          )
 107  
 108  
 109  def test_http_request_includes_workspace_header_for_mlflow_routes():
 110      host_only = MlflowHostCreds("http://my-host")
 111      response = mock.MagicMock()
 112      response.status_code = 200
 113      with WorkspaceContext("team-a"):
 114          with mock.patch("requests.Session.request", return_value=response) as mock_request:
 115              http_request(host_only, "/api/2.0/mlflow/runs/search", "GET")
 116          headers = mock_request.call_args.kwargs["headers"]
 117          assert headers[WORKSPACE_HEADER_NAME] == "team-a"
 118  
 119  
 120  def test_http_request_omits_workspace_header_for_workspace_admin_routes():
 121      host_only = MlflowHostCreds("http://my-host")
 122      response = mock.MagicMock()
 123      response.status_code = 200
 124      with WorkspaceContext("team-a"):
 125          with mock.patch("requests.Session.request", return_value=response) as mock_request:
 126              http_request(host_only, "/api/3.0/mlflow/workspaces/team-a", "GET")
 127          headers = mock_request.call_args.kwargs["headers"]
 128          assert WORKSPACE_HEADER_NAME not in headers
 129  
 130  
 131  def test_http_request_cleans_hostname():
 132      # Add a trailing slash, should be removed.
 133      host_only = MlflowHostCreds("http://my-host/")
 134      response = mock.MagicMock()
 135      response.status_code = 200
 136      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 137          http_request(host_only, "/my/endpoint", "GET")
 138          mock_request.assert_called_with(
 139              "GET",
 140              "http://my-host/my/endpoint",
 141              allow_redirects=True,
 142              verify=True,
 143              headers=DefaultRequestHeaderProvider().request_headers(),
 144              timeout=120,
 145          )
 146  
 147  
 148  def test_http_request_with_basic_auth():
 149      host_only = MlflowHostCreds("http://my-host", username="user", password="pass")
 150      response = mock.MagicMock()
 151      response.status_code = 200
 152      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 153          http_request(host_only, "/my/endpoint", "GET")
 154          headers = DefaultRequestHeaderProvider().request_headers()
 155          headers["Authorization"] = "Basic dXNlcjpwYXNz"
 156          mock_request.assert_called_with(
 157              "GET",
 158              "http://my-host/my/endpoint",
 159              allow_redirects=True,
 160              verify=True,
 161              headers=headers,
 162              timeout=120,
 163          )
 164  
 165  
 166  def test_http_request_with_aws_sigv4(monkeypatch):
 167      from requests_auth_aws_sigv4 import AWSSigV4
 168  
 169      monkeypatch.setenv("AWS_ACCESS_KEY_ID", "access-key")
 170      monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret-key")
 171      monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1")
 172      aws_sigv4 = MlflowHostCreds("http://my-host", aws_sigv4=True)
 173      response = mock.MagicMock()
 174      response.status_code = 200
 175  
 176      class AuthMatcher:
 177          def __eq__(self, other):
 178              return isinstance(other, AWSSigV4)
 179  
 180      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 181          http_request(aws_sigv4, "/my/endpoint", "GET")
 182          mock_request.assert_called_once_with(
 183              "GET",
 184              "http://my-host/my/endpoint",
 185              allow_redirects=True,
 186              verify=mock.ANY,
 187              headers=mock.ANY,
 188              timeout=mock.ANY,
 189              auth=AuthMatcher(),
 190          )
 191  
 192  
 193  def test_http_request_with_auth():
 194      mock_fetch_auth = {"test_name": "test_auth_value"}
 195      auth = "test_auth_name"
 196      host_only = MlflowHostCreds("http://my-host", auth=auth)
 197      response = mock.MagicMock()
 198      response.status_code = 200
 199      with (
 200          mock.patch("requests.Session.request", return_value=response) as mock_request,
 201          mock.patch(
 202              "mlflow.tracking.request_auth.registry.fetch_auth", return_value=mock_fetch_auth
 203          ) as mock_fetch_auth_call,
 204      ):
 205          http_request(host_only, "/my/endpoint", "GET")
 206  
 207          mock_fetch_auth_call.assert_called_with(auth)
 208  
 209          mock_request.assert_called_with(
 210              "GET",
 211              "http://my-host/my/endpoint",
 212              allow_redirects=True,
 213              verify=mock.ANY,
 214              headers=mock.ANY,
 215              timeout=mock.ANY,
 216              auth=mock_fetch_auth,
 217          )
 218  
 219  
 220  def test_http_request_with_token():
 221      host_only = MlflowHostCreds("http://my-host", token="my-token")
 222      response = mock.MagicMock()
 223      response.status_code = 200
 224      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 225          http_request(host_only, "/my/endpoint", "GET")
 226          headers = DefaultRequestHeaderProvider().request_headers()
 227          headers["Authorization"] = "Bearer my-token"
 228          mock_request.assert_called_with(
 229              "GET",
 230              "http://my-host/my/endpoint",
 231              allow_redirects=True,
 232              verify=True,
 233              headers=headers,
 234              timeout=120,
 235          )
 236  
 237  
 238  def test_http_request_with_insecure():
 239      host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
 240      response = mock.MagicMock()
 241      response.status_code = 200
 242      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 243          http_request(host_only, "/my/endpoint", "GET")
 244          mock_request.assert_called_with(
 245              "GET",
 246              "http://my-host/my/endpoint",
 247              allow_redirects=True,
 248              verify=False,
 249              headers=DefaultRequestHeaderProvider().request_headers(),
 250              timeout=120,
 251          )
 252  
 253  
 254  def test_http_request_client_cert_path():
 255      host_only = MlflowHostCreds("http://my-host", client_cert_path="/some/path")
 256      response = mock.MagicMock()
 257      response.status_code = 200
 258      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 259          http_request(host_only, "/my/endpoint", "GET")
 260          mock_request.assert_called_with(
 261              "GET",
 262              "http://my-host/my/endpoint",
 263              allow_redirects=True,
 264              verify=True,
 265              cert="/some/path",
 266              headers=DefaultRequestHeaderProvider().request_headers(),
 267              timeout=120,
 268          )
 269  
 270  
 271  def test_http_request_server_cert_path():
 272      host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path")
 273      response = mock.MagicMock()
 274      response.status_code = 200
 275      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 276          http_request(host_only, "/my/endpoint", "GET")
 277          mock_request.assert_called_with(
 278              "GET",
 279              "http://my-host/my/endpoint",
 280              allow_redirects=True,
 281              verify="/some/path",
 282              headers=DefaultRequestHeaderProvider().request_headers(),
 283              timeout=120,
 284          )
 285  
 286  
 287  def test_http_request_with_content_type_header():
 288      host_only = MlflowHostCreds("http://my-host", token="my-token")
 289      response = mock.MagicMock()
 290      response.status_code = 200
 291      extra_headers = {"Content-Type": "text/plain"}
 292      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 293          http_request(host_only, "/my/endpoint", "GET", extra_headers=extra_headers)
 294          headers = DefaultRequestHeaderProvider().request_headers()
 295          headers["Authorization"] = "Bearer my-token"
 296          headers["Content-Type"] = "text/plain"
 297          mock_request.assert_called_with(
 298              "GET",
 299              "http://my-host/my/endpoint",
 300              allow_redirects=True,
 301              verify=True,
 302              headers=headers,
 303              timeout=120,
 304          )
 305  
 306  
 307  def test_http_request_request_headers():
 308      from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider
 309  
 310      # The test plugin's request header provider always returns False from in_context to avoid
 311      # polluting request headers in developers' environments. The following mock overrides this to
 312      # perform the integration test.
 313      response = mock.MagicMock()
 314      response.status_code = 200
 315      with (
 316          mock.patch("requests.Session.request", return_value=response) as mock_request,
 317          mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True),
 318      ):
 319          host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path")
 320          http_request(host_only, "/my/endpoint", "GET")
 321          mock_request.assert_called_with(
 322              "GET",
 323              "http://my-host/my/endpoint",
 324              allow_redirects=True,
 325              verify="/some/path",
 326              headers={**DefaultRequestHeaderProvider().request_headers(), "test": "header"},
 327              timeout=120,
 328          )
 329  
 330  
 331  def test_http_request_request_headers_default():
 332      from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider
 333  
 334      # The test plugin's request header provider always returns False from in_context to avoid
 335      # polluting request headers in developers' environments. The following mock overrides this to
 336      # perform the integration test.
 337      host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path")
 338      default_headers = DefaultRequestHeaderProvider().request_headers()
 339      expected_headers = {
 340          _USER_AGENT: "{} {}".format(default_headers[_USER_AGENT], "test_user_agent"),
 341          _CLIENT_VERSION: "{} {}".format(default_headers[_CLIENT_VERSION], "test_client_version"),
 342      }
 343  
 344      response = mock.MagicMock()
 345      response.status_code = 200
 346      with (
 347          mock.patch("requests.Session.request", return_value=response) as mock_request,
 348          mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True),
 349          mock.patch.object(
 350              PluginRequestHeaderProvider,
 351              "request_headers",
 352              return_value={_USER_AGENT: "test_user_agent", _CLIENT_VERSION: "test_client_version"},
 353          ),
 354      ):
 355          http_request(host_only, "/my/endpoint", "GET")
 356          mock_request.assert_called_with(
 357              "GET",
 358              "http://my-host/my/endpoint",
 359              allow_redirects=True,
 360              verify="/some/path",
 361              headers=expected_headers,
 362              timeout=120,
 363          )
 364  
 365  
 366  def test_http_request_request_headers_default_and_extra_header():
 367      from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider
 368  
 369      # The test plugin's request header provider always returns False from in_context to avoid
 370      # polluting request headers in developers' environments. The following mock overrides this to
 371      # perform the integration test.
 372      host_only = MlflowHostCreds("http://my-host", server_cert_path="/some/path")
 373      default_headers = DefaultRequestHeaderProvider().request_headers()
 374      expected_headers = {
 375          _USER_AGENT: "{} {}".format(default_headers[_USER_AGENT], "test_user_agent"),
 376          _CLIENT_VERSION: "{} {}".format(default_headers[_CLIENT_VERSION], "test_client_version"),
 377          "header": "value",
 378      }
 379  
 380      response = mock.MagicMock()
 381      response.status_code = 200
 382      with (
 383          mock.patch("requests.Session.request", return_value=response) as mock_request,
 384          mock.patch.object(PluginRequestHeaderProvider, "in_context", return_value=True),
 385          mock.patch.object(
 386              PluginRequestHeaderProvider,
 387              "request_headers",
 388              return_value={
 389                  _USER_AGENT: "test_user_agent",
 390                  _CLIENT_VERSION: "test_client_version",
 391                  "header": "value",
 392              },
 393          ),
 394      ):
 395          http_request(host_only, "/my/endpoint", "GET")
 396          mock_request.assert_called_with(
 397              "GET",
 398              "http://my-host/my/endpoint",
 399              allow_redirects=True,
 400              verify="/some/path",
 401              headers=expected_headers,
 402              timeout=120,
 403          )
 404  
 405  
 406  def test_http_request_with_invalid_url_raise_invalid_url_exception():
 407      host_only = MlflowHostCreds("http://my-host")
 408  
 409      with pytest.raises(InvalidUrlException, match="Invalid url: http://my-host/invalid_url"):
 410          with mock.patch("requests.Session.request", side_effect=requests.exceptions.InvalidURL):
 411              http_request(host_only, "/invalid_url", "GET")
 412  
 413  
 414  def test_http_request_with_invalid_url_raise_mlflow_exception():
 415      host_only = MlflowHostCreds("http://my-host")
 416  
 417      with pytest.raises(MlflowException, match="Invalid url: http://my-host/invalid_url"):
 418          with mock.patch("requests.Session.request", side_effect=requests.exceptions.InvalidURL):
 419              http_request(host_only, "/invalid_url", "GET")
 420  
 421  
 422  def test_ignore_tls_verification_not_server_cert_path():
 423      with pytest.raises(
 424          MlflowException,
 425          match="When 'ignore_tls_verification' is true then 'server_cert_path' must not be set",
 426      ):
 427          MlflowHostCreds(
 428              "http://my-host",
 429              ignore_tls_verification=True,
 430              server_cert_path="/some/path",
 431          )
 432  
 433  
 434  def test_http_request_wrapper():
 435      host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
 436      response = mock.MagicMock()
 437      response.status_code = 200
 438      response.text = "{}"
 439      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 440          http_request_safe(host_only, "/my/endpoint", "GET")
 441          mock_request.assert_called_with(
 442              "GET",
 443              "http://my-host/my/endpoint",
 444              allow_redirects=True,
 445              verify=False,
 446              headers=DefaultRequestHeaderProvider().request_headers(),
 447              timeout=120,
 448          )
 449          response.text = "non json"
 450          http_request_safe(host_only, "/my/endpoint", "GET")
 451          mock_request.assert_called_with(
 452              "GET",
 453              "http://my-host/my/endpoint",
 454              allow_redirects=True,
 455              verify=False,
 456              headers=DefaultRequestHeaderProvider().request_headers(),
 457              timeout=120,
 458          )
 459          response.status_code = 400
 460          response.text = ""
 461          with pytest.raises(MlflowException, match="Response body"):
 462              http_request_safe(host_only, "/my/endpoint", "GET")
 463          response.text = (
 464              '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
 465          )
 466          with pytest.raises(RestException, match="RESOURCE_DOES_NOT_EXIST: Node type not supported"):
 467              http_request_safe(host_only, "/my/endpoint", "GET")
 468  
 469  
 470  def test_numpy_encoder():
 471      test_number = numpy.int64(42)
 472      ne = NumpyEncoder()
 473      defaulted_val = ne.default(test_number)
 474      assert defaulted_val == 42
 475  
 476  
 477  def test_numpy_encoder_fail():
 478      if not hasattr(numpy, "float128"):
 479          pytest.skip("numpy on exit this platform has no float128")
 480      test_number = numpy.float128
 481      ne = NumpyEncoder()
 482      with pytest.raises(TypeError, match="not JSON serializable"):
 483          ne.default(test_number)
 484  
 485  
 486  def test_can_parse_as_json_object():
 487      assert _can_parse_as_json_object("{}")
 488      assert _can_parse_as_json_object('{"a": "b"}')
 489      assert _can_parse_as_json_object('{"a": {"b": "c"}}')
 490      assert not _can_parse_as_json_object("[0, 1, 2]")
 491      assert not _can_parse_as_json_object('"abc"')
 492      assert not _can_parse_as_json_object("123")
 493  
 494  
 495  def test_http_request_customize_config(monkeypatch):
 496      with mock.patch(
 497          "mlflow.utils.rest_utils._get_http_response_with_retries"
 498      ) as mock_get_http_response_with_retries:
 499          host_only = MlflowHostCreds("http://my-host")
 500          monkeypatch.delenv("MLFLOW_HTTP_REQUEST_MAX_RETRIES", raising=False)
 501          monkeypatch.delenv("MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR", raising=False)
 502          monkeypatch.delenv("MLFLOW_HTTP_REQUEST_TIMEOUT", raising=False)
 503          monkeypatch.delenv("MLFLOW_HTTP_RESPECT_RETRY_AFTER_HEADER", raising=False)
 504          http_request(host_only, "/my/endpoint", "GET")
 505          mock_get_http_response_with_retries.assert_called_with(
 506              mock.ANY,
 507              mock.ANY,
 508              7,
 509              2,
 510              1.0,
 511              mock.ANY,
 512              True,
 513              headers=mock.ANY,
 514              verify=mock.ANY,
 515              timeout=120,
 516              respect_retry_after_header=True,
 517          )
 518          mock_get_http_response_with_retries.reset_mock()
 519          monkeypatch.setenv("MLFLOW_HTTP_REQUEST_MAX_RETRIES", "8")
 520          monkeypatch.setenv("MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR", "3")
 521          monkeypatch.setenv("MLFLOW_HTTP_REQUEST_BACKOFF_JITTER", "1.0")
 522          monkeypatch.setenv("MLFLOW_HTTP_REQUEST_TIMEOUT", "300")
 523          monkeypatch.setenv("MLFLOW_HTTP_RESPECT_RETRY_AFTER_HEADER", "false")
 524          http_request(host_only, "/my/endpoint", "GET")
 525          mock_get_http_response_with_retries.assert_called_with(
 526              mock.ANY,
 527              mock.ANY,
 528              8,
 529              3,
 530              1.0,
 531              mock.ANY,
 532              True,
 533              headers=mock.ANY,
 534              verify=mock.ANY,
 535              timeout=300,
 536              respect_retry_after_header=False,
 537          )
 538  
 539  
 540  def test_http_request_explains_how_to_increase_timeout_in_error_message():
 541      with mock.patch("requests.Session.request", side_effect=requests.exceptions.Timeout):
 542          with pytest.raises(
 543              MlflowException,
 544              match=(
 545                  r"To increase the timeout, set the environment variable "
 546                  + re.escape(str(MLFLOW_HTTP_REQUEST_TIMEOUT))
 547              ),
 548          ):
 549              http_request(MlflowHostCreds("http://my-host"), "/my/endpoint", "GET")
 550  
 551  
 552  def test_augmented_raise_for_status():
 553      response = requests.Response()
 554      response.status_code = 403
 555      response._content = b"Token expired"
 556  
 557      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 558          response = requests.get("https://github.com/mlflow/mlflow.git")
 559          mock_request.assert_called_once()
 560  
 561      with pytest.raises(requests.HTTPError, match="Token expired") as e:
 562          augmented_raise_for_status(response)
 563  
 564      assert e.value.response == response
 565      assert e.value.request == response.request
 566      assert response.text in str(e.value)
 567  
 568  
 569  def test_provide_redirect_kwarg():
 570      with mock.patch("requests.Session.request") as mock_request:
 571          mock_request.return_value.status_code = 302
 572          mock_request.return_value.text = "mock response"
 573  
 574          response = http_request(
 575              MlflowHostCreds("http://my-host"),
 576              "/my/endpoint",
 577              "GET",
 578              allow_redirects=False,
 579          )
 580  
 581          assert response.text == "mock response"
 582          mock_request.assert_called_with(
 583              "GET",
 584              "http://my-host/my/endpoint",
 585              allow_redirects=False,
 586              headers=mock.ANY,
 587              verify=mock.ANY,
 588              timeout=120,
 589          )
 590  
 591  
 592  def test_http_request_max_retries(monkeypatch):
 593      monkeypatch.setenv("_MLFLOW_HTTP_REQUEST_MAX_RETRIES_LIMIT", "15")
 594      host_creds = MlflowHostCreds("http://example.com")
 595  
 596      with mock.patch("requests.Session.request") as mock_request:
 597          # Value exceeding limit should raise
 598          with pytest.raises(MlflowException, match="The configured max_retries"):
 599              http_request(host_creds, "/endpoint", "GET", max_retries=16)
 600          mock_request.assert_not_called()
 601  
 602          # Value equal to limit should succeed (boundary case)
 603          http_request(host_creds, "/endpoint", "GET", max_retries=15)
 604          assert mock_request.call_count == 1
 605  
 606          # Value below limit should succeed
 607          http_request(host_creds, "/endpoint", "GET", max_retries=3)
 608          assert mock_request.call_count == 2
 609  
 610  
 611  def test_http_request_backoff_factor(monkeypatch):
 612      monkeypatch.setenv("_MLFLOW_HTTP_REQUEST_MAX_BACKOFF_FACTOR_LIMIT", "200")
 613      host_creds = MlflowHostCreds("http://example.com")
 614  
 615      with mock.patch("requests.Session.request") as mock_request:
 616          # Value exceeding limit should raise
 617          with pytest.raises(MlflowException, match="The configured backoff_factor"):
 618              http_request(host_creds, "/endpoint", "GET", backoff_factor=250)
 619          mock_request.assert_not_called()
 620  
 621          # Value equal to limit should succeed (boundary case)
 622          http_request(host_creds, "/endpoint", "GET", backoff_factor=200)
 623          assert mock_request.call_count == 1
 624  
 625          # Value below limit should succeed
 626          http_request(host_creds, "/endpoint", "GET", backoff_factor=10)
 627          assert mock_request.call_count == 2
 628  
 629  
 630  def test_http_request_negative_max_retries():
 631      host_creds = MlflowHostCreds("http://example.com")
 632  
 633      with mock.patch("requests.Session.request") as mock_request:
 634          with pytest.raises(MlflowException, match="The max_retries value must be either"):
 635              http_request(host_creds, "/endpoint", "GET", max_retries=-1)
 636          mock_request.assert_not_called()
 637  
 638  
 639  def test_http_request_negative_backoff_factor():
 640      host_creds = MlflowHostCreds("http://example.com")
 641  
 642      with mock.patch("requests.Session.request") as mock_request:
 643          with pytest.raises(MlflowException, match="The backoff_factor value must be"):
 644              http_request(host_creds, "/endpoint", "GET", backoff_factor=-1)
 645          mock_request.assert_not_called()
 646  
 647  
 648  def test_suppress_databricks_retry_after_secs_warnings():
 649      host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True)
 650  
 651      def mock_do(*args, **kwargs):
 652          warnings.warn(_DATABRICKS_SDK_RETRY_AFTER_SECS_DEPRECATION_WARNING)
 653          return mock.MagicMock()
 654  
 655      with (
 656          warnings.catch_warnings(record=True) as recorded_warnings,
 657          mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client,
 658      ):
 659          warnings.simplefilter("always")
 660          mock_workspace_client = mock.MagicMock()
 661          mock_workspace_client.api_client.do = mock_do
 662          mock_get_workspace_client.return_value = mock_workspace_client
 663          http_request(host_creds, "/endpoint", "GET")
 664          mock_get_workspace_client.assert_called_once()
 665          assert not any(
 666              _DATABRICKS_SDK_RETRY_AFTER_SECS_DEPRECATION_WARNING in str(w.message)
 667              for w in recorded_warnings
 668          )
 669  
 670  
 671  def test_databricks_sdk_retry_on_transient_errors():
 672      host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True)
 673  
 674      call_count = 0
 675  
 676      def mock_do_failing_then_success(*args, **kwargs):
 677          nonlocal call_count
 678          call_count += 1
 679          if call_count <= 2:  # Fail first 2 attempts
 680              from databricks.sdk.errors import DatabricksError
 681  
 682              from mlflow.protos.databricks_pb2 import ErrorCode
 683  
 684              raise DatabricksError(
 685                  error_code=ErrorCode.Name(ErrorCode.INTERNAL_ERROR), message="Transient error"
 686              )
 687          # Success on 3rd attempt
 688          response_mock = mock.MagicMock()
 689          response_mock._response = mock.MagicMock()
 690          return {"contents": response_mock}
 691  
 692      with mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client:
 693          mock_workspace_client = mock.MagicMock()
 694          mock_workspace_client.api_client.do = mock_do_failing_then_success
 695          mock_get_workspace_client.return_value = mock_workspace_client
 696  
 697          # Use smaller retry timeout to make test run faster
 698          response = http_request(
 699              host_creds,
 700              "/endpoint",
 701              "GET",
 702              retry_timeout_seconds=10,
 703              backoff_factor=0.1,  # Very small backoff for faster test
 704          )
 705  
 706          assert call_count == 3  # Should retry 2 times, succeed on 3rd
 707          assert response is not None
 708  
 709  
 710  def test_databricks_sdk_retry_max_retries_exceeded():
 711      host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True)
 712  
 713      call_count = 0
 714  
 715      def mock_do_always_fail(*args, **kwargs):
 716          nonlocal call_count
 717          call_count += 1
 718          from databricks.sdk.errors import DatabricksError
 719  
 720          raise DatabricksError(error_code="INTERNAL_ERROR", message="Always fails")
 721  
 722      with (
 723          mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client,
 724          mock.patch("mlflow.utils.rest_utils._logger") as mock_logger,
 725      ):
 726          mock_workspace_client = mock.MagicMock()
 727          mock_workspace_client.api_client.do = mock_do_always_fail
 728          mock_get_workspace_client.return_value = mock_workspace_client
 729  
 730          response = http_request(host_creds, "/endpoint", "GET", max_retries=3)
 731  
 732          assert call_count == 4  # Initial call + 3 retries
 733          assert response.status_code == 500  # Should return error response
 734  
 735          # Check that max retries warning was logged
 736          mock_logger.warning.assert_called()
 737          warning_call = mock_logger.warning.call_args[0][0]
 738          assert "Max retries (3) exceeded" in warning_call
 739  
 740  
 741  def test_databricks_sdk_retry_timeout_exceeded():
 742      host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True)
 743  
 744      call_count = 0
 745  
 746      def mock_do_always_fail(*args, **kwargs):
 747          nonlocal call_count
 748          call_count += 1
 749  
 750          time.sleep(0.1)  # Small delay to ensure timeout
 751          from databricks.sdk.errors import DatabricksError
 752  
 753          raise DatabricksError(error_code="INTERNAL_ERROR", message="Always fails")
 754  
 755      with (
 756          mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client,
 757          mock.patch("mlflow.utils.rest_utils._logger") as mock_logger,
 758      ):
 759          mock_workspace_client = mock.MagicMock()
 760          mock_workspace_client.api_client.do = mock_do_always_fail
 761          mock_get_workspace_client.return_value = mock_workspace_client
 762  
 763          response = http_request(
 764              host_creds,
 765              "/endpoint",
 766              "GET",
 767              retry_timeout_seconds=0.2,  # Very short timeout
 768              max_retries=10,  # High retry limit that shouldn't be reached
 769          )
 770  
 771          assert call_count >= 1  # At least initial call
 772          assert response.status_code == 500  # Should return error response
 773  
 774          # Check that timeout warning was logged
 775          mock_logger.warning.assert_called()
 776          warning_call = mock_logger.warning.call_args[0][0]
 777          assert "Retry timeout (0.2s) exceeded" in warning_call
 778  
 779  
 780  def test_databricks_sdk_retry_non_retryable_error():
 781      host_creds = MlflowHostCreds("http://example.com", use_databricks_sdk=True)
 782  
 783      call_count = 0
 784  
 785      def mock_do_non_retryable_error(*args, **kwargs):
 786          nonlocal call_count
 787          call_count += 1
 788          from databricks.sdk.errors import InvalidParameterValue
 789  
 790          # Use an error code that maps to 400 (non-retryable)
 791          raise InvalidParameterValue(error_code="INVALID_PARAMETER_VALUE", message="Bad request")
 792  
 793      with mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client:
 794          mock_workspace_client = mock.MagicMock()
 795          mock_workspace_client.api_client.do = mock_do_non_retryable_error
 796          mock_get_workspace_client.return_value = mock_workspace_client
 797  
 798          response = http_request(host_creds, "/endpoint", "GET", max_retries=5)
 799  
 800          assert call_count == 1  # Should not retry on non-retryable error
 801          assert response.status_code == 400  # Should return 400 for INVALID_PARAMETER_VALUE
 802  
 803  
 804  def test_databricks_sdk_retry_backoff_calculation():
 805      from databricks.sdk.errors import DatabricksError
 806  
 807      from mlflow.utils.request_utils import _TRANSIENT_FAILURE_RESPONSE_CODES
 808      from mlflow.utils.rest_utils import _retry_databricks_sdk_call_with_exponential_backoff
 809  
 810      call_count = 0
 811  
 812      def mock_failing_call():
 813          nonlocal call_count
 814          call_count += 1
 815  
 816          raise DatabricksError(error_code="INTERNAL_ERROR", message="Mock error")
 817  
 818      with mock.patch("mlflow.utils.rest_utils._time_sleep") as mock_sleep:
 819          with pytest.raises(DatabricksError, match="Mock error"):
 820              _retry_databricks_sdk_call_with_exponential_backoff(
 821                  call_func=mock_failing_call,
 822                  retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES,
 823                  retry_timeout_seconds=10,
 824                  backoff_factor=1,  # Use 1 for predictable calculation
 825                  backoff_jitter=0,  # No jitter for predictable calculation
 826                  max_retries=3,
 827              )
 828  
 829      # Verify sleep was called with correct intervals
 830      # attempt 0 (1st retry): 0 seconds (immediate)
 831      # attempt 1 (2nd retry): 1 * (2^1) = 2 seconds
 832      # attempt 2 (3rd retry): 1 * (2^2) = 4 seconds
 833      expected_sleep_times = [0, 2, 4]
 834      actual_sleep_times = [call.args[0] for call in mock_sleep.call_args_list]
 835      assert actual_sleep_times == expected_sleep_times
 836      assert call_count == 4  # Initial + 3 retries
 837  
 838  
 839  @pytest.mark.skip
 840  def test_timeout_parameter_propagation_with_timeout():
 841      with (
 842          mock.patch("databricks.sdk.WorkspaceClient") as mock_workspace_client,
 843          mock.patch("databricks.sdk.config.Config") as mock_config,
 844      ):
 845          # Test http_request with timeout via get_workspace_client directly
 846          mock_workspace_client_instance = mock.MagicMock()
 847          mock_workspace_client_instance.api_client.do.return_value = {"contents": mock.MagicMock()}
 848          mock_workspace_client.return_value = mock_workspace_client_instance
 849  
 850          get_workspace_client(
 851              use_secret_scope_token=False,
 852              host="http://my-host",
 853              token=None,
 854              databricks_auth_profile="my-profile",
 855              retry_timeout_seconds=None,
 856              timeout=180,
 857          )
 858  
 859          mock_config.assert_called_once_with(
 860              profile="my-profile",
 861              http_timeout_seconds=180,
 862              retry_timeout_seconds=mock.ANY,
 863          )
 864  
 865  
 866  @pytest.mark.skip
 867  def test_timeout_parameter_propagation_without_timeout():
 868      with (
 869          mock.patch("databricks.sdk.WorkspaceClient") as mock_workspace_client,
 870          mock.patch("databricks.sdk.config.Config") as mock_config,
 871      ):
 872          # Test http_request without timeout via get_workspace_client directly
 873          mock_workspace_client_instance = mock.MagicMock()
 874          mock_workspace_client_instance.api_client.do.return_value = {"contents": mock.MagicMock()}
 875          mock_workspace_client.return_value = mock_workspace_client_instance
 876  
 877          get_workspace_client(
 878              use_secret_scope_token=False,
 879              host="http://my-host",
 880              token=None,
 881              databricks_auth_profile="my-profile",
 882              retry_timeout_seconds=None,
 883              timeout=None,
 884          )
 885  
 886          mock_config.assert_called_once_with(
 887              profile="my-profile",
 888              retry_timeout_seconds=mock.ANY,
 889          )
 890  
 891  
 892  def test_deployment_client_timeout_propagation(monkeypatch):
 893      with (
 894          mock.patch("mlflow.utils.rest_utils.get_workspace_client") as mock_get_workspace_client,
 895          mock.patch(
 896              "mlflow.utils.databricks_utils.get_databricks_host_creds"
 897          ) as mock_get_databricks_host_creds,
 898          mock.patch(
 899              "mlflow.deployments.databricks.get_databricks_host_creds"
 900          ) as mock_deployment_host_creds,
 901      ):
 902          # Mock the host creds to use Databricks SDK
 903          mock_host_creds = MlflowHostCreds("http://my-host", use_databricks_sdk=True)
 904          mock_get_databricks_host_creds.return_value = mock_host_creds
 905          mock_deployment_host_creds.return_value = mock_host_creds
 906  
 907          # Mock workspace client and its response
 908          mock_workspace_client_instance = mock.MagicMock()
 909          mock_workspace_client_instance.api_client.do.return_value = {"contents": mock.MagicMock()}
 910          mock_get_workspace_client.return_value = mock_workspace_client_instance
 911  
 912          # Set the environment variable to a custom value using monkeypatch
 913          monkeypatch.setenv("MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT", "300")
 914  
 915          # Create deployment client and call predict
 916          client = DatabricksDeploymentClient("databricks")
 917          client.predict(endpoint="test-endpoint", inputs={"test": "data"})
 918  
 919          # Verify get_workspace_client was called with the deployment predict timeout
 920          mock_get_workspace_client.assert_called_once_with(
 921              False,  # use_secret_scope_token
 922              "http://my-host",  # host
 923              None,  # token
 924              None,  # databricks_auth_profile
 925              retry_timeout_seconds=600,
 926              timeout=300,  # MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT value
 927          )
 928  
 929  
 930  def test_http_request_with_databricks_traffic_id(monkeypatch: pytest.MonkeyPatch):
 931      response = mock.MagicMock()
 932      response.status_code = 200
 933  
 934      # Test with env var set
 935      monkeypatch.setenv(_MLFLOW_DATABRICKS_TRAFFIC_ID.name, "test-traffic-id-12345")
 936      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 937          http_request(MlflowHostCreds("http://my-host"), "/my/endpoint", "GET")
 938          headers = mock_request.call_args.kwargs["headers"]
 939          assert headers["x-databricks-traffic-id"] == "test-traffic-id-12345"
 940  
 941      # Test without env var set
 942      monkeypatch.delenv(_MLFLOW_DATABRICKS_TRAFFIC_ID.name)
 943      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 944          http_request(MlflowHostCreds("http://my-host"), "/my/endpoint", "GET")
 945          headers = mock_request.call_args.kwargs["headers"]
 946          assert "x-databricks-traffic-id" not in headers
 947  
 948  
 949  def test_http_request_with_workspace_id():
 950      response = mock.MagicMock()
 951      response.status_code = 200
 952  
 953      # With workspace_id set, header should be included
 954      creds = MlflowHostCreds("http://my-host", workspace_id="6051921418418893")
 955      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 956          http_request(creds, "/my/endpoint", "GET")
 957          headers = mock_request.call_args.kwargs["headers"]
 958          assert headers["x-databricks-org-id"] == "6051921418418893"
 959  
 960      # Without workspace_id, header should not be present
 961      creds = MlflowHostCreds("http://my-host")
 962      with mock.patch("requests.Session.request", return_value=response) as mock_request:
 963          http_request(creds, "/my/endpoint", "GET")
 964          headers = mock_request.call_args.kwargs["headers"]
 965          assert "x-databricks-org-id" not in headers
 966  
 967  
 968  def test_mlflow_host_creds_workspace_id_equality():
 969      creds1 = MlflowHostCreds("http://my-host", workspace_id="123")
 970      creds2 = MlflowHostCreds("http://my-host", workspace_id="123")
 971      creds3 = MlflowHostCreds("http://my-host", workspace_id="456")
 972      creds4 = MlflowHostCreds("http://my-host")
 973  
 974      assert creds1 == creds2
 975      assert creds1 != creds3
 976      assert creds1 != creds4
 977      assert hash(creds1) == hash(creds2)
 978      assert hash(creds1) != hash(creds3)
 979  
 980  
 981  @pytest.mark.parametrize(
 982      ("timeout", "retry_timeout_seconds", "should_warn"),
 983      [
 984          (300, 120, True),
 985          (120, 600, False),
 986          (300, 300, False),
 987          (None, 120, False),
 988          (300, None, False),
 989          (None, None, False),
 990      ],
 991  )
 992  def test_validate_deployment_timeout_config(timeout, retry_timeout_seconds, should_warn):
 993      from mlflow.utils.rest_utils import validate_deployment_timeout_config
 994  
 995      if should_warn:
 996          with warnings.catch_warnings(record=True) as w:
 997              warnings.simplefilter("always")
 998              validate_deployment_timeout_config(
 999                  timeout=timeout, retry_timeout_seconds=retry_timeout_seconds
1000              )
1001              assert len(w) == 1
1002              warning_msg = str(w[0].message)
1003              assert "MLFLOW_DEPLOYMENT_PREDICT_TOTAL_TIMEOUT" in warning_msg
1004              assert f"({retry_timeout_seconds}s)" in warning_msg
1005              assert f"({timeout}s)" in warning_msg
1006      else:
1007          with warnings.catch_warnings(record=True) as w:
1008              warnings.simplefilter("always")
1009              validate_deployment_timeout_config(
1010                  timeout=timeout, retry_timeout_seconds=retry_timeout_seconds
1011              )
1012              assert len(w) == 0