/ mlflow / utils / request_utils.py
request_utils.py
  1  # DO NO IMPORT MLFLOW IN THIS FILE.
  2  # This file is imported by download_cloud_file_chunk.py.
  3  # Importing mlflow is time-consuming and we want to avoid that in artifact download subprocesses.
  4  import os
  5  import random
  6  import socket
  7  from functools import lru_cache
  8  
  9  import requests
 10  import urllib3
 11  from packaging.version import Version
 12  from requests.adapters import HTTPAdapter
 13  from requests.exceptions import HTTPError
 14  from urllib3.connection import HTTPConnection
 15  from urllib3.util import Retry
 16  
 17  # Response codes that generally indicate transient network failures and merit client retries,
 18  # based on guidance from cloud service providers
 19  # (https://docs.microsoft.com/en-us/azure/architecture/best-practices/retry-service-specific#general-rest-and-retry-guidelines)
 20  _TRANSIENT_FAILURE_RESPONSE_CODES = frozenset([
 21      408,  # Request Timeout
 22      429,  # Too Many Requests
 23      500,  # Internal Server Error
 24      502,  # Bad Gateway
 25      503,  # Service Unavailable
 26      504,  # Gateway Timeout
 27  ])
 28  
 29  
 30  def _build_socket_options() -> list[tuple[int, int, int]]:
 31      """Returns socket options with TCP keepalive enabled."""
 32      from mlflow.environment_variables import (
 33          MLFLOW_HTTP_TCP_KEEPALIVE,
 34          MLFLOW_HTTP_TCP_KEEPALIVE_COUNT,
 35          MLFLOW_HTTP_TCP_KEEPALIVE_IDLE,
 36          MLFLOW_HTTP_TCP_KEEPALIVE_INTERVAL,
 37      )
 38  
 39      socket_options = list(HTTPConnection.default_socket_options or [])
 40  
 41      if not MLFLOW_HTTP_TCP_KEEPALIVE.get():
 42          return socket_options
 43  
 44      socket_options.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
 45      # TCP_KEEPIDLE (Linux) vs TCP_KEEPALIVE (macOS/BSD) for idle time before first probe
 46      idle = MLFLOW_HTTP_TCP_KEEPALIVE_IDLE.get()
 47      if hasattr(socket, "TCP_KEEPIDLE"):
 48          socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, idle))
 49      elif hasattr(socket, "TCP_KEEPALIVE"):
 50          socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, idle))
 51  
 52      interval = MLFLOW_HTTP_TCP_KEEPALIVE_INTERVAL.get()
 53      if hasattr(socket, "TCP_KEEPINTVL"):
 54          socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval))
 55  
 56      count = MLFLOW_HTTP_TCP_KEEPALIVE_COUNT.get()
 57      if hasattr(socket, "TCP_KEEPCNT"):
 58          socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, count))
 59  
 60      return socket_options
 61  
 62  
 63  class TCPKeepAliveHTTPAdapter(HTTPAdapter):
 64      """HTTPAdapter with TCP keepalive enabled to detect stale/dead connections faster."""
 65  
 66      def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
 67          pool_kwargs.setdefault("socket_options", _build_socket_options())
 68          super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs)
 69  
 70      def proxy_manager_for(self, proxy, **proxy_kwargs):
 71          proxy_kwargs.setdefault("socket_options", _build_socket_options())
 72          return super().proxy_manager_for(proxy, **proxy_kwargs)
 73  
 74  
 75  class JitteredRetry(Retry):
 76      """
 77      urllib3 < 2 doesn't support `backoff_jitter`. This class is a workaround for that.
 78      """
 79  
 80      def __init__(self, *args, backoff_jitter=0.0, **kwargs):
 81          super().__init__(*args, **kwargs)
 82          self.backoff_jitter = backoff_jitter
 83  
 84      def get_backoff_time(self):
 85          """
 86          Source: https://github.com/urllib3/urllib3/commit/214b184923388328919b0a4b0c15bff603aa51be
 87          """
 88          backoff_value = super().get_backoff_time()
 89          if self.backoff_jitter != 0.0:
 90              backoff_value += random.random() * self.backoff_jitter
 91          # The attribute `BACKOFF_MAX` was renamed to `DEFAULT_BACKOFF_MAX` in this commit:
 92          # https://github.com/urllib3/urllib3/commit/f69b1c89f885a74429cabdee2673e030b35979f0
 93          # which was part of the major release of 2.0 for urllib3 and the support for both
 94          # constants was added in 1.26.9:
 95          # https://github.com/urllib3/urllib3/blob/1.26.9/src/urllib3/util/retry.py
 96          default_backoff = (
 97              Retry.BACKOFF_MAX
 98              if Version(urllib3.__version__) < Version("1.26.9")
 99              else Retry.DEFAULT_BACKOFF_MAX
100          )
101  
102          return float(max(0, min(default_backoff, backoff_value)))
103  
104  
105  def augmented_raise_for_status(response):
106      """Wrap the standard `requests.response.raise_for_status()` method and return reason"""
107      try:
108          response.raise_for_status()
109      except HTTPError as e:
110          if response.text:
111              raise HTTPError(
112                  f"{e}. Response text: {response.text}", request=e.request, response=e.response
113              )
114          else:
115              raise e
116  
117  
118  def download_chunk(*, range_start, range_end, headers, download_path, http_uri):
119      combined_headers = {**headers, "Range": f"bytes={range_start}-{range_end}"}
120  
121      with cloud_storage_http_request(
122          "get",
123          http_uri,
124          stream=False,
125          headers=combined_headers,
126          timeout=10,
127      ) as response:
128          expected_length = response.headers.get("Content-Length")
129          if expected_length is not None:
130              actual_length = response.raw.tell()
131              expected_length = int(expected_length)
132              if actual_length < expected_length:
133                  raise IOError(
134                      "Incomplete read ({} bytes read, {} more expected)".format(
135                          actual_length, expected_length - actual_length
136                      )
137                  )
138          # File will have been created upstream. Use r+b to ensure chunks
139          # don't overwrite the entire file.
140          augmented_raise_for_status(response)
141          with open(download_path, "r+b") as f:
142              f.seek(range_start)
143              f.write(response.content)
144  
145  
146  @lru_cache(maxsize=64)
147  def _cached_get_request_session(
148      max_retries,
149      backoff_factor,
150      backoff_jitter,
151      retry_codes,
152      raise_on_status,
153      # To create a new Session object for each process, we use the process id as the cache key.
154      # This is to avoid sharing the same Session object across processes, which can lead to issues
155      # such as https://stackoverflow.com/q/3724900.
156      _pid,
157      respect_retry_after_header=True,
158  ):
159      """
160      This function should not be called directly. Instead, use `_get_request_session` below.
161      """
162  
163      retry_kwargs = {
164          "total": max_retries,
165          "connect": max_retries,
166          "read": max_retries,
167          "redirect": max_retries,
168          "status": max_retries,
169          "status_forcelist": retry_codes,
170          "backoff_factor": backoff_factor,
171          "backoff_jitter": backoff_jitter,
172          "raise_on_status": raise_on_status,
173          "respect_retry_after_header": respect_retry_after_header,
174      }
175      urllib3_version = Version(urllib3.__version__)
176      if urllib3_version >= Version("1.26.0"):
177          retry_kwargs["allowed_methods"] = None
178      else:
179          retry_kwargs["method_whitelist"] = None
180  
181      if urllib3_version < Version("2.0"):
182          retry = JitteredRetry(**retry_kwargs)
183      else:
184          retry = Retry(**retry_kwargs)
185      from mlflow.environment_variables import (
186          MLFLOW_HTTP_POOL_CONNECTIONS,
187          MLFLOW_HTTP_POOL_MAXSIZE,
188      )
189  
190      adapter = TCPKeepAliveHTTPAdapter(
191          pool_connections=MLFLOW_HTTP_POOL_CONNECTIONS.get(),
192          pool_maxsize=MLFLOW_HTTP_POOL_MAXSIZE.get(),
193          max_retries=retry,
194      )
195      session = requests.Session()
196      session.mount("https://", adapter)
197      session.mount("http://", adapter)
198      return session
199  
200  
201  def _get_request_session(
202      max_retries,
203      backoff_factor,
204      backoff_jitter,
205      retry_codes,
206      raise_on_status,
207      respect_retry_after_header,
208  ):
209      """Returns a `Requests.Session` object for making an HTTP request.
210  
211      Args:
212          max_retries: Maximum total number of retries.
213          backoff_factor: A time factor for exponential backoff. e.g. value 5 means the HTTP
214              request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
215              exponential backoff.
216          backoff_jitter: A random jitter to add to the backoff interval.
217          retry_codes: A list of HTTP response error codes that qualifies for retry.
218          raise_on_status: Whether to raise an exception, or return a response, if status falls
219              in retry_codes range and retries have been exhausted.
220          respect_retry_after_header: Whether to respect Retry-After header on status codes defined
221              as Retry.RETRY_AFTER_STATUS_CODES or not.
222  
223      Returns:
224          requests.Session object.
225  
226      """
227      return _cached_get_request_session(
228          max_retries,
229          backoff_factor,
230          backoff_jitter,
231          retry_codes,
232          raise_on_status,
233          _pid=os.getpid(),
234          respect_retry_after_header=respect_retry_after_header,
235      )
236  
237  
238  def _get_http_response_with_retries(
239      method,
240      url,
241      max_retries,
242      backoff_factor,
243      backoff_jitter,
244      retry_codes,
245      raise_on_status=True,
246      allow_redirects=None,
247      respect_retry_after_header=True,
248      **kwargs,
249  ):
250      """Performs an HTTP request using Python's `requests` module with an automatic retry policy.
251  
252      Args:
253          method: A string indicating the method to use, e.g. "GET", "POST", "PUT".
254          url: The target URL address for the HTTP request.
255          max_retries: Maximum total number of retries.
256          backoff_factor: A time factor for exponential backoff. e.g. value 5 means the HTTP
257              request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
258              exponential backoff.
259          backoff_jitter: A random jitter to add to the backoff interval.
260          retry_codes: A list of HTTP response error codes that qualifies for retry.
261          raise_on_status: Whether to raise an exception, or return a response, if status falls
262              in retry_codes range and retries have been exhausted.
263          kwargs: Additional keyword arguments to pass to `requests.Session.request()`
264  
265      Returns:
266          requests.Response object.
267      """
268      session = _get_request_session(
269          max_retries,
270          backoff_factor,
271          backoff_jitter,
272          retry_codes,
273          raise_on_status,
274          respect_retry_after_header,
275      )
276  
277      # the environment variable is hardcoded here to avoid importing mlflow.
278      # however, documentation is available in environment_variables.py
279      env_value = os.environ.get("MLFLOW_ALLOW_HTTP_REDIRECTS", "true").lower() in ["true", "1"]
280      allow_redirects = env_value if allow_redirects is None else allow_redirects
281  
282      return session.request(method, url, allow_redirects=allow_redirects, **kwargs)
283  
284  
285  def cloud_storage_http_request(
286      method,
287      url,
288      max_retries=5,
289      backoff_factor=2,
290      backoff_jitter=1.0,
291      retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES,
292      timeout=None,
293      **kwargs,
294  ):
295      """Performs an HTTP PUT/GET/PATCH request using Python's `requests` module with automatic retry.
296  
297      Args:
298          method: string of 'PUT' or 'GET' or 'PATCH', specify to do http PUT or GET or PATCH.
299          url: the target URL address for the HTTP request.
300          max_retries: maximum number of retries before throwing an exception.
301          backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
302              request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
303              exponential backoff.
304          backoff_jitter: A random jitter to add to the backoff interval.
305          retry_codes: a list of HTTP response error codes that qualifies for retry.
306          timeout: wait for timeout seconds for response from remote server for connect and
307              read request. Default to None owing to long duration operation in read / write.
308          kwargs: Additional keyword arguments to pass to `requests.Session.request()`.
309  
310      Returns:
311          requests.Response object.
312      """
313      if method.lower() not in ("put", "get", "patch", "delete"):
314          raise ValueError("Illegal http method: " + method)
315      return _get_http_response_with_retries(
316          method,
317          url,
318          max_retries,
319          backoff_factor,
320          backoff_jitter,
321          retry_codes,
322          timeout=timeout,
323          **kwargs,
324      )