/ mlflow / utils / async_logging / async_artifacts_logging_queue.py
async_artifacts_logging_queue.py
  1  """
  2  Defines an AsyncArtifactsLoggingQueue that provides async fashion artifact writes using
  3  queue based approach.
  4  """
  5  
  6  import atexit
  7  import logging
  8  import threading
  9  from concurrent.futures import ThreadPoolExecutor
 10  from queue import Empty, Queue
 11  from typing import TYPE_CHECKING, Callable, Union
 12  
 13  from mlflow.utils.async_logging.run_artifact import RunArtifact
 14  from mlflow.utils.async_logging.run_operations import RunOperations
 15  
 16  if TYPE_CHECKING:
 17      import PIL.Image
 18  
 19  _logger = logging.getLogger(__name__)
 20  
 21  
 22  class AsyncArtifactsLoggingQueue:
 23      """
 24      This is a queue based run data processor that queue incoming data and process it using a single
 25      worker thread. This class is used to process artifacts saving in async fashion.
 26  
 27      Args:
 28          logging_func: A callable function that takes in three arguments:
 29              - filename: The name of the artifact file.
 30              - artifact_path: The path to the artifact.
 31              - artifact: The artifact to be logged.
 32      """
 33  
 34      def __init__(
 35          self, artifact_logging_func: Callable[[str, str, Union["PIL.Image.Image"]], None]
 36      ) -> None:
 37          self._queue: Queue[RunArtifact] = Queue()
 38          self._lock = threading.RLock()
 39          self._artifact_logging_func = artifact_logging_func
 40  
 41          self._stop_data_logging_thread_event = threading.Event()
 42          self._is_activated = False
 43  
 44      def _at_exit_callback(self) -> None:
 45          """Callback function to be executed when the program is exiting.
 46  
 47          Stops the data processing thread and waits for the queue to be drained. Finally, shuts down
 48          the thread pools used for data logging and artifact processing status check.
 49          """
 50          try:
 51              # Stop the data processing thread
 52              self._stop_data_logging_thread_event.set()
 53              # Waits till logging queue is drained.
 54              self._artifact_logging_thread.join()
 55              self._artifact_logging_worker_threadpool.shutdown(wait=True)
 56              self._artifact_status_check_threadpool.shutdown(wait=True)
 57          except Exception as e:
 58              _logger.error(f"Encountered error while trying to finish logging: {e}")
 59  
 60      def flush(self) -> None:
 61          """Flush the async logging queue.
 62  
 63          Calling this method will flush the queue to ensure all the data are logged.
 64          """
 65          # Stop the data processing thread.
 66          self._stop_data_logging_thread_event.set()
 67          # Waits till logging queue is drained.
 68          self._artifact_logging_thread.join()
 69          self._artifact_logging_worker_threadpool.shutdown(wait=True)
 70          self._artifact_status_check_threadpool.shutdown(wait=True)
 71  
 72          # Restart the thread to listen to incoming data after flushing.
 73          self._stop_data_logging_thread_event.clear()
 74          self._set_up_logging_thread()
 75  
 76      def _logging_loop(self) -> None:
 77          """
 78          Continuously logs run data until `self._continue_to_process_data` is set to False.
 79          If an exception occurs during logging, a `MlflowException` is raised.
 80          """
 81          try:
 82              while not self._stop_data_logging_thread_event.is_set():
 83                  self._log_artifact()
 84              # Drain the queue after the stop event is set.
 85              while not self._queue.empty():
 86                  self._log_artifact()
 87          except Exception as e:
 88              from mlflow.exceptions import MlflowException
 89  
 90              raise MlflowException(f"Exception inside the run data logging thread: {e}")
 91  
 92      def _log_artifact(self) -> None:
 93          """Process the run's artifacts in the running runs queues.
 94  
 95          For each run in the running runs queues, this method retrieves the next artifact of run
 96          from the queue and processes it by calling the `_artifact_logging_func` method with the run
 97          ID and artifact. If the artifact is empty, it is skipped. After processing the artifact,
 98          the processed watermark is updated and the artifact event is set.
 99          If an exception occurs during processing, the exception is logged and the artifact event
100          is set with the exception. If the queue is empty, it is ignored.
101          """
102          try:
103              run_artifact = self._queue.get(timeout=1)
104          except Empty:
105              # Ignore empty queue exception
106              return
107  
108          def logging_func(run_artifact):
109              try:
110                  self._artifact_logging_func(
111                      filename=run_artifact.filename,
112                      artifact_path=run_artifact.artifact_path,
113                      artifact=run_artifact.artifact,
114                  )
115  
116                  # Signal the artifact processing is done.
117                  run_artifact.completion_event.set()
118  
119              except Exception as e:
120                  _logger.error(f"Failed to log artifact {run_artifact.filename}. Exception: {e}")
121                  run_artifact.exception = e
122                  run_artifact.completion_event.set()
123  
124          self._artifact_logging_worker_threadpool.submit(logging_func, run_artifact)
125  
126      def _wait_for_artifact(self, artifact: RunArtifact) -> None:
127          """Wait for given artifacts to be processed by the logging thread.
128  
129          Args:
130              artifact: The artifact to wait for.
131  
132          Raises:
133              Exception: If an exception occurred while processing the artifact.
134          """
135          artifact.completion_event.wait()
136          if artifact.exception:
137              raise artifact.exception
138  
139      def __getstate__(self):
140          """Return the state of the object for pickling.
141  
142          This method is called by the `pickle` module when the object is being pickled. It returns a
143          dictionary containing the object's state, with non-picklable attributes removed.
144  
145          Returns:
146              dict: A dictionary containing the object's state.
147          """
148          state = self.__dict__.copy()
149          del state["_queue"]
150          del state["_lock"]
151          del state["_is_activated"]
152  
153          if "_stop_data_logging_thread_event" in state:
154              del state["_stop_data_logging_thread_event"]
155          if "_artifact_logging_thread" in state:
156              del state["_artifact_logging_thread"]
157          if "_artifact_logging_worker_threadpool" in state:
158              del state["_artifact_logging_worker_threadpool"]
159          if "_artifact_status_check_threadpool" in state:
160              del state["_artifact_status_check_threadpool"]
161  
162          return state
163  
164      def __setstate__(self, state):
165          """Set the state of the object from a given state dictionary.
166  
167          It pops back the removed non-picklable attributes from `self.__getstate__()`.
168  
169          Args:
170              state (dict): A dictionary containing the state of the object.
171  
172          Returns:
173              None
174          """
175          self.__dict__.update(state)
176          self._queue = Queue()
177          self._lock = threading.RLock()
178          self._is_activated = False
179          self._artifact_logging_thread = None
180          self._artifact_logging_worker_threadpool = None
181          self._artifact_status_check_threadpool = None
182          self._stop_data_logging_thread_event = threading.Event()
183  
184      def log_artifacts_async(self, filename, artifact_path, artifact) -> RunOperations:
185          """Asynchronously logs runs artifacts.
186  
187          Args:
188              filename: Filename of the artifact to be logged.
189              artifact_path: Directory within the run's artifact directory in which to log the
190                  artifact.
191              artifact: The artifact to be logged.
192  
193          Returns:
194              mlflow.utils.async_utils.RunOperations: An object that encapsulates the
195                  asynchronous operation of logging the artifact of run data.
196                  The object contains a list of `concurrent.futures.Future` objects that can be used
197                  to check the status of the operation and retrieve any exceptions
198                  that occurred during the operation.
199          """
200          from mlflow import MlflowException
201  
202          if not self._is_activated:
203              raise MlflowException("AsyncArtifactsLoggingQueue is not activated.")
204          artifact = RunArtifact(
205              filename=filename,
206              artifact_path=artifact_path,
207              artifact=artifact,
208              completion_event=threading.Event(),
209          )
210          self._queue.put(artifact)
211          operation_future = self._artifact_status_check_threadpool.submit(
212              self._wait_for_artifact, artifact
213          )
214          return RunOperations(operation_futures=[operation_future])
215  
216      def is_active(self) -> bool:
217          return self._is_activated
218  
219      def _set_up_logging_thread(self) -> None:
220          """Sets up the logging thread.
221  
222          If the logging thread is already set up, this method does nothing.
223          """
224          with self._lock:
225              self._artifact_logging_thread = threading.Thread(
226                  target=self._logging_loop,
227                  name="MLflowAsyncArtifactsLoggingLoop",
228                  daemon=True,
229              )
230              self._artifact_logging_worker_threadpool = ThreadPoolExecutor(
231                  max_workers=5,
232                  thread_name_prefix="MLflowArtifactsLoggingWorkerPool",
233              )
234  
235              self._artifact_status_check_threadpool = ThreadPoolExecutor(
236                  max_workers=5,
237                  thread_name_prefix="MLflowAsyncArtifactsLoggingStatusCheck",
238              )
239              self._artifact_logging_thread.start()
240  
241      def activate(self) -> None:
242          """Activates the async logging queue
243  
244          1. Initializes queue draining thread.
245          2. Initializes threads for checking the status of logged artifact.
246          3. Registering an atexit callback to ensure that any remaining log data
247              is flushed before the program exits.
248  
249          If the queue is already activated, this method does nothing.
250          """
251          with self._lock:
252              if self._is_activated:
253                  return
254  
255              self._set_up_logging_thread()
256              atexit.register(self._at_exit_callback)
257  
258              self._is_activated = True