api_request_parallel_processor.py
1 # Based ons: https://github.com/openai/openai-cookbook/blob/6df6ceff470eeba26a56de131254e775292eac22/examples/api_request_parallel_processor.py 2 # Several changes were made to make it work with MLflow. 3 # Currently, only chat completion is supported. 4 5 """ 6 API REQUEST PARALLEL PROCESSOR 7 8 Using the LangChain API to process lots of text quickly takes some care. 9 If you trickle in a million API requests one by one, they'll take days to complete. 10 This script parallelizes requests using LangChain API. 11 12 Features: 13 - Streams requests from file, to avoid running out of memory for giant jobs 14 - Makes requests concurrently, to maximize throughput 15 - Logs errors, to diagnose problems with requests 16 """ 17 18 from __future__ import annotations 19 20 import logging 21 import queue 22 import threading 23 import time 24 import traceback 25 from concurrent.futures import ThreadPoolExecutor 26 from dataclasses import dataclass 27 from typing import Any 28 29 from mlflow.langchain._compat import import_base_callback_handler, try_import_chain 30 31 BaseCallbackHandler = import_base_callback_handler() 32 Chain = try_import_chain() 33 34 import mlflow 35 from mlflow.exceptions import MlflowException 36 from mlflow.langchain.utils.chat import ( 37 transform_request_json_for_chat_if_necessary, 38 try_transform_response_iter_to_chat_format, 39 try_transform_response_to_chat_format, 40 ) 41 from mlflow.langchain.utils.serialization import convert_to_serializable 42 from mlflow.pyfunc.context import Context, get_prediction_context 43 from mlflow.tracing.utils import maybe_set_prediction_context 44 45 _logger = logging.getLogger(__name__) 46 47 48 @dataclass 49 class StatusTracker: 50 """ 51 Stores metadata about the script's progress. Only one instance is created. 52 """ 53 54 num_tasks_started: int = 0 55 num_tasks_in_progress: int = 0 # script ends when this reaches 0 56 num_tasks_succeeded: int = 0 57 num_tasks_failed: int = 0 58 num_api_errors: int = 0 # excluding rate limit errors, counted above 59 lock: threading.Lock = threading.Lock() 60 61 def start_task(self): 62 with self.lock: 63 self.num_tasks_started += 1 64 self.num_tasks_in_progress += 1 65 66 def complete_task(self, *, success: bool): 67 with self.lock: 68 self.num_tasks_in_progress -= 1 69 if success: 70 self.num_tasks_succeeded += 1 71 else: 72 self.num_tasks_failed += 1 73 74 def increment_num_api_errors(self): 75 with self.lock: 76 self.num_api_errors += 1 77 78 79 @dataclass 80 class APIRequest: 81 """ 82 Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API 83 call. 84 85 Args: 86 index: The request's index in the tasks list 87 lc_model: The LangChain model to call 88 request_json: The request's input data 89 results: The list to append the request's output data to, it's a list of tuples 90 (index, response) 91 errors: A dictionary to store any errors that occur 92 convert_chat_responses: Whether to convert the model's responses to chat format 93 did_perform_chat_conversion: Whether the input data was converted to chat format 94 based on the model's type and input data. 95 stream: Whether the request is a stream request 96 prediction_context: The prediction context to use for the request 97 """ 98 99 index: int 100 lc_model: Any 101 request_json: dict[str, Any] 102 results: list[tuple[int, str]] 103 errors: dict[int, str] 104 convert_chat_responses: bool 105 did_perform_chat_conversion: bool 106 stream: bool 107 params: dict[str, Any] 108 prediction_context: Context | None = None 109 110 def _predict_single_input(self, single_input, callback_handlers, **kwargs): 111 config = kwargs.pop("config", {}) 112 config["callbacks"] = config.get("callbacks", []) + (callback_handlers or []) 113 if self.stream: 114 return self.lc_model.stream(single_input, config=config, **kwargs) 115 if hasattr(self.lc_model, "invoke"): 116 return self.lc_model.invoke(single_input, config=config, **kwargs) 117 else: 118 # for backwards compatibility, __call__ is deprecated and will be removed in 0.3.0 119 # kwargs shouldn't have config field if invoking with __call__ 120 return self.lc_model(single_input, callbacks=callback_handlers, **kwargs) 121 122 def _try_convert_response(self, response): 123 if self.stream: 124 return try_transform_response_iter_to_chat_format(response) 125 else: 126 return try_transform_response_to_chat_format(response) 127 128 def single_call_api(self, callback_handlers: list[BaseCallbackHandler] | None): 129 from mlflow.langchain._compat import import_base_retriever 130 from mlflow.langchain.utils.logging import langgraph_types, lc_runnables_types 131 132 BaseRetriever = import_base_retriever() 133 134 if isinstance(self.lc_model, BaseRetriever): 135 # Retrievers are invoked differently than Chains 136 response = self.lc_model.get_relevant_documents( 137 **self.request_json, callbacks=callback_handlers, **self.params 138 ) 139 elif isinstance(self.lc_model, lc_runnables_types() + langgraph_types()): 140 if isinstance(self.request_json, dict): 141 # This is a temporary fix for the case when spark_udf converts 142 # input into pandas dataframe with column name, while the model 143 # does not accept dictionaries as input, it leads to errors like 144 # Expected Scalar value for String field 'query_text' 145 try: 146 response = self._predict_single_input( 147 self.request_json, callback_handlers, **self.params 148 ) 149 except TypeError as e: 150 _logger.debug( 151 f"Failed to invoke {self.lc_model.__class__.__name__} " 152 f"with {self.request_json}. Error: {e!r}. Trying to " 153 "invoke with the first value of the dictionary." 154 ) 155 self.request_json = next(iter(self.request_json.values())) 156 ( 157 prepared_request_json, 158 did_perform_chat_conversion, 159 ) = transform_request_json_for_chat_if_necessary( 160 self.request_json, self.lc_model 161 ) 162 self.did_perform_chat_conversion = did_perform_chat_conversion 163 164 response = self._predict_single_input( 165 prepared_request_json, callback_handlers, **self.params 166 ) 167 else: 168 response = self._predict_single_input( 169 self.request_json, callback_handlers, **self.params 170 ) 171 172 if self.did_perform_chat_conversion or self.convert_chat_responses: 173 response = self._try_convert_response(response) 174 else: 175 # return_only_outputs is invalid for stream call 176 if Chain and isinstance(self.lc_model, Chain) and not self.stream: 177 kwargs = {"return_only_outputs": True} 178 else: 179 kwargs = {} 180 kwargs.update(**self.params) 181 response = self._predict_single_input(self.request_json, callback_handlers, **kwargs) 182 183 if self.did_perform_chat_conversion or self.convert_chat_responses: 184 response = self._try_convert_response(response) 185 elif isinstance(response, dict) and len(response) == 1: 186 # to maintain existing code, single output chains will still return 187 # only the result 188 response = response.popitem()[1] 189 190 return convert_to_serializable(response) 191 192 def call_api( 193 self, status_tracker: StatusTracker, callback_handlers: list[BaseCallbackHandler] | None 194 ): 195 """ 196 Calls the LangChain API and stores results. 197 """ 198 _logger.debug(f"Request #{self.index} started with payload: {self.request_json}") 199 200 try: 201 with maybe_set_prediction_context(self.prediction_context): 202 response = self.single_call_api(callback_handlers) 203 _logger.debug(f"Request #{self.index} succeeded with response: {response}") 204 self.results.append((self.index, response)) 205 status_tracker.complete_task(success=True) 206 except Exception as e: 207 self.errors[self.index] = ( 208 f"error: {e!r} {traceback.format_exc()}\n request payload: {self.request_json}" 209 ) 210 status_tracker.increment_num_api_errors() 211 status_tracker.complete_task(success=False) 212 213 214 def process_api_requests( 215 lc_model, 216 requests: list[Any | dict[str, Any]] | None = None, 217 max_workers: int = 10, 218 callback_handlers: list[BaseCallbackHandler] | None = None, 219 convert_chat_responses: bool = False, 220 params: dict[str, Any] | None = None, 221 context: Context | None = None, 222 ): 223 """ 224 Processes API requests in parallel. 225 """ 226 227 # initialize trackers 228 retry_queue = queue.Queue() 229 status_tracker = StatusTracker() # single instance to track a collection of variables 230 next_request = None # variable to hold the next request to call 231 context = context or get_prediction_context() 232 233 results = [] 234 errors = {} 235 236 # Note: we should call `transform_request_json_for_chat_if_necessary` 237 # for the whole batch data, because the conversion should obey the rule 238 # that if any record in the batch can't be converted, then all the record 239 # in this batch can't be converted. 240 ( 241 converted_chat_requests, 242 did_perform_chat_conversion, 243 ) = transform_request_json_for_chat_if_necessary(requests, lc_model) 244 245 requests_iter = enumerate(converted_chat_requests) 246 with ThreadPoolExecutor( 247 max_workers=max_workers, thread_name_prefix="MlflowLangChainApi" 248 ) as executor: 249 while True: 250 # get next request (if one is not already waiting for capacity) 251 if not retry_queue.empty(): 252 next_request = retry_queue.get_nowait() 253 _logger.warning(f"Retrying request {next_request.index}: {next_request}") 254 elif req := next(requests_iter, None): 255 # get new request 256 index, converted_chat_request_json = req 257 next_request = APIRequest( 258 index=index, 259 lc_model=lc_model, 260 request_json=converted_chat_request_json, 261 results=results, 262 errors=errors, 263 convert_chat_responses=convert_chat_responses, 264 did_perform_chat_conversion=did_perform_chat_conversion, 265 stream=False, 266 prediction_context=context, 267 params=params, 268 ) 269 status_tracker.start_task() 270 else: 271 next_request = None 272 273 # if enough capacity available, call API 274 if next_request: 275 # call API 276 executor.submit( 277 next_request.call_api, 278 status_tracker=status_tracker, 279 callback_handlers=callback_handlers, 280 ) 281 282 # if all tasks are finished, break 283 # check next_request to avoid terminating the process 284 # before extra requests need to be processed 285 if status_tracker.num_tasks_in_progress == 0 and next_request is None: 286 break 287 288 time.sleep(0.001) # avoid busy waiting 289 290 # after finishing, log final status 291 if status_tracker.num_tasks_failed > 0: 292 raise mlflow.MlflowException( 293 f"{status_tracker.num_tasks_failed} tasks failed. Errors: {errors}" 294 ) 295 296 return [res for _, res in sorted(results)] 297 298 299 def process_stream_request( 300 lc_model, 301 request_json: Any | dict[str, Any], 302 callback_handlers: list[BaseCallbackHandler] | None = None, 303 convert_chat_responses: bool = False, 304 params: dict[str, Any] | None = None, 305 ): 306 """ 307 Process single stream request. 308 """ 309 if not hasattr(lc_model, "stream"): 310 raise MlflowException( 311 f"Model {lc_model.__class__.__name__} does not support streaming prediction output. " 312 "No `stream` method found." 313 ) 314 315 ( 316 converted_chat_requests, 317 did_perform_chat_conversion, 318 ) = transform_request_json_for_chat_if_necessary(request_json, lc_model) 319 320 api_request = APIRequest( 321 index=0, 322 lc_model=lc_model, 323 request_json=converted_chat_requests, 324 results=None, 325 errors=None, 326 convert_chat_responses=convert_chat_responses, 327 did_perform_chat_conversion=did_perform_chat_conversion, 328 stream=True, 329 prediction_context=get_prediction_context(), 330 params=params, 331 ) 332 with maybe_set_prediction_context(api_request.prediction_context): 333 return api_request.single_call_api(callback_handlers)