/ mlflow / openai / api_request_parallel_processor.py
api_request_parallel_processor.py
  1  # Based ons: https://github.com/openai/openai-cookbook/blob/6df6ceff470eeba26a56de131254e775292eac22/examples/api_request_parallel_processor.py
  2  # Several changes were made to make it work with MLflow.
  3  
  4  """
  5  API REQUEST PARALLEL PROCESSOR
  6  
  7  Using the OpenAI API to process lots of text quickly takes some care.
  8  If you trickle in a million API requests one by one, they'll take days to complete.
  9  If you flood a million API requests in parallel, they'll exceed the rate limits and fail with
 10  errors. To maximize throughput, parallel requests need to be throttled to stay under rate limits.
 11  
 12  This script parallelizes requests to the OpenAI API
 13  
 14  Features:
 15  - Makes requests concurrently, to maximize throughput
 16  - Retries failed requests up to {max_attempts} times, to avoid missing data
 17  - Logs errors, to diagnose problems with requests
 18  """
 19  
 20  from __future__ import annotations
 21  
 22  import logging
 23  import threading
 24  from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait
 25  from dataclasses import dataclass
 26  from typing import Any, Callable
 27  
 28  import mlflow
 29  
 30  _logger = logging.getLogger(__name__)
 31  
 32  
 33  @dataclass
 34  class StatusTracker:
 35      """Stores metadata about the script's progress. Only one instance is created."""
 36  
 37      num_tasks_started: int = 0
 38      num_tasks_in_progress: int = 0  # script ends when this reaches 0
 39      num_tasks_succeeded: int = 0
 40      num_tasks_failed: int = 0
 41      num_rate_limit_errors: int = 0
 42      lock: threading.Lock = threading.Lock()
 43      error = None
 44  
 45      def start_task(self):
 46          with self.lock:
 47              self.num_tasks_started += 1
 48              self.num_tasks_in_progress += 1
 49  
 50      def complete_task(self, *, success: bool):
 51          with self.lock:
 52              self.num_tasks_in_progress -= 1
 53              if success:
 54                  self.num_tasks_succeeded += 1
 55              else:
 56                  self.num_tasks_failed += 1
 57  
 58      def increment_num_rate_limit_errors(self):
 59          with self.lock:
 60              self.num_rate_limit_errors += 1
 61  
 62  
 63  def call_api(
 64      index: int,
 65      results: list[tuple[int, Any]],
 66      task: Callable[[], Any],
 67      status_tracker: StatusTracker,
 68  ):
 69      import openai
 70  
 71      status_tracker.start_task()
 72      try:
 73          result = task()
 74          _logger.debug(f"Request #{index} succeeded")
 75          status_tracker.complete_task(success=True)
 76          results.append((index, result))
 77      except openai.RateLimitError as e:
 78          status_tracker.complete_task(success=False)
 79          _logger.debug(f"Request #{index} failed with: {e}")
 80          status_tracker.increment_num_rate_limit_errors()
 81          status_tracker.error = mlflow.MlflowException(
 82              f"Request #{index} failed with rate limit: {e}."
 83          )
 84      except Exception as e:
 85          status_tracker.complete_task(success=False)
 86          _logger.debug(f"Request #{index} failed with: {e}")
 87          status_tracker.error = mlflow.MlflowException(
 88              f"Request #{index} failed with: {e.__cause__}"
 89          )
 90  
 91  
 92  def process_api_requests(
 93      request_tasks: list[Callable[[], Any]],
 94      max_workers: int = 10,
 95  ):
 96      """Processes API requests in parallel"""
 97      # initialize trackers
 98      status_tracker = StatusTracker()  # single instance to track a collection of variables
 99  
100      results: list[tuple[int, Any]] = []
101      request_tasks_iter = enumerate(request_tasks)
102      _logger.debug(f"Request pool executor will run {len(request_tasks)} requests")
103      with ThreadPoolExecutor(
104          max_workers=max_workers, thread_name_prefix="MlflowOpenAiApi"
105      ) as executor:
106          futures = [
107              executor.submit(
108                  call_api,
109                  index=index,
110                  task=task,
111                  results=results,
112                  status_tracker=status_tracker,
113              )
114              for index, task in request_tasks_iter
115          ]
116          wait(futures, return_when=FIRST_EXCEPTION)
117  
118      # after finishing, log final status
119      if status_tracker.num_tasks_failed > 0:
120          if status_tracker.num_tasks_failed == 1:
121              raise status_tracker.error
122          raise mlflow.MlflowException(
123              f"{status_tracker.num_tasks_failed} tasks failed. See logs for details."
124          )
125      if status_tracker.num_rate_limit_errors > 0:
126          _logger.debug(
127              f"{status_tracker.num_rate_limit_errors} rate limit errors received. "
128              "Consider running at a lower rate."
129          )
130  
131      return [res for _, res in sorted(results)]