async_logging_queue.py
1 """ 2 Defines an AsyncLoggingQueue that provides async fashion logging of metrics/tags/params using 3 queue based approach. 4 """ 5 6 import atexit 7 import enum 8 import logging 9 import threading 10 from concurrent.futures import ThreadPoolExecutor 11 from queue import Empty, Queue 12 from typing import Callable 13 14 from mlflow.entities.metric import Metric 15 from mlflow.entities.param import Param 16 from mlflow.entities.run_tag import RunTag 17 from mlflow.environment_variables import ( 18 MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS, 19 MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE, 20 ) 21 from mlflow.utils.async_logging.run_batch import RunBatch 22 from mlflow.utils.async_logging.run_operations import RunOperations 23 24 _logger = logging.getLogger(__name__) 25 26 27 ASYNC_LOGGING_WORKER_THREAD_PREFIX = "MLflowBatchLoggingWorkerPool" 28 ASYNC_LOGGING_STATUS_CHECK_THREAD_PREFIX = "MLflowAsyncLoggingStatusCheck" 29 30 31 class QueueStatus(enum.Enum): 32 """Status of the async queue""" 33 34 # The queue is listening to new data and logging enqueued data to MLflow. 35 ACTIVE = 1 36 # The queue is not listening to new data, but still logging enqueued data to MLflow. 37 TEAR_DOWN = 2 38 # The queue is neither listening to new data or logging enqueued data to MLflow. 39 IDLE = 3 40 41 42 _MAX_ITEMS_PER_BATCH = 1000 43 _MAX_PARAMS_PER_BATCH = 100 44 _MAX_TAGS_PER_BATCH = 100 45 46 47 class AsyncLoggingQueue: 48 """ 49 This is a queue based run data processor that queues incoming batches and processes them using 50 single worker thread. 51 """ 52 53 def __init__( 54 self, logging_func: Callable[[str, list[Metric], list[Param], list[RunTag]], None] 55 ) -> None: 56 """Initializes an AsyncLoggingQueue object. 57 58 Args: 59 logging_func: A callable function that takes in four arguments: a string 60 representing the run_id, a list of Metric objects, 61 a list of Param objects, and a list of RunTag objects. 62 """ 63 self._queue = Queue() 64 self._lock = threading.RLock() 65 self._logging_func = logging_func 66 67 self._stop_data_logging_thread_event = threading.Event() 68 self._status = QueueStatus.IDLE 69 70 def _at_exit_callback(self) -> None: 71 """Callback function to be executed when the program is exiting. 72 73 Stops the data processing thread and waits for the queue to be drained. Finally, shuts down 74 the thread pools used for data logging and batch processing status check. 75 """ 76 try: 77 # Stop the data processing thread 78 self._stop_data_logging_thread_event.set() 79 # Waits till logging queue is drained. 80 self._batch_logging_thread.join() 81 self._batch_logging_worker_threadpool.shutdown(wait=True) 82 self._batch_status_check_threadpool.shutdown(wait=True) 83 except Exception as e: 84 _logger.error(f"Encountered error while trying to finish logging: {e}") 85 86 def end_async_logging(self) -> None: 87 with self._lock: 88 # Stop the data processing thread. 89 self._stop_data_logging_thread_event.set() 90 # Waits till logging queue is drained. 91 self._batch_logging_thread.join() 92 # Set the status to tear down. The worker threads will still process 93 # the remaining data. 94 self._status = QueueStatus.TEAR_DOWN 95 # Clear the status to avoid blocking next logging. 96 self._stop_data_logging_thread_event.clear() 97 98 def shut_down_async_logging(self) -> None: 99 """ 100 Shut down the async logging queue and wait for the queue to be drained. 101 Use this method if the async logging should be terminated. 102 """ 103 self.end_async_logging() 104 self._batch_logging_worker_threadpool.shutdown(wait=True) 105 self._batch_status_check_threadpool.shutdown(wait=True) 106 self._status = QueueStatus.IDLE 107 108 def flush(self) -> None: 109 """ 110 Flush the async logging queue and restart thread to listen 111 to incoming data after flushing. 112 113 Calling this method will flush the queue to ensure all the data are logged. 114 """ 115 self.shut_down_async_logging() 116 # Reinitialize the logging thread and set the status to active. 117 self.activate() 118 119 def _logging_loop(self) -> None: 120 """ 121 Continuously logs run data until `self._continue_to_process_data` is set to False. 122 If an exception occurs during logging, a `MlflowException` is raised. 123 """ 124 try: 125 while not self._stop_data_logging_thread_event.is_set(): 126 self._log_run_data() 127 # Drain the queue after the stop event is set. 128 while not self._queue.empty(): 129 self._log_run_data() 130 except Exception as e: 131 from mlflow.exceptions import MlflowException 132 133 raise MlflowException(f"Exception inside the run data logging thread: {e}") 134 135 def _fetch_batch_from_queue(self) -> list[RunBatch]: 136 """Fetches a batch of run data from the queue. 137 138 Returns: 139 RunBatch: A batch of run data. 140 """ 141 batches = [] 142 if self._queue.empty(): 143 return batches 144 queue_size = self._queue.qsize() # Estimate the queue's size. 145 merged_batch = self._queue.get() 146 for i in range(queue_size - 1): 147 if self._queue.empty(): 148 # `queue_size` is an estimate, so we need to check if the queue is empty. 149 break 150 batch = self._queue.get() 151 152 if ( 153 merged_batch.run_id != batch.run_id 154 or ( 155 len(merged_batch.metrics + merged_batch.params + merged_batch.tags) 156 + len(batch.metrics + batch.params + batch.tags) 157 ) 158 >= _MAX_ITEMS_PER_BATCH 159 or len(merged_batch.params) + len(batch.params) >= _MAX_PARAMS_PER_BATCH 160 or len(merged_batch.tags) + len(batch.tags) >= _MAX_TAGS_PER_BATCH 161 ): 162 # Make a new batch if the run_id is different or the batch is full. 163 batches.append(merged_batch) 164 merged_batch = batch 165 else: 166 merged_batch.add_child_batch(batch) 167 merged_batch.params.extend(batch.params) 168 merged_batch.tags.extend(batch.tags) 169 merged_batch.metrics.extend(batch.metrics) 170 171 batches.append(merged_batch) 172 return batches 173 174 def _log_run_data(self) -> None: 175 """Process the run data in the running runs queues. 176 177 For each run in the running runs queues, this method retrieves the next batch of run data 178 from the queue and processes it by calling the `_processing_func` method with the run ID, 179 metrics, parameters, and tags in the batch. If the batch is empty, it is skipped. After 180 processing the batch, the processed watermark is updated and the batch event is set. 181 If an exception occurs during processing, the exception is logged and the batch event is set 182 with the exception. If the queue is empty, it is ignored. 183 184 Returns: None 185 """ 186 async_logging_buffer_seconds = MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS.get() 187 try: 188 if async_logging_buffer_seconds: 189 self._stop_data_logging_thread_event.wait(async_logging_buffer_seconds) 190 run_batches = self._fetch_batch_from_queue() 191 else: 192 run_batches = [self._queue.get(timeout=1)] 193 except Empty: 194 # Ignore empty queue exception 195 return 196 197 def logging_func(run_batch): 198 try: 199 self._logging_func( 200 run_id=run_batch.run_id, 201 metrics=run_batch.metrics, 202 params=run_batch.params, 203 tags=run_batch.tags, 204 ) 205 except Exception as e: 206 _logger.error(f"Run Id {run_batch.run_id}: Failed to log run data: Exception: {e}") 207 run_batch.exception = e 208 finally: 209 run_batch.complete() 210 211 for run_batch in run_batches: 212 try: 213 self._batch_logging_worker_threadpool.submit(logging_func, run_batch) 214 except Exception as e: 215 _logger.error( 216 f"Failed to submit batch for logging: {e}. Usually this means you are not " 217 "shutting down MLflow properly before exiting. Please make sure you are using " 218 "context manager, e.g., `with mlflow.start_run():` or call `mlflow.end_run()`" 219 "explicitly to terminate MLflow logging before exiting." 220 ) 221 run_batch.exception = e 222 run_batch.complete() 223 224 def _wait_for_batch(self, batch: RunBatch) -> None: 225 """Wait for the given batch to be processed by the logging thread. 226 227 Args: 228 batch: The batch to wait for. 229 230 Raises: 231 Exception: If an exception occurred while processing the batch. 232 """ 233 batch.completion_event.wait() 234 if batch.exception: 235 raise batch.exception 236 237 def __getstate__(self): 238 """Return the state of the object for pickling. 239 240 This method is called by the `pickle` module when the object is being pickled. It returns a 241 dictionary containing the object's state, with non-picklable attributes removed. 242 243 Returns: 244 dict: A dictionary containing the object's state. 245 """ 246 state = self.__dict__.copy() 247 del state["_queue"] 248 del state["_lock"] 249 del state["_status"] 250 251 if "_run_data_logging_thread" in state: 252 del state["_run_data_logging_thread"] 253 if "_stop_data_logging_thread_event" in state: 254 del state["_stop_data_logging_thread_event"] 255 if "_batch_logging_thread" in state: 256 del state["_batch_logging_thread"] 257 if "_batch_logging_worker_threadpool" in state: 258 del state["_batch_logging_worker_threadpool"] 259 if "_batch_status_check_threadpool" in state: 260 del state["_batch_status_check_threadpool"] 261 262 return state 263 264 def __setstate__(self, state): 265 """Set the state of the object from a given state dictionary. 266 267 It pops back the removed non-picklable attributes from `self.__getstate__()`. 268 269 Args: 270 state (dict): A dictionary containing the state of the object. 271 272 Returns: 273 None 274 """ 275 self.__dict__.update(state) 276 self._queue = Queue() 277 self._lock = threading.RLock() 278 self._status = QueueStatus.IDLE 279 self._batch_logging_thread = None 280 self._batch_logging_worker_threadpool = None 281 self._batch_status_check_threadpool = None 282 self._stop_data_logging_thread_event = threading.Event() 283 284 def log_batch_async( 285 self, run_id: str, params: list[Param], tags: list[RunTag], metrics: list[Metric] 286 ) -> RunOperations: 287 """Asynchronously logs a batch of run data (parameters, tags, and metrics). 288 289 Args: 290 run_id (str): The ID of the run to log data for. 291 params (list[mlflow.entities.Param]): A list of parameters to log for the run. 292 tags (list[mlflow.entities.RunTag]): A list of tags to log for the run. 293 metrics (list[mlflow.entities.Metric]): A list of metrics to log for the run. 294 295 Returns: 296 mlflow.utils.async_utils.RunOperations: An object that encapsulates the 297 asynchronous operation of logging the batch of run data. 298 The object contains a list of `concurrent.futures.Future` objects that can be used 299 to check the status of the operation and retrieve any exceptions 300 that occurred during the operation. 301 """ 302 from mlflow import MlflowException 303 304 if not self.is_active(): 305 raise MlflowException("AsyncLoggingQueue is not activated.") 306 batch = RunBatch( 307 run_id=run_id, 308 params=params, 309 tags=tags, 310 metrics=metrics, 311 completion_event=threading.Event(), 312 ) 313 self._queue.put(batch) 314 operation_future = self._batch_status_check_threadpool.submit(self._wait_for_batch, batch) 315 return RunOperations(operation_futures=[operation_future]) 316 317 def is_active(self) -> bool: 318 return self._status == QueueStatus.ACTIVE 319 320 def is_idle(self) -> bool: 321 return self._status == QueueStatus.IDLE 322 323 def _set_up_logging_thread(self) -> None: 324 """ 325 Sets up the logging thread. 326 327 This method shouldn't be called directly without shutting down the async 328 logging first if an existing async logging exists, otherwise it might 329 hang the program. 330 """ 331 with self._lock: 332 self._batch_logging_thread = threading.Thread( 333 target=self._logging_loop, 334 name="MLflowAsyncLoggingLoop", 335 daemon=True, 336 ) 337 self._batch_logging_worker_threadpool = ThreadPoolExecutor( 338 max_workers=MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.get() or 10, 339 thread_name_prefix=ASYNC_LOGGING_WORKER_THREAD_PREFIX, 340 ) 341 342 self._batch_status_check_threadpool = ThreadPoolExecutor( 343 max_workers=MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.get() or 10, 344 thread_name_prefix=ASYNC_LOGGING_STATUS_CHECK_THREAD_PREFIX, 345 ) 346 347 self._batch_logging_thread.start() 348 349 def activate(self) -> None: 350 """Activates the async logging queue 351 352 1. Initializes queue draining thread. 353 2. Initializes threads for checking the status of logged batch. 354 3. Registering an atexit callback to ensure that any remaining log data 355 is flushed before the program exits. 356 357 If the queue is already activated, this method does nothing. 358 """ 359 with self._lock: 360 if self.is_active(): 361 return 362 363 self._set_up_logging_thread() 364 atexit.register(self._at_exit_callback) 365 366 self._status = QueueStatus.ACTIVE