/ tests / utils / test_request_utils.py
test_request_utils.py
  1  import socket
  2  import subprocess
  3  import sys
  4  from unittest import mock
  5  
  6  import pytest
  7  from requests.adapters import HTTPAdapter
  8  
  9  from mlflow.utils import request_utils
 10  from mlflow.utils.request_utils import (
 11      TCPKeepAliveHTTPAdapter,
 12      _build_socket_options,
 13  )
 14  
 15  
 16  def test_request_utils_does_not_import_mlflow(tmp_path):
 17      file_content = f"""
 18  import importlib.util
 19  import os
 20  import sys
 21  
 22  file_path = r"{request_utils.__file__}"
 23  module_name = "mlflow.utils.request_utils"
 24  
 25  spec = importlib.util.spec_from_file_location(module_name, file_path)
 26  module = importlib.util.module_from_spec(spec)
 27  sys.modules[module_name] = module
 28  spec.loader.exec_module(module)
 29  
 30  assert "mlflow" not in sys.modules
 31  assert "mlflow.utils.request_utils" in sys.modules
 32  """
 33      test_file = tmp_path.joinpath("test_request_utils_does_not_import_mlflow.py")
 34      test_file.write_text(file_content)
 35  
 36      subprocess.check_call([sys.executable, str(test_file)])
 37  
 38  
 39  class IncompleteResponse:
 40      def __init__(self):
 41          self.headers = {"Content-Length": "100"}
 42          raw = mock.MagicMock()
 43          raw.tell.return_value = 50
 44          self.raw = raw
 45  
 46      def __enter__(self):
 47          return self
 48  
 49      def __exit__(self, *args):
 50          pass
 51  
 52  
 53  def test_download_chunk_incomplete_read(tmp_path):
 54      with mock.patch.object(
 55          request_utils, "cloud_storage_http_request", return_value=IncompleteResponse()
 56      ):
 57          download_path = tmp_path / "chunk"
 58          download_path.touch()
 59          with pytest.raises(IOError, match="Incomplete read"):
 60              request_utils.download_chunk(
 61                  range_start=0,
 62                  range_end=999,
 63                  headers={},
 64                  download_path=download_path,
 65                  http_uri="https://example.com",
 66              )
 67  
 68  
 69  @pytest.mark.parametrize("env_value", ["0", "false", "False", "FALSE"])
 70  def test_redirects_disabled_if_env_var_set(monkeypatch, env_value):
 71      monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)
 72  
 73      with mock.patch("requests.Session.request") as mock_request:
 74          mock_request.return_value.status_code = 302
 75          mock_request.return_value.text = "mock response"
 76  
 77          response = request_utils.cloud_storage_http_request("GET", "http://localhost:5000")
 78  
 79          assert response.text == "mock response"
 80          mock_request.assert_called_once_with(
 81              "GET",
 82              "http://localhost:5000",
 83              allow_redirects=False,
 84              timeout=None,
 85          )
 86  
 87  
 88  @pytest.mark.parametrize("env_value", ["1", "true", "True", "TRUE"])
 89  def test_redirects_enabled_if_env_var_set(monkeypatch, env_value):
 90      monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)
 91  
 92      with mock.patch("requests.Session.request") as mock_request:
 93          mock_request.return_value.status_code = 302
 94          mock_request.return_value.text = "mock response"
 95  
 96          response = request_utils.cloud_storage_http_request(
 97              "GET",
 98              "http://localhost:5000",
 99          )
100  
101          assert response.text == "mock response"
102          mock_request.assert_called_once_with(
103              "GET",
104              "http://localhost:5000",
105              allow_redirects=True,
106              timeout=None,
107          )
108  
109  
110  @pytest.mark.parametrize("env_value", ["0", "false", "False", "FALSE"])
111  def test_redirect_kwarg_overrides_env_value_false(monkeypatch, env_value):
112      monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)
113  
114      with mock.patch("requests.Session.request") as mock_request:
115          mock_request.return_value.status_code = 302
116          mock_request.return_value.text = "mock response"
117  
118          response = request_utils.cloud_storage_http_request(
119              "GET", "http://localhost:5000", allow_redirects=True
120          )
121  
122          assert response.text == "mock response"
123          mock_request.assert_called_once_with(
124              "GET",
125              "http://localhost:5000",
126              allow_redirects=True,
127              timeout=None,
128          )
129  
130  
131  @pytest.mark.parametrize("env_value", ["1", "true", "True", "TRUE"])
132  def test_redirect_kwarg_overrides_env_value_true(monkeypatch, env_value):
133      monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)
134  
135      with mock.patch("requests.Session.request") as mock_request:
136          mock_request.return_value.status_code = 302
137          mock_request.return_value.text = "mock response"
138  
139          response = request_utils.cloud_storage_http_request(
140              "GET", "http://localhost:5000", allow_redirects=False
141          )
142  
143          assert response.text == "mock response"
144          mock_request.assert_called_once_with(
145              "GET",
146              "http://localhost:5000",
147              allow_redirects=False,
148              timeout=None,
149          )
150  
151  
152  def test_redirects_enabled_by_default():
153      with mock.patch("requests.Session.request") as mock_request:
154          mock_request.return_value.status_code = 302
155          mock_request.return_value.text = "mock response"
156  
157          response = request_utils.cloud_storage_http_request(
158              "GET",
159              "http://localhost:5000",
160          )
161  
162          assert response.text == "mock response"
163          mock_request.assert_called_once_with(
164              "GET",
165              "http://localhost:5000",
166              allow_redirects=True,
167              timeout=None,
168          )
169  
170  
171  # --- TCP Keepalive tests ---
172  
173  
174  def test_build_socket_options_includes_keepalive():
175      options = _build_socket_options()
176      assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in options
177  
178  
179  def test_build_socket_options_platform_specific():
180      options = _build_socket_options()
181      if hasattr(socket, "TCP_KEEPIDLE"):
182          assert (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) in options
183      elif hasattr(socket, "TCP_KEEPALIVE"):
184          assert (socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 30) in options
185      if hasattr(socket, "TCP_KEEPINTVL"):
186          assert (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10) in options
187      if hasattr(socket, "TCP_KEEPCNT"):
188          assert (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) in options
189  
190  
191  def test_build_socket_options_disabled_via_env(monkeypatch):
192      monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE", "false")
193      options = _build_socket_options()
194      assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) not in options
195  
196  
197  def test_build_socket_options_custom_values_via_env(monkeypatch):
198      monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE_IDLE", "60")
199      monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE_INTERVAL", "20")
200      monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE_COUNT", "5")
201      options = _build_socket_options()
202      if hasattr(socket, "TCP_KEEPIDLE"):
203          assert (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60) in options
204      elif hasattr(socket, "TCP_KEEPALIVE"):
205          assert (socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 60) in options
206      if hasattr(socket, "TCP_KEEPINTVL"):
207          assert (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 20) in options
208      if hasattr(socket, "TCP_KEEPCNT"):
209          assert (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) in options
210  
211  
212  def test_tcp_keepalive_adapter_init_poolmanager():
213      adapter = TCPKeepAliveHTTPAdapter()
214      with mock.patch.object(HTTPAdapter, "init_poolmanager") as mock_init:
215          adapter.init_poolmanager(1, 1)
216          mock_init.assert_called_once()
217          _, kwargs = mock_init.call_args
218          assert "socket_options" in kwargs
219          assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in kwargs["socket_options"]
220  
221  
222  def test_tcp_keepalive_adapter_proxy_manager_for():
223      adapter = TCPKeepAliveHTTPAdapter()
224      with mock.patch.object(HTTPAdapter, "proxy_manager_for") as mock_proxy:
225          adapter.proxy_manager_for("http://proxy:8080")
226          mock_proxy.assert_called_once()
227          _, kwargs = mock_proxy.call_args
228          assert "socket_options" in kwargs
229          assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in kwargs["socket_options"]
230  
231  
232  def test_tcp_keepalive_adapter_proxy_respects_explicit_options():
233      adapter = TCPKeepAliveHTTPAdapter()
234      custom_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0)]
235      with mock.patch.object(HTTPAdapter, "proxy_manager_for") as mock_proxy:
236          adapter.proxy_manager_for("http://proxy:8080", socket_options=custom_options)
237          _, kwargs = mock_proxy.call_args
238          assert kwargs["socket_options"] == custom_options
239  
240  
241  def test_session_uses_tcp_keepalive_adapter():
242      request_utils._cached_get_request_session.cache_clear()
243      session = request_utils._get_request_session(
244          max_retries=3,
245          backoff_factor=1,
246          backoff_jitter=0.5,
247          retry_codes=(500,),
248          raise_on_status=True,
249          respect_retry_after_header=True,
250      )
251      assert isinstance(session.get_adapter("https://example.com"), TCPKeepAliveHTTPAdapter)
252      assert isinstance(session.get_adapter("http://example.com"), TCPKeepAliveHTTPAdapter)
253      request_utils._cached_get_request_session.cache_clear()