request_utils.py
1 # DO NO IMPORT MLFLOW IN THIS FILE. 2 # This file is imported by download_cloud_file_chunk.py. 3 # Importing mlflow is time-consuming and we want to avoid that in artifact download subprocesses. 4 import os 5 import random 6 import socket 7 from functools import lru_cache 8 9 import requests 10 import urllib3 11 from packaging.version import Version 12 from requests.adapters import HTTPAdapter 13 from requests.exceptions import HTTPError 14 from urllib3.connection import HTTPConnection 15 from urllib3.util import Retry 16 17 # Response codes that generally indicate transient network failures and merit client retries, 18 # based on guidance from cloud service providers 19 # (https://docs.microsoft.com/en-us/azure/architecture/best-practices/retry-service-specific#general-rest-and-retry-guidelines) 20 _TRANSIENT_FAILURE_RESPONSE_CODES = frozenset([ 21 408, # Request Timeout 22 429, # Too Many Requests 23 500, # Internal Server Error 24 502, # Bad Gateway 25 503, # Service Unavailable 26 504, # Gateway Timeout 27 ]) 28 29 30 def _build_socket_options() -> list[tuple[int, int, int]]: 31 """Returns socket options with TCP keepalive enabled.""" 32 from mlflow.environment_variables import ( 33 MLFLOW_HTTP_TCP_KEEPALIVE, 34 MLFLOW_HTTP_TCP_KEEPALIVE_COUNT, 35 MLFLOW_HTTP_TCP_KEEPALIVE_IDLE, 36 MLFLOW_HTTP_TCP_KEEPALIVE_INTERVAL, 37 ) 38 39 socket_options = list(HTTPConnection.default_socket_options or []) 40 41 if not MLFLOW_HTTP_TCP_KEEPALIVE.get(): 42 return socket_options 43 44 socket_options.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)) 45 # TCP_KEEPIDLE (Linux) vs TCP_KEEPALIVE (macOS/BSD) for idle time before first probe 46 idle = MLFLOW_HTTP_TCP_KEEPALIVE_IDLE.get() 47 if hasattr(socket, "TCP_KEEPIDLE"): 48 socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, idle)) 49 elif hasattr(socket, "TCP_KEEPALIVE"): 50 socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, idle)) 51 52 interval = MLFLOW_HTTP_TCP_KEEPALIVE_INTERVAL.get() 53 if hasattr(socket, "TCP_KEEPINTVL"): 54 socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval)) 55 56 count = MLFLOW_HTTP_TCP_KEEPALIVE_COUNT.get() 57 if hasattr(socket, "TCP_KEEPCNT"): 58 socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, count)) 59 60 return socket_options 61 62 63 class TCPKeepAliveHTTPAdapter(HTTPAdapter): 64 """HTTPAdapter with TCP keepalive enabled to detect stale/dead connections faster.""" 65 66 def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): 67 pool_kwargs.setdefault("socket_options", _build_socket_options()) 68 super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs) 69 70 def proxy_manager_for(self, proxy, **proxy_kwargs): 71 proxy_kwargs.setdefault("socket_options", _build_socket_options()) 72 return super().proxy_manager_for(proxy, **proxy_kwargs) 73 74 75 class JitteredRetry(Retry): 76 """ 77 urllib3 < 2 doesn't support `backoff_jitter`. This class is a workaround for that. 78 """ 79 80 def __init__(self, *args, backoff_jitter=0.0, **kwargs): 81 super().__init__(*args, **kwargs) 82 self.backoff_jitter = backoff_jitter 83 84 def get_backoff_time(self): 85 """ 86 Source: https://github.com/urllib3/urllib3/commit/214b184923388328919b0a4b0c15bff603aa51be 87 """ 88 backoff_value = super().get_backoff_time() 89 if self.backoff_jitter != 0.0: 90 backoff_value += random.random() * self.backoff_jitter 91 # The attribute `BACKOFF_MAX` was renamed to `DEFAULT_BACKOFF_MAX` in this commit: 92 # https://github.com/urllib3/urllib3/commit/f69b1c89f885a74429cabdee2673e030b35979f0 93 # which was part of the major release of 2.0 for urllib3 and the support for both 94 # constants was added in 1.26.9: 95 # https://github.com/urllib3/urllib3/blob/1.26.9/src/urllib3/util/retry.py 96 default_backoff = ( 97 Retry.BACKOFF_MAX 98 if Version(urllib3.__version__) < Version("1.26.9") 99 else Retry.DEFAULT_BACKOFF_MAX 100 ) 101 102 return float(max(0, min(default_backoff, backoff_value))) 103 104 105 def augmented_raise_for_status(response): 106 """Wrap the standard `requests.response.raise_for_status()` method and return reason""" 107 try: 108 response.raise_for_status() 109 except HTTPError as e: 110 if response.text: 111 raise HTTPError( 112 f"{e}. Response text: {response.text}", request=e.request, response=e.response 113 ) 114 else: 115 raise e 116 117 118 def download_chunk(*, range_start, range_end, headers, download_path, http_uri): 119 combined_headers = {**headers, "Range": f"bytes={range_start}-{range_end}"} 120 121 with cloud_storage_http_request( 122 "get", 123 http_uri, 124 stream=False, 125 headers=combined_headers, 126 timeout=10, 127 ) as response: 128 expected_length = response.headers.get("Content-Length") 129 if expected_length is not None: 130 actual_length = response.raw.tell() 131 expected_length = int(expected_length) 132 if actual_length < expected_length: 133 raise IOError( 134 "Incomplete read ({} bytes read, {} more expected)".format( 135 actual_length, expected_length - actual_length 136 ) 137 ) 138 # File will have been created upstream. Use r+b to ensure chunks 139 # don't overwrite the entire file. 140 augmented_raise_for_status(response) 141 with open(download_path, "r+b") as f: 142 f.seek(range_start) 143 f.write(response.content) 144 145 146 @lru_cache(maxsize=64) 147 def _cached_get_request_session( 148 max_retries, 149 backoff_factor, 150 backoff_jitter, 151 retry_codes, 152 raise_on_status, 153 # To create a new Session object for each process, we use the process id as the cache key. 154 # This is to avoid sharing the same Session object across processes, which can lead to issues 155 # such as https://stackoverflow.com/q/3724900. 156 _pid, 157 respect_retry_after_header=True, 158 ): 159 """ 160 This function should not be called directly. Instead, use `_get_request_session` below. 161 """ 162 163 retry_kwargs = { 164 "total": max_retries, 165 "connect": max_retries, 166 "read": max_retries, 167 "redirect": max_retries, 168 "status": max_retries, 169 "status_forcelist": retry_codes, 170 "backoff_factor": backoff_factor, 171 "backoff_jitter": backoff_jitter, 172 "raise_on_status": raise_on_status, 173 "respect_retry_after_header": respect_retry_after_header, 174 } 175 urllib3_version = Version(urllib3.__version__) 176 if urllib3_version >= Version("1.26.0"): 177 retry_kwargs["allowed_methods"] = None 178 else: 179 retry_kwargs["method_whitelist"] = None 180 181 if urllib3_version < Version("2.0"): 182 retry = JitteredRetry(**retry_kwargs) 183 else: 184 retry = Retry(**retry_kwargs) 185 from mlflow.environment_variables import ( 186 MLFLOW_HTTP_POOL_CONNECTIONS, 187 MLFLOW_HTTP_POOL_MAXSIZE, 188 ) 189 190 adapter = TCPKeepAliveHTTPAdapter( 191 pool_connections=MLFLOW_HTTP_POOL_CONNECTIONS.get(), 192 pool_maxsize=MLFLOW_HTTP_POOL_MAXSIZE.get(), 193 max_retries=retry, 194 ) 195 session = requests.Session() 196 session.mount("https://", adapter) 197 session.mount("http://", adapter) 198 return session 199 200 201 def _get_request_session( 202 max_retries, 203 backoff_factor, 204 backoff_jitter, 205 retry_codes, 206 raise_on_status, 207 respect_retry_after_header, 208 ): 209 """Returns a `Requests.Session` object for making an HTTP request. 210 211 Args: 212 max_retries: Maximum total number of retries. 213 backoff_factor: A time factor for exponential backoff. e.g. value 5 means the HTTP 214 request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the 215 exponential backoff. 216 backoff_jitter: A random jitter to add to the backoff interval. 217 retry_codes: A list of HTTP response error codes that qualifies for retry. 218 raise_on_status: Whether to raise an exception, or return a response, if status falls 219 in retry_codes range and retries have been exhausted. 220 respect_retry_after_header: Whether to respect Retry-After header on status codes defined 221 as Retry.RETRY_AFTER_STATUS_CODES or not. 222 223 Returns: 224 requests.Session object. 225 226 """ 227 return _cached_get_request_session( 228 max_retries, 229 backoff_factor, 230 backoff_jitter, 231 retry_codes, 232 raise_on_status, 233 _pid=os.getpid(), 234 respect_retry_after_header=respect_retry_after_header, 235 ) 236 237 238 def _get_http_response_with_retries( 239 method, 240 url, 241 max_retries, 242 backoff_factor, 243 backoff_jitter, 244 retry_codes, 245 raise_on_status=True, 246 allow_redirects=None, 247 respect_retry_after_header=True, 248 **kwargs, 249 ): 250 """Performs an HTTP request using Python's `requests` module with an automatic retry policy. 251 252 Args: 253 method: A string indicating the method to use, e.g. "GET", "POST", "PUT". 254 url: The target URL address for the HTTP request. 255 max_retries: Maximum total number of retries. 256 backoff_factor: A time factor for exponential backoff. e.g. value 5 means the HTTP 257 request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the 258 exponential backoff. 259 backoff_jitter: A random jitter to add to the backoff interval. 260 retry_codes: A list of HTTP response error codes that qualifies for retry. 261 raise_on_status: Whether to raise an exception, or return a response, if status falls 262 in retry_codes range and retries have been exhausted. 263 kwargs: Additional keyword arguments to pass to `requests.Session.request()` 264 265 Returns: 266 requests.Response object. 267 """ 268 session = _get_request_session( 269 max_retries, 270 backoff_factor, 271 backoff_jitter, 272 retry_codes, 273 raise_on_status, 274 respect_retry_after_header, 275 ) 276 277 # the environment variable is hardcoded here to avoid importing mlflow. 278 # however, documentation is available in environment_variables.py 279 env_value = os.environ.get("MLFLOW_ALLOW_HTTP_REDIRECTS", "true").lower() in ["true", "1"] 280 allow_redirects = env_value if allow_redirects is None else allow_redirects 281 282 return session.request(method, url, allow_redirects=allow_redirects, **kwargs) 283 284 285 def cloud_storage_http_request( 286 method, 287 url, 288 max_retries=5, 289 backoff_factor=2, 290 backoff_jitter=1.0, 291 retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES, 292 timeout=None, 293 **kwargs, 294 ): 295 """Performs an HTTP PUT/GET/PATCH request using Python's `requests` module with automatic retry. 296 297 Args: 298 method: string of 'PUT' or 'GET' or 'PATCH', specify to do http PUT or GET or PATCH. 299 url: the target URL address for the HTTP request. 300 max_retries: maximum number of retries before throwing an exception. 301 backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP 302 request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the 303 exponential backoff. 304 backoff_jitter: A random jitter to add to the backoff interval. 305 retry_codes: a list of HTTP response error codes that qualifies for retry. 306 timeout: wait for timeout seconds for response from remote server for connect and 307 read request. Default to None owing to long duration operation in read / write. 308 kwargs: Additional keyword arguments to pass to `requests.Session.request()`. 309 310 Returns: 311 requests.Response object. 312 """ 313 if method.lower() not in ("put", "get", "patch", "delete"): 314 raise ValueError("Illegal http method: " + method) 315 return _get_http_response_with_retries( 316 method, 317 url, 318 max_retries, 319 backoff_factor, 320 backoff_jitter, 321 retry_codes, 322 timeout=timeout, 323 **kwargs, 324 )