/ mlflow / utils / async_logging / async_logging_queue.py
async_logging_queue.py
  1  """
  2  Defines an AsyncLoggingQueue that provides async fashion logging of metrics/tags/params using
  3  queue based approach.
  4  """
  5  
  6  import atexit
  7  import enum
  8  import logging
  9  import threading
 10  from concurrent.futures import ThreadPoolExecutor
 11  from queue import Empty, Queue
 12  from typing import Callable
 13  
 14  from mlflow.entities.metric import Metric
 15  from mlflow.entities.param import Param
 16  from mlflow.entities.run_tag import RunTag
 17  from mlflow.environment_variables import (
 18      MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS,
 19      MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE,
 20  )
 21  from mlflow.utils.async_logging.run_batch import RunBatch
 22  from mlflow.utils.async_logging.run_operations import RunOperations
 23  
 24  _logger = logging.getLogger(__name__)
 25  
 26  
 27  ASYNC_LOGGING_WORKER_THREAD_PREFIX = "MLflowBatchLoggingWorkerPool"
 28  ASYNC_LOGGING_STATUS_CHECK_THREAD_PREFIX = "MLflowAsyncLoggingStatusCheck"
 29  
 30  
 31  class QueueStatus(enum.Enum):
 32      """Status of the async queue"""
 33  
 34      # The queue is listening to new data and logging enqueued data to MLflow.
 35      ACTIVE = 1
 36      # The queue is not listening to new data, but still logging enqueued data to MLflow.
 37      TEAR_DOWN = 2
 38      # The queue is neither listening to new data or logging enqueued data to MLflow.
 39      IDLE = 3
 40  
 41  
 42  _MAX_ITEMS_PER_BATCH = 1000
 43  _MAX_PARAMS_PER_BATCH = 100
 44  _MAX_TAGS_PER_BATCH = 100
 45  
 46  
 47  class AsyncLoggingQueue:
 48      """
 49      This is a queue based run data processor that queues incoming batches and processes them using
 50      single worker thread.
 51      """
 52  
 53      def __init__(
 54          self, logging_func: Callable[[str, list[Metric], list[Param], list[RunTag]], None]
 55      ) -> None:
 56          """Initializes an AsyncLoggingQueue object.
 57  
 58          Args:
 59              logging_func: A callable function that takes in four arguments: a string
 60                  representing the run_id, a list of Metric objects,
 61                  a list of Param objects, and a list of RunTag objects.
 62          """
 63          self._queue = Queue()
 64          self._lock = threading.RLock()
 65          self._logging_func = logging_func
 66  
 67          self._stop_data_logging_thread_event = threading.Event()
 68          self._status = QueueStatus.IDLE
 69  
 70      def _at_exit_callback(self) -> None:
 71          """Callback function to be executed when the program is exiting.
 72  
 73          Stops the data processing thread and waits for the queue to be drained. Finally, shuts down
 74          the thread pools used for data logging and batch processing status check.
 75          """
 76          try:
 77              # Stop the data processing thread
 78              self._stop_data_logging_thread_event.set()
 79              # Waits till logging queue is drained.
 80              self._batch_logging_thread.join()
 81              self._batch_logging_worker_threadpool.shutdown(wait=True)
 82              self._batch_status_check_threadpool.shutdown(wait=True)
 83          except Exception as e:
 84              _logger.error(f"Encountered error while trying to finish logging: {e}")
 85  
 86      def end_async_logging(self) -> None:
 87          with self._lock:
 88              # Stop the data processing thread.
 89              self._stop_data_logging_thread_event.set()
 90              # Waits till logging queue is drained.
 91              self._batch_logging_thread.join()
 92              # Set the status to tear down. The worker threads will still process
 93              # the remaining data.
 94              self._status = QueueStatus.TEAR_DOWN
 95              # Clear the status to avoid blocking next logging.
 96              self._stop_data_logging_thread_event.clear()
 97  
 98      def shut_down_async_logging(self) -> None:
 99          """
100          Shut down the async logging queue and wait for the queue to be drained.
101          Use this method if the async logging should be terminated.
102          """
103          self.end_async_logging()
104          self._batch_logging_worker_threadpool.shutdown(wait=True)
105          self._batch_status_check_threadpool.shutdown(wait=True)
106          self._status = QueueStatus.IDLE
107  
108      def flush(self) -> None:
109          """
110          Flush the async logging queue and restart thread to listen
111          to incoming data after flushing.
112  
113          Calling this method will flush the queue to ensure all the data are logged.
114          """
115          self.shut_down_async_logging()
116          # Reinitialize the logging thread and set the status to active.
117          self.activate()
118  
119      def _logging_loop(self) -> None:
120          """
121          Continuously logs run data until `self._continue_to_process_data` is set to False.
122          If an exception occurs during logging, a `MlflowException` is raised.
123          """
124          try:
125              while not self._stop_data_logging_thread_event.is_set():
126                  self._log_run_data()
127              # Drain the queue after the stop event is set.
128              while not self._queue.empty():
129                  self._log_run_data()
130          except Exception as e:
131              from mlflow.exceptions import MlflowException
132  
133              raise MlflowException(f"Exception inside the run data logging thread: {e}")
134  
135      def _fetch_batch_from_queue(self) -> list[RunBatch]:
136          """Fetches a batch of run data from the queue.
137  
138          Returns:
139              RunBatch: A batch of run data.
140          """
141          batches = []
142          if self._queue.empty():
143              return batches
144          queue_size = self._queue.qsize()  # Estimate the queue's size.
145          merged_batch = self._queue.get()
146          for i in range(queue_size - 1):
147              if self._queue.empty():
148                  # `queue_size` is an estimate, so we need to check if the queue is empty.
149                  break
150              batch = self._queue.get()
151  
152              if (
153                  merged_batch.run_id != batch.run_id
154                  or (
155                      len(merged_batch.metrics + merged_batch.params + merged_batch.tags)
156                      + len(batch.metrics + batch.params + batch.tags)
157                  )
158                  >= _MAX_ITEMS_PER_BATCH
159                  or len(merged_batch.params) + len(batch.params) >= _MAX_PARAMS_PER_BATCH
160                  or len(merged_batch.tags) + len(batch.tags) >= _MAX_TAGS_PER_BATCH
161              ):
162                  # Make a new batch if the run_id is different or the batch is full.
163                  batches.append(merged_batch)
164                  merged_batch = batch
165              else:
166                  merged_batch.add_child_batch(batch)
167                  merged_batch.params.extend(batch.params)
168                  merged_batch.tags.extend(batch.tags)
169                  merged_batch.metrics.extend(batch.metrics)
170  
171          batches.append(merged_batch)
172          return batches
173  
174      def _log_run_data(self) -> None:
175          """Process the run data in the running runs queues.
176  
177          For each run in the running runs queues, this method retrieves the next batch of run data
178          from the queue and processes it by calling the `_processing_func` method with the run ID,
179          metrics, parameters, and tags in the batch. If the batch is empty, it is skipped. After
180          processing the batch, the processed watermark is updated and the batch event is set.
181          If an exception occurs during processing, the exception is logged and the batch event is set
182          with the exception. If the queue is empty, it is ignored.
183  
184          Returns: None
185          """
186          async_logging_buffer_seconds = MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS.get()
187          try:
188              if async_logging_buffer_seconds:
189                  self._stop_data_logging_thread_event.wait(async_logging_buffer_seconds)
190                  run_batches = self._fetch_batch_from_queue()
191              else:
192                  run_batches = [self._queue.get(timeout=1)]
193          except Empty:
194              # Ignore empty queue exception
195              return
196  
197          def logging_func(run_batch):
198              try:
199                  self._logging_func(
200                      run_id=run_batch.run_id,
201                      metrics=run_batch.metrics,
202                      params=run_batch.params,
203                      tags=run_batch.tags,
204                  )
205              except Exception as e:
206                  _logger.error(f"Run Id {run_batch.run_id}: Failed to log run data: Exception: {e}")
207                  run_batch.exception = e
208              finally:
209                  run_batch.complete()
210  
211          for run_batch in run_batches:
212              try:
213                  self._batch_logging_worker_threadpool.submit(logging_func, run_batch)
214              except Exception as e:
215                  _logger.error(
216                      f"Failed to submit batch for logging: {e}. Usually this means you are not "
217                      "shutting down MLflow properly before exiting. Please make sure you are using "
218                      "context manager, e.g., `with mlflow.start_run():` or call `mlflow.end_run()`"
219                      "explicitly to terminate MLflow logging before exiting."
220                  )
221                  run_batch.exception = e
222                  run_batch.complete()
223  
224      def _wait_for_batch(self, batch: RunBatch) -> None:
225          """Wait for the given batch to be processed by the logging thread.
226  
227          Args:
228              batch: The batch to wait for.
229  
230          Raises:
231              Exception: If an exception occurred while processing the batch.
232          """
233          batch.completion_event.wait()
234          if batch.exception:
235              raise batch.exception
236  
237      def __getstate__(self):
238          """Return the state of the object for pickling.
239  
240          This method is called by the `pickle` module when the object is being pickled. It returns a
241          dictionary containing the object's state, with non-picklable attributes removed.
242  
243          Returns:
244              dict: A dictionary containing the object's state.
245          """
246          state = self.__dict__.copy()
247          del state["_queue"]
248          del state["_lock"]
249          del state["_status"]
250  
251          if "_run_data_logging_thread" in state:
252              del state["_run_data_logging_thread"]
253          if "_stop_data_logging_thread_event" in state:
254              del state["_stop_data_logging_thread_event"]
255          if "_batch_logging_thread" in state:
256              del state["_batch_logging_thread"]
257          if "_batch_logging_worker_threadpool" in state:
258              del state["_batch_logging_worker_threadpool"]
259          if "_batch_status_check_threadpool" in state:
260              del state["_batch_status_check_threadpool"]
261  
262          return state
263  
264      def __setstate__(self, state):
265          """Set the state of the object from a given state dictionary.
266  
267          It pops back the removed non-picklable attributes from `self.__getstate__()`.
268  
269          Args:
270              state (dict): A dictionary containing the state of the object.
271  
272          Returns:
273              None
274          """
275          self.__dict__.update(state)
276          self._queue = Queue()
277          self._lock = threading.RLock()
278          self._status = QueueStatus.IDLE
279          self._batch_logging_thread = None
280          self._batch_logging_worker_threadpool = None
281          self._batch_status_check_threadpool = None
282          self._stop_data_logging_thread_event = threading.Event()
283  
284      def log_batch_async(
285          self, run_id: str, params: list[Param], tags: list[RunTag], metrics: list[Metric]
286      ) -> RunOperations:
287          """Asynchronously logs a batch of run data (parameters, tags, and metrics).
288  
289          Args:
290              run_id (str): The ID of the run to log data for.
291              params (list[mlflow.entities.Param]): A list of parameters to log for the run.
292              tags (list[mlflow.entities.RunTag]): A list of tags to log for the run.
293              metrics (list[mlflow.entities.Metric]): A list of metrics to log for the run.
294  
295          Returns:
296              mlflow.utils.async_utils.RunOperations: An object that encapsulates the
297                  asynchronous operation of logging the batch of run data.
298                  The object contains a list of `concurrent.futures.Future` objects that can be used
299                  to check the status of the operation and retrieve any exceptions
300                  that occurred during the operation.
301          """
302          from mlflow import MlflowException
303  
304          if not self.is_active():
305              raise MlflowException("AsyncLoggingQueue is not activated.")
306          batch = RunBatch(
307              run_id=run_id,
308              params=params,
309              tags=tags,
310              metrics=metrics,
311              completion_event=threading.Event(),
312          )
313          self._queue.put(batch)
314          operation_future = self._batch_status_check_threadpool.submit(self._wait_for_batch, batch)
315          return RunOperations(operation_futures=[operation_future])
316  
317      def is_active(self) -> bool:
318          return self._status == QueueStatus.ACTIVE
319  
320      def is_idle(self) -> bool:
321          return self._status == QueueStatus.IDLE
322  
323      def _set_up_logging_thread(self) -> None:
324          """
325          Sets up the logging thread.
326  
327          This method shouldn't be called directly without shutting down the async
328          logging first if an existing async logging exists, otherwise it might
329          hang the program.
330          """
331          with self._lock:
332              self._batch_logging_thread = threading.Thread(
333                  target=self._logging_loop,
334                  name="MLflowAsyncLoggingLoop",
335                  daemon=True,
336              )
337              self._batch_logging_worker_threadpool = ThreadPoolExecutor(
338                  max_workers=MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.get() or 10,
339                  thread_name_prefix=ASYNC_LOGGING_WORKER_THREAD_PREFIX,
340              )
341  
342              self._batch_status_check_threadpool = ThreadPoolExecutor(
343                  max_workers=MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.get() or 10,
344                  thread_name_prefix=ASYNC_LOGGING_STATUS_CHECK_THREAD_PREFIX,
345              )
346  
347              self._batch_logging_thread.start()
348  
349      def activate(self) -> None:
350          """Activates the async logging queue
351  
352          1. Initializes queue draining thread.
353          2. Initializes threads for checking the status of logged batch.
354          3. Registering an atexit callback to ensure that any remaining log data
355              is flushed before the program exits.
356  
357          If the queue is already activated, this method does nothing.
358          """
359          with self._lock:
360              if self.is_active():
361                  return
362  
363              self._set_up_logging_thread()
364              atexit.register(self._at_exit_callback)
365  
366              self._status = QueueStatus.ACTIVE