async_artifacts_logging_queue.py
1 """ 2 Defines an AsyncArtifactsLoggingQueue that provides async fashion artifact writes using 3 queue based approach. 4 """ 5 6 import atexit 7 import logging 8 import threading 9 from concurrent.futures import ThreadPoolExecutor 10 from queue import Empty, Queue 11 from typing import TYPE_CHECKING, Callable, Union 12 13 from mlflow.utils.async_logging.run_artifact import RunArtifact 14 from mlflow.utils.async_logging.run_operations import RunOperations 15 16 if TYPE_CHECKING: 17 import PIL.Image 18 19 _logger = logging.getLogger(__name__) 20 21 22 class AsyncArtifactsLoggingQueue: 23 """ 24 This is a queue based run data processor that queue incoming data and process it using a single 25 worker thread. This class is used to process artifacts saving in async fashion. 26 27 Args: 28 logging_func: A callable function that takes in three arguments: 29 - filename: The name of the artifact file. 30 - artifact_path: The path to the artifact. 31 - artifact: The artifact to be logged. 32 """ 33 34 def __init__( 35 self, artifact_logging_func: Callable[[str, str, Union["PIL.Image.Image"]], None] 36 ) -> None: 37 self._queue: Queue[RunArtifact] = Queue() 38 self._lock = threading.RLock() 39 self._artifact_logging_func = artifact_logging_func 40 41 self._stop_data_logging_thread_event = threading.Event() 42 self._is_activated = False 43 44 def _at_exit_callback(self) -> None: 45 """Callback function to be executed when the program is exiting. 46 47 Stops the data processing thread and waits for the queue to be drained. Finally, shuts down 48 the thread pools used for data logging and artifact processing status check. 49 """ 50 try: 51 # Stop the data processing thread 52 self._stop_data_logging_thread_event.set() 53 # Waits till logging queue is drained. 54 self._artifact_logging_thread.join() 55 self._artifact_logging_worker_threadpool.shutdown(wait=True) 56 self._artifact_status_check_threadpool.shutdown(wait=True) 57 except Exception as e: 58 _logger.error(f"Encountered error while trying to finish logging: {e}") 59 60 def flush(self) -> None: 61 """Flush the async logging queue. 62 63 Calling this method will flush the queue to ensure all the data are logged. 64 """ 65 # Stop the data processing thread. 66 self._stop_data_logging_thread_event.set() 67 # Waits till logging queue is drained. 68 self._artifact_logging_thread.join() 69 self._artifact_logging_worker_threadpool.shutdown(wait=True) 70 self._artifact_status_check_threadpool.shutdown(wait=True) 71 72 # Restart the thread to listen to incoming data after flushing. 73 self._stop_data_logging_thread_event.clear() 74 self._set_up_logging_thread() 75 76 def _logging_loop(self) -> None: 77 """ 78 Continuously logs run data until `self._continue_to_process_data` is set to False. 79 If an exception occurs during logging, a `MlflowException` is raised. 80 """ 81 try: 82 while not self._stop_data_logging_thread_event.is_set(): 83 self._log_artifact() 84 # Drain the queue after the stop event is set. 85 while not self._queue.empty(): 86 self._log_artifact() 87 except Exception as e: 88 from mlflow.exceptions import MlflowException 89 90 raise MlflowException(f"Exception inside the run data logging thread: {e}") 91 92 def _log_artifact(self) -> None: 93 """Process the run's artifacts in the running runs queues. 94 95 For each run in the running runs queues, this method retrieves the next artifact of run 96 from the queue and processes it by calling the `_artifact_logging_func` method with the run 97 ID and artifact. If the artifact is empty, it is skipped. After processing the artifact, 98 the processed watermark is updated and the artifact event is set. 99 If an exception occurs during processing, the exception is logged and the artifact event 100 is set with the exception. If the queue is empty, it is ignored. 101 """ 102 try: 103 run_artifact = self._queue.get(timeout=1) 104 except Empty: 105 # Ignore empty queue exception 106 return 107 108 def logging_func(run_artifact): 109 try: 110 self._artifact_logging_func( 111 filename=run_artifact.filename, 112 artifact_path=run_artifact.artifact_path, 113 artifact=run_artifact.artifact, 114 ) 115 116 # Signal the artifact processing is done. 117 run_artifact.completion_event.set() 118 119 except Exception as e: 120 _logger.error(f"Failed to log artifact {run_artifact.filename}. Exception: {e}") 121 run_artifact.exception = e 122 run_artifact.completion_event.set() 123 124 self._artifact_logging_worker_threadpool.submit(logging_func, run_artifact) 125 126 def _wait_for_artifact(self, artifact: RunArtifact) -> None: 127 """Wait for given artifacts to be processed by the logging thread. 128 129 Args: 130 artifact: The artifact to wait for. 131 132 Raises: 133 Exception: If an exception occurred while processing the artifact. 134 """ 135 artifact.completion_event.wait() 136 if artifact.exception: 137 raise artifact.exception 138 139 def __getstate__(self): 140 """Return the state of the object for pickling. 141 142 This method is called by the `pickle` module when the object is being pickled. It returns a 143 dictionary containing the object's state, with non-picklable attributes removed. 144 145 Returns: 146 dict: A dictionary containing the object's state. 147 """ 148 state = self.__dict__.copy() 149 del state["_queue"] 150 del state["_lock"] 151 del state["_is_activated"] 152 153 if "_stop_data_logging_thread_event" in state: 154 del state["_stop_data_logging_thread_event"] 155 if "_artifact_logging_thread" in state: 156 del state["_artifact_logging_thread"] 157 if "_artifact_logging_worker_threadpool" in state: 158 del state["_artifact_logging_worker_threadpool"] 159 if "_artifact_status_check_threadpool" in state: 160 del state["_artifact_status_check_threadpool"] 161 162 return state 163 164 def __setstate__(self, state): 165 """Set the state of the object from a given state dictionary. 166 167 It pops back the removed non-picklable attributes from `self.__getstate__()`. 168 169 Args: 170 state (dict): A dictionary containing the state of the object. 171 172 Returns: 173 None 174 """ 175 self.__dict__.update(state) 176 self._queue = Queue() 177 self._lock = threading.RLock() 178 self._is_activated = False 179 self._artifact_logging_thread = None 180 self._artifact_logging_worker_threadpool = None 181 self._artifact_status_check_threadpool = None 182 self._stop_data_logging_thread_event = threading.Event() 183 184 def log_artifacts_async(self, filename, artifact_path, artifact) -> RunOperations: 185 """Asynchronously logs runs artifacts. 186 187 Args: 188 filename: Filename of the artifact to be logged. 189 artifact_path: Directory within the run's artifact directory in which to log the 190 artifact. 191 artifact: The artifact to be logged. 192 193 Returns: 194 mlflow.utils.async_utils.RunOperations: An object that encapsulates the 195 asynchronous operation of logging the artifact of run data. 196 The object contains a list of `concurrent.futures.Future` objects that can be used 197 to check the status of the operation and retrieve any exceptions 198 that occurred during the operation. 199 """ 200 from mlflow import MlflowException 201 202 if not self._is_activated: 203 raise MlflowException("AsyncArtifactsLoggingQueue is not activated.") 204 artifact = RunArtifact( 205 filename=filename, 206 artifact_path=artifact_path, 207 artifact=artifact, 208 completion_event=threading.Event(), 209 ) 210 self._queue.put(artifact) 211 operation_future = self._artifact_status_check_threadpool.submit( 212 self._wait_for_artifact, artifact 213 ) 214 return RunOperations(operation_futures=[operation_future]) 215 216 def is_active(self) -> bool: 217 return self._is_activated 218 219 def _set_up_logging_thread(self) -> None: 220 """Sets up the logging thread. 221 222 If the logging thread is already set up, this method does nothing. 223 """ 224 with self._lock: 225 self._artifact_logging_thread = threading.Thread( 226 target=self._logging_loop, 227 name="MLflowAsyncArtifactsLoggingLoop", 228 daemon=True, 229 ) 230 self._artifact_logging_worker_threadpool = ThreadPoolExecutor( 231 max_workers=5, 232 thread_name_prefix="MLflowArtifactsLoggingWorkerPool", 233 ) 234 235 self._artifact_status_check_threadpool = ThreadPoolExecutor( 236 max_workers=5, 237 thread_name_prefix="MLflowAsyncArtifactsLoggingStatusCheck", 238 ) 239 self._artifact_logging_thread.start() 240 241 def activate(self) -> None: 242 """Activates the async logging queue 243 244 1. Initializes queue draining thread. 245 2. Initializes threads for checking the status of logged artifact. 246 3. Registering an atexit callback to ensure that any remaining log data 247 is flushed before the program exits. 248 249 If the queue is already activated, this method does nothing. 250 """ 251 with self._lock: 252 if self._is_activated: 253 return 254 255 self._set_up_logging_thread() 256 atexit.register(self._at_exit_callback) 257 258 self._is_activated = True