test_request_utils.py
1 import socket 2 import subprocess 3 import sys 4 from unittest import mock 5 6 import pytest 7 from requests.adapters import HTTPAdapter 8 9 from mlflow.utils import request_utils 10 from mlflow.utils.request_utils import ( 11 TCPKeepAliveHTTPAdapter, 12 _build_socket_options, 13 ) 14 15 16 def test_request_utils_does_not_import_mlflow(tmp_path): 17 file_content = f""" 18 import importlib.util 19 import os 20 import sys 21 22 file_path = r"{request_utils.__file__}" 23 module_name = "mlflow.utils.request_utils" 24 25 spec = importlib.util.spec_from_file_location(module_name, file_path) 26 module = importlib.util.module_from_spec(spec) 27 sys.modules[module_name] = module 28 spec.loader.exec_module(module) 29 30 assert "mlflow" not in sys.modules 31 assert "mlflow.utils.request_utils" in sys.modules 32 """ 33 test_file = tmp_path.joinpath("test_request_utils_does_not_import_mlflow.py") 34 test_file.write_text(file_content) 35 36 subprocess.check_call([sys.executable, str(test_file)]) 37 38 39 class IncompleteResponse: 40 def __init__(self): 41 self.headers = {"Content-Length": "100"} 42 raw = mock.MagicMock() 43 raw.tell.return_value = 50 44 self.raw = raw 45 46 def __enter__(self): 47 return self 48 49 def __exit__(self, *args): 50 pass 51 52 53 def test_download_chunk_incomplete_read(tmp_path): 54 with mock.patch.object( 55 request_utils, "cloud_storage_http_request", return_value=IncompleteResponse() 56 ): 57 download_path = tmp_path / "chunk" 58 download_path.touch() 59 with pytest.raises(IOError, match="Incomplete read"): 60 request_utils.download_chunk( 61 range_start=0, 62 range_end=999, 63 headers={}, 64 download_path=download_path, 65 http_uri="https://example.com", 66 ) 67 68 69 @pytest.mark.parametrize("env_value", ["0", "false", "False", "FALSE"]) 70 def test_redirects_disabled_if_env_var_set(monkeypatch, env_value): 71 monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value) 72 73 with mock.patch("requests.Session.request") as mock_request: 74 mock_request.return_value.status_code = 302 75 mock_request.return_value.text = "mock response" 76 77 response = request_utils.cloud_storage_http_request("GET", "http://localhost:5000") 78 79 assert response.text == "mock response" 80 mock_request.assert_called_once_with( 81 "GET", 82 "http://localhost:5000", 83 allow_redirects=False, 84 timeout=None, 85 ) 86 87 88 @pytest.mark.parametrize("env_value", ["1", "true", "True", "TRUE"]) 89 def test_redirects_enabled_if_env_var_set(monkeypatch, env_value): 90 monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value) 91 92 with mock.patch("requests.Session.request") as mock_request: 93 mock_request.return_value.status_code = 302 94 mock_request.return_value.text = "mock response" 95 96 response = request_utils.cloud_storage_http_request( 97 "GET", 98 "http://localhost:5000", 99 ) 100 101 assert response.text == "mock response" 102 mock_request.assert_called_once_with( 103 "GET", 104 "http://localhost:5000", 105 allow_redirects=True, 106 timeout=None, 107 ) 108 109 110 @pytest.mark.parametrize("env_value", ["0", "false", "False", "FALSE"]) 111 def test_redirect_kwarg_overrides_env_value_false(monkeypatch, env_value): 112 monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value) 113 114 with mock.patch("requests.Session.request") as mock_request: 115 mock_request.return_value.status_code = 302 116 mock_request.return_value.text = "mock response" 117 118 response = request_utils.cloud_storage_http_request( 119 "GET", "http://localhost:5000", allow_redirects=True 120 ) 121 122 assert response.text == "mock response" 123 mock_request.assert_called_once_with( 124 "GET", 125 "http://localhost:5000", 126 allow_redirects=True, 127 timeout=None, 128 ) 129 130 131 @pytest.mark.parametrize("env_value", ["1", "true", "True", "TRUE"]) 132 def test_redirect_kwarg_overrides_env_value_true(monkeypatch, env_value): 133 monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value) 134 135 with mock.patch("requests.Session.request") as mock_request: 136 mock_request.return_value.status_code = 302 137 mock_request.return_value.text = "mock response" 138 139 response = request_utils.cloud_storage_http_request( 140 "GET", "http://localhost:5000", allow_redirects=False 141 ) 142 143 assert response.text == "mock response" 144 mock_request.assert_called_once_with( 145 "GET", 146 "http://localhost:5000", 147 allow_redirects=False, 148 timeout=None, 149 ) 150 151 152 def test_redirects_enabled_by_default(): 153 with mock.patch("requests.Session.request") as mock_request: 154 mock_request.return_value.status_code = 302 155 mock_request.return_value.text = "mock response" 156 157 response = request_utils.cloud_storage_http_request( 158 "GET", 159 "http://localhost:5000", 160 ) 161 162 assert response.text == "mock response" 163 mock_request.assert_called_once_with( 164 "GET", 165 "http://localhost:5000", 166 allow_redirects=True, 167 timeout=None, 168 ) 169 170 171 # --- TCP Keepalive tests --- 172 173 174 def test_build_socket_options_includes_keepalive(): 175 options = _build_socket_options() 176 assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in options 177 178 179 def test_build_socket_options_platform_specific(): 180 options = _build_socket_options() 181 if hasattr(socket, "TCP_KEEPIDLE"): 182 assert (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) in options 183 elif hasattr(socket, "TCP_KEEPALIVE"): 184 assert (socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 30) in options 185 if hasattr(socket, "TCP_KEEPINTVL"): 186 assert (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10) in options 187 if hasattr(socket, "TCP_KEEPCNT"): 188 assert (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) in options 189 190 191 def test_build_socket_options_disabled_via_env(monkeypatch): 192 monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE", "false") 193 options = _build_socket_options() 194 assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) not in options 195 196 197 def test_build_socket_options_custom_values_via_env(monkeypatch): 198 monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE_IDLE", "60") 199 monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE_INTERVAL", "20") 200 monkeypatch.setenv("MLFLOW_HTTP_TCP_KEEPALIVE_COUNT", "5") 201 options = _build_socket_options() 202 if hasattr(socket, "TCP_KEEPIDLE"): 203 assert (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60) in options 204 elif hasattr(socket, "TCP_KEEPALIVE"): 205 assert (socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 60) in options 206 if hasattr(socket, "TCP_KEEPINTVL"): 207 assert (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 20) in options 208 if hasattr(socket, "TCP_KEEPCNT"): 209 assert (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5) in options 210 211 212 def test_tcp_keepalive_adapter_init_poolmanager(): 213 adapter = TCPKeepAliveHTTPAdapter() 214 with mock.patch.object(HTTPAdapter, "init_poolmanager") as mock_init: 215 adapter.init_poolmanager(1, 1) 216 mock_init.assert_called_once() 217 _, kwargs = mock_init.call_args 218 assert "socket_options" in kwargs 219 assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in kwargs["socket_options"] 220 221 222 def test_tcp_keepalive_adapter_proxy_manager_for(): 223 adapter = TCPKeepAliveHTTPAdapter() 224 with mock.patch.object(HTTPAdapter, "proxy_manager_for") as mock_proxy: 225 adapter.proxy_manager_for("http://proxy:8080") 226 mock_proxy.assert_called_once() 227 _, kwargs = mock_proxy.call_args 228 assert "socket_options" in kwargs 229 assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in kwargs["socket_options"] 230 231 232 def test_tcp_keepalive_adapter_proxy_respects_explicit_options(): 233 adapter = TCPKeepAliveHTTPAdapter() 234 custom_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0)] 235 with mock.patch.object(HTTPAdapter, "proxy_manager_for") as mock_proxy: 236 adapter.proxy_manager_for("http://proxy:8080", socket_options=custom_options) 237 _, kwargs = mock_proxy.call_args 238 assert kwargs["socket_options"] == custom_options 239 240 241 def test_session_uses_tcp_keepalive_adapter(): 242 request_utils._cached_get_request_session.cache_clear() 243 session = request_utils._get_request_session( 244 max_retries=3, 245 backoff_factor=1, 246 backoff_jitter=0.5, 247 retry_codes=(500,), 248 raise_on_status=True, 249 respect_retry_after_header=True, 250 ) 251 assert isinstance(session.get_adapter("https://example.com"), TCPKeepAliveHTTPAdapter) 252 assert isinstance(session.get_adapter("http://example.com"), TCPKeepAliveHTTPAdapter) 253 request_utils._cached_get_request_session.cache_clear()