api_request_parallel_processor.py
1 # Based ons: https://github.com/openai/openai-cookbook/blob/6df6ceff470eeba26a56de131254e775292eac22/examples/api_request_parallel_processor.py 2 # Several changes were made to make it work with MLflow. 3 4 """ 5 API REQUEST PARALLEL PROCESSOR 6 7 Using the OpenAI API to process lots of text quickly takes some care. 8 If you trickle in a million API requests one by one, they'll take days to complete. 9 If you flood a million API requests in parallel, they'll exceed the rate limits and fail with 10 errors. To maximize throughput, parallel requests need to be throttled to stay under rate limits. 11 12 This script parallelizes requests to the OpenAI API 13 14 Features: 15 - Makes requests concurrently, to maximize throughput 16 - Retries failed requests up to {max_attempts} times, to avoid missing data 17 - Logs errors, to diagnose problems with requests 18 """ 19 20 from __future__ import annotations 21 22 import logging 23 import threading 24 from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait 25 from dataclasses import dataclass 26 from typing import Any, Callable 27 28 import mlflow 29 30 _logger = logging.getLogger(__name__) 31 32 33 @dataclass 34 class StatusTracker: 35 """Stores metadata about the script's progress. Only one instance is created.""" 36 37 num_tasks_started: int = 0 38 num_tasks_in_progress: int = 0 # script ends when this reaches 0 39 num_tasks_succeeded: int = 0 40 num_tasks_failed: int = 0 41 num_rate_limit_errors: int = 0 42 lock: threading.Lock = threading.Lock() 43 error = None 44 45 def start_task(self): 46 with self.lock: 47 self.num_tasks_started += 1 48 self.num_tasks_in_progress += 1 49 50 def complete_task(self, *, success: bool): 51 with self.lock: 52 self.num_tasks_in_progress -= 1 53 if success: 54 self.num_tasks_succeeded += 1 55 else: 56 self.num_tasks_failed += 1 57 58 def increment_num_rate_limit_errors(self): 59 with self.lock: 60 self.num_rate_limit_errors += 1 61 62 63 def call_api( 64 index: int, 65 results: list[tuple[int, Any]], 66 task: Callable[[], Any], 67 status_tracker: StatusTracker, 68 ): 69 import openai 70 71 status_tracker.start_task() 72 try: 73 result = task() 74 _logger.debug(f"Request #{index} succeeded") 75 status_tracker.complete_task(success=True) 76 results.append((index, result)) 77 except openai.RateLimitError as e: 78 status_tracker.complete_task(success=False) 79 _logger.debug(f"Request #{index} failed with: {e}") 80 status_tracker.increment_num_rate_limit_errors() 81 status_tracker.error = mlflow.MlflowException( 82 f"Request #{index} failed with rate limit: {e}." 83 ) 84 except Exception as e: 85 status_tracker.complete_task(success=False) 86 _logger.debug(f"Request #{index} failed with: {e}") 87 status_tracker.error = mlflow.MlflowException( 88 f"Request #{index} failed with: {e.__cause__}" 89 ) 90 91 92 def process_api_requests( 93 request_tasks: list[Callable[[], Any]], 94 max_workers: int = 10, 95 ): 96 """Processes API requests in parallel""" 97 # initialize trackers 98 status_tracker = StatusTracker() # single instance to track a collection of variables 99 100 results: list[tuple[int, Any]] = [] 101 request_tasks_iter = enumerate(request_tasks) 102 _logger.debug(f"Request pool executor will run {len(request_tasks)} requests") 103 with ThreadPoolExecutor( 104 max_workers=max_workers, thread_name_prefix="MlflowOpenAiApi" 105 ) as executor: 106 futures = [ 107 executor.submit( 108 call_api, 109 index=index, 110 task=task, 111 results=results, 112 status_tracker=status_tracker, 113 ) 114 for index, task in request_tasks_iter 115 ] 116 wait(futures, return_when=FIRST_EXCEPTION) 117 118 # after finishing, log final status 119 if status_tracker.num_tasks_failed > 0: 120 if status_tracker.num_tasks_failed == 1: 121 raise status_tracker.error 122 raise mlflow.MlflowException( 123 f"{status_tracker.num_tasks_failed} tasks failed. See logs for details." 124 ) 125 if status_tracker.num_rate_limit_errors > 0: 126 _logger.debug( 127 f"{status_tracker.num_rate_limit_errors} rate limit errors received. " 128 "Consider running at a lower rate." 129 ) 130 131 return [res for _, res in sorted(results)]