/ mlflow / langchain / api_request_parallel_processor.py
api_request_parallel_processor.py
  1  # Based ons: https://github.com/openai/openai-cookbook/blob/6df6ceff470eeba26a56de131254e775292eac22/examples/api_request_parallel_processor.py
  2  # Several changes were made to make it work with MLflow.
  3  # Currently, only chat completion is supported.
  4  
  5  """
  6  API REQUEST PARALLEL PROCESSOR
  7  
  8  Using the LangChain API to process lots of text quickly takes some care.
  9  If you trickle in a million API requests one by one, they'll take days to complete.
 10  This script parallelizes requests using LangChain API.
 11  
 12  Features:
 13  - Streams requests from file, to avoid running out of memory for giant jobs
 14  - Makes requests concurrently, to maximize throughput
 15  - Logs errors, to diagnose problems with requests
 16  """
 17  
 18  from __future__ import annotations
 19  
 20  import logging
 21  import queue
 22  import threading
 23  import time
 24  import traceback
 25  from concurrent.futures import ThreadPoolExecutor
 26  from dataclasses import dataclass
 27  from typing import Any
 28  
 29  from mlflow.langchain._compat import import_base_callback_handler, try_import_chain
 30  
 31  BaseCallbackHandler = import_base_callback_handler()
 32  Chain = try_import_chain()
 33  
 34  import mlflow
 35  from mlflow.exceptions import MlflowException
 36  from mlflow.langchain.utils.chat import (
 37      transform_request_json_for_chat_if_necessary,
 38      try_transform_response_iter_to_chat_format,
 39      try_transform_response_to_chat_format,
 40  )
 41  from mlflow.langchain.utils.serialization import convert_to_serializable
 42  from mlflow.pyfunc.context import Context, get_prediction_context
 43  from mlflow.tracing.utils import maybe_set_prediction_context
 44  
 45  _logger = logging.getLogger(__name__)
 46  
 47  
 48  @dataclass
 49  class StatusTracker:
 50      """
 51      Stores metadata about the script's progress. Only one instance is created.
 52      """
 53  
 54      num_tasks_started: int = 0
 55      num_tasks_in_progress: int = 0  # script ends when this reaches 0
 56      num_tasks_succeeded: int = 0
 57      num_tasks_failed: int = 0
 58      num_api_errors: int = 0  # excluding rate limit errors, counted above
 59      lock: threading.Lock = threading.Lock()
 60  
 61      def start_task(self):
 62          with self.lock:
 63              self.num_tasks_started += 1
 64              self.num_tasks_in_progress += 1
 65  
 66      def complete_task(self, *, success: bool):
 67          with self.lock:
 68              self.num_tasks_in_progress -= 1
 69              if success:
 70                  self.num_tasks_succeeded += 1
 71              else:
 72                  self.num_tasks_failed += 1
 73  
 74      def increment_num_api_errors(self):
 75          with self.lock:
 76              self.num_api_errors += 1
 77  
 78  
 79  @dataclass
 80  class APIRequest:
 81      """
 82      Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API
 83      call.
 84  
 85      Args:
 86          index: The request's index in the tasks list
 87          lc_model: The LangChain model to call
 88          request_json: The request's input data
 89          results: The list to append the request's output data to, it's a list of tuples
 90              (index, response)
 91          errors: A dictionary to store any errors that occur
 92          convert_chat_responses: Whether to convert the model's responses to chat format
 93          did_perform_chat_conversion: Whether the input data was converted to chat format
 94              based on the model's type and input data.
 95          stream: Whether the request is a stream request
 96          prediction_context: The prediction context to use for the request
 97      """
 98  
 99      index: int
100      lc_model: Any
101      request_json: dict[str, Any]
102      results: list[tuple[int, str]]
103      errors: dict[int, str]
104      convert_chat_responses: bool
105      did_perform_chat_conversion: bool
106      stream: bool
107      params: dict[str, Any]
108      prediction_context: Context | None = None
109  
110      def _predict_single_input(self, single_input, callback_handlers, **kwargs):
111          config = kwargs.pop("config", {})
112          config["callbacks"] = config.get("callbacks", []) + (callback_handlers or [])
113          if self.stream:
114              return self.lc_model.stream(single_input, config=config, **kwargs)
115          if hasattr(self.lc_model, "invoke"):
116              return self.lc_model.invoke(single_input, config=config, **kwargs)
117          else:
118              # for backwards compatibility, __call__ is deprecated and will be removed in 0.3.0
119              # kwargs shouldn't have config field if invoking with __call__
120              return self.lc_model(single_input, callbacks=callback_handlers, **kwargs)
121  
122      def _try_convert_response(self, response):
123          if self.stream:
124              return try_transform_response_iter_to_chat_format(response)
125          else:
126              return try_transform_response_to_chat_format(response)
127  
128      def single_call_api(self, callback_handlers: list[BaseCallbackHandler] | None):
129          from mlflow.langchain._compat import import_base_retriever
130          from mlflow.langchain.utils.logging import langgraph_types, lc_runnables_types
131  
132          BaseRetriever = import_base_retriever()
133  
134          if isinstance(self.lc_model, BaseRetriever):
135              # Retrievers are invoked differently than Chains
136              response = self.lc_model.get_relevant_documents(
137                  **self.request_json, callbacks=callback_handlers, **self.params
138              )
139          elif isinstance(self.lc_model, lc_runnables_types() + langgraph_types()):
140              if isinstance(self.request_json, dict):
141                  # This is a temporary fix for the case when spark_udf converts
142                  # input into pandas dataframe with column name, while the model
143                  # does not accept dictionaries as input, it leads to errors like
144                  # Expected Scalar value for String field 'query_text'
145                  try:
146                      response = self._predict_single_input(
147                          self.request_json, callback_handlers, **self.params
148                      )
149                  except TypeError as e:
150                      _logger.debug(
151                          f"Failed to invoke {self.lc_model.__class__.__name__} "
152                          f"with {self.request_json}. Error: {e!r}. Trying to "
153                          "invoke with the first value of the dictionary."
154                      )
155                      self.request_json = next(iter(self.request_json.values()))
156                      (
157                          prepared_request_json,
158                          did_perform_chat_conversion,
159                      ) = transform_request_json_for_chat_if_necessary(
160                          self.request_json, self.lc_model
161                      )
162                      self.did_perform_chat_conversion = did_perform_chat_conversion
163  
164                      response = self._predict_single_input(
165                          prepared_request_json, callback_handlers, **self.params
166                      )
167              else:
168                  response = self._predict_single_input(
169                      self.request_json, callback_handlers, **self.params
170                  )
171  
172              if self.did_perform_chat_conversion or self.convert_chat_responses:
173                  response = self._try_convert_response(response)
174          else:
175              # return_only_outputs is invalid for stream call
176              if Chain and isinstance(self.lc_model, Chain) and not self.stream:
177                  kwargs = {"return_only_outputs": True}
178              else:
179                  kwargs = {}
180              kwargs.update(**self.params)
181              response = self._predict_single_input(self.request_json, callback_handlers, **kwargs)
182  
183              if self.did_perform_chat_conversion or self.convert_chat_responses:
184                  response = self._try_convert_response(response)
185              elif isinstance(response, dict) and len(response) == 1:
186                  # to maintain existing code, single output chains will still return
187                  # only the result
188                  response = response.popitem()[1]
189  
190          return convert_to_serializable(response)
191  
192      def call_api(
193          self, status_tracker: StatusTracker, callback_handlers: list[BaseCallbackHandler] | None
194      ):
195          """
196          Calls the LangChain API and stores results.
197          """
198          _logger.debug(f"Request #{self.index} started with payload: {self.request_json}")
199  
200          try:
201              with maybe_set_prediction_context(self.prediction_context):
202                  response = self.single_call_api(callback_handlers)
203              _logger.debug(f"Request #{self.index} succeeded with response: {response}")
204              self.results.append((self.index, response))
205              status_tracker.complete_task(success=True)
206          except Exception as e:
207              self.errors[self.index] = (
208                  f"error: {e!r} {traceback.format_exc()}\n request payload: {self.request_json}"
209              )
210              status_tracker.increment_num_api_errors()
211              status_tracker.complete_task(success=False)
212  
213  
214  def process_api_requests(
215      lc_model,
216      requests: list[Any | dict[str, Any]] | None = None,
217      max_workers: int = 10,
218      callback_handlers: list[BaseCallbackHandler] | None = None,
219      convert_chat_responses: bool = False,
220      params: dict[str, Any] | None = None,
221      context: Context | None = None,
222  ):
223      """
224      Processes API requests in parallel.
225      """
226  
227      # initialize trackers
228      retry_queue = queue.Queue()
229      status_tracker = StatusTracker()  # single instance to track a collection of variables
230      next_request = None  # variable to hold the next request to call
231      context = context or get_prediction_context()
232  
233      results = []
234      errors = {}
235  
236      # Note: we should call `transform_request_json_for_chat_if_necessary`
237      # for the whole batch data, because the conversion should obey the rule
238      # that if any record in the batch can't be converted, then all the record
239      # in this batch can't be converted.
240      (
241          converted_chat_requests,
242          did_perform_chat_conversion,
243      ) = transform_request_json_for_chat_if_necessary(requests, lc_model)
244  
245      requests_iter = enumerate(converted_chat_requests)
246      with ThreadPoolExecutor(
247          max_workers=max_workers, thread_name_prefix="MlflowLangChainApi"
248      ) as executor:
249          while True:
250              # get next request (if one is not already waiting for capacity)
251              if not retry_queue.empty():
252                  next_request = retry_queue.get_nowait()
253                  _logger.warning(f"Retrying request {next_request.index}: {next_request}")
254              elif req := next(requests_iter, None):
255                  # get new request
256                  index, converted_chat_request_json = req
257                  next_request = APIRequest(
258                      index=index,
259                      lc_model=lc_model,
260                      request_json=converted_chat_request_json,
261                      results=results,
262                      errors=errors,
263                      convert_chat_responses=convert_chat_responses,
264                      did_perform_chat_conversion=did_perform_chat_conversion,
265                      stream=False,
266                      prediction_context=context,
267                      params=params,
268                  )
269                  status_tracker.start_task()
270              else:
271                  next_request = None
272  
273              # if enough capacity available, call API
274              if next_request:
275                  # call API
276                  executor.submit(
277                      next_request.call_api,
278                      status_tracker=status_tracker,
279                      callback_handlers=callback_handlers,
280                  )
281  
282              # if all tasks are finished, break
283              # check next_request to avoid terminating the process
284              # before extra requests need to be processed
285              if status_tracker.num_tasks_in_progress == 0 and next_request is None:
286                  break
287  
288              time.sleep(0.001)  # avoid busy waiting
289  
290          # after finishing, log final status
291          if status_tracker.num_tasks_failed > 0:
292              raise mlflow.MlflowException(
293                  f"{status_tracker.num_tasks_failed} tasks failed. Errors: {errors}"
294              )
295  
296          return [res for _, res in sorted(results)]
297  
298  
299  def process_stream_request(
300      lc_model,
301      request_json: Any | dict[str, Any],
302      callback_handlers: list[BaseCallbackHandler] | None = None,
303      convert_chat_responses: bool = False,
304      params: dict[str, Any] | None = None,
305  ):
306      """
307      Process single stream request.
308      """
309      if not hasattr(lc_model, "stream"):
310          raise MlflowException(
311              f"Model {lc_model.__class__.__name__} does not support streaming prediction output. "
312              "No `stream` method found."
313          )
314  
315      (
316          converted_chat_requests,
317          did_perform_chat_conversion,
318      ) = transform_request_json_for_chat_if_necessary(request_json, lc_model)
319  
320      api_request = APIRequest(
321          index=0,
322          lc_model=lc_model,
323          request_json=converted_chat_requests,
324          results=None,
325          errors=None,
326          convert_chat_responses=convert_chat_responses,
327          did_perform_chat_conversion=did_perform_chat_conversion,
328          stream=True,
329          prediction_context=get_prediction_context(),
330          params=params,
331      )
332      with maybe_set_prediction_context(api_request.prediction_context):
333          return api_request.single_call_api(callback_handlers)