/ tests / utils / test_async_artifacts_logging_queue.py
test_async_artifacts_logging_queue.py
  1  import io
  2  import pickle
  3  import random
  4  import threading
  5  import time
  6  
  7  import pytest
  8  from PIL import Image
  9  
 10  from mlflow import MlflowException
 11  from mlflow.utils.async_logging.async_artifacts_logging_queue import AsyncArtifactsLoggingQueue
 12  
 13  TOTAL_ARTIFACTS = 5
 14  
 15  
 16  class RunArtifacts:
 17      def __init__(self, throw_exception_on_artifact_number=None):
 18          if throw_exception_on_artifact_number is None:
 19              throw_exception_on_artifact_number = []
 20          self.received_run_id = ""
 21          self.received_artifacts = []
 22          self.received_filenames = []
 23          self.received_artifact_paths = []
 24          self.artifact_count = 0
 25          self.throw_exception_on_artifact_number = throw_exception_on_artifact_number or []
 26  
 27      def consume_queue_data(self, filename, artifact_path, artifact):
 28          self.artifact_count += 1
 29          if self.artifact_count in self.throw_exception_on_artifact_number:
 30              raise MlflowException("Failed to log run data")
 31          self.received_artifacts.append(artifact)
 32          self.received_filenames.append(filename)
 33          self.received_artifact_paths.append(artifact_path)
 34  
 35  
 36  def _get_run_artifacts(total_artifacts=TOTAL_ARTIFACTS):
 37      for num in range(0, total_artifacts):
 38          filename = f"image_{num}.png"
 39          artifact_path = f"images/artifact_{num}"
 40          artifact = Image.new("RGB", (100, 100), color="red")
 41          yield filename, artifact_path, artifact
 42  
 43  
 44  def _assert_sent_received_artifacts(
 45      filenames_sent,
 46      artifact_paths_sent,
 47      artifacts_sent,
 48      received_filenames,
 49      received_artifact_paths,
 50      received_artifacts,
 51  ):
 52      for num in range(1, len(filenames_sent)):
 53          assert filenames_sent[num] == received_filenames[num]
 54  
 55      for num in range(1, len(artifact_paths_sent)):
 56          assert artifact_paths_sent[num] == received_artifact_paths[num]
 57  
 58      for num in range(1, len(artifacts_sent)):
 59          assert artifacts_sent[num] == received_artifacts[num]
 60  
 61  
 62  def test_single_thread_publish_consume_queue():
 63      run_artifacts = RunArtifacts()
 64      async_logging_queue = AsyncArtifactsLoggingQueue(run_artifacts.consume_queue_data)
 65      async_logging_queue.activate()
 66      filenames_sent = []
 67      artifact_paths_sent = []
 68      artifacts_sent = []
 69      for filename, artifact_path, artifact in _get_run_artifacts():
 70          async_logging_queue.log_artifacts_async(
 71              filename=filename, artifact_path=artifact_path, artifact=artifact
 72          )
 73          filenames_sent.append(filename)
 74          artifact_paths_sent.append(artifact_path)
 75          artifacts_sent.append(artifact)
 76      async_logging_queue.flush()
 77  
 78      _assert_sent_received_artifacts(
 79          filenames_sent,
 80          artifact_paths_sent,
 81          artifacts_sent,
 82          run_artifacts.received_filenames,
 83          run_artifacts.received_artifact_paths,
 84          run_artifacts.received_artifacts,
 85      )
 86  
 87  
 88  def test_queue_activation():
 89      run_artifacts = RunArtifacts()
 90      async_logging_queue = AsyncArtifactsLoggingQueue(run_artifacts.consume_queue_data)
 91  
 92      assert not async_logging_queue._is_activated
 93  
 94      for filename, artifact_path, artifact in _get_run_artifacts(1):
 95          with pytest.raises(MlflowException, match="AsyncArtifactsLoggingQueue is not activated."):
 96              async_logging_queue.log_artifacts_async(
 97                  filename=filename, artifact_path=artifact_path, artifact=artifact
 98              )
 99  
100      async_logging_queue.activate()
101      assert async_logging_queue._is_activated
102  
103  
104  def test_partial_logging_failed():
105      run_data = RunArtifacts(throw_exception_on_artifact_number=[3, 4])
106  
107      async_logging_queue = AsyncArtifactsLoggingQueue(run_data.consume_queue_data)
108      async_logging_queue.activate()
109  
110      filenames_sent = []
111      artifact_paths_sent = []
112      artifacts_sent = []
113  
114      run_operations = []
115      batch_id = 1
116      for filename, artifact_path, artifact in _get_run_artifacts():
117          if batch_id in [3, 4]:
118              with pytest.raises(MlflowException, match="Failed to log run data"):
119                  async_logging_queue.log_artifacts_async(
120                      filename=filename, artifact_path=artifact_path, artifact=artifact
121                  ).wait()
122          else:
123              run_operations.append(
124                  async_logging_queue.log_artifacts_async(
125                      filename=filename, artifact_path=artifact_path, artifact=artifact
126                  )
127              )
128              filenames_sent.append(filename)
129              artifact_paths_sent.append(artifact_path)
130              artifacts_sent.append(artifact)
131  
132          batch_id += 1
133  
134      for run_operation in run_operations:
135          run_operation.wait()
136  
137      _assert_sent_received_artifacts(
138          filenames_sent,
139          artifact_paths_sent,
140          artifacts_sent,
141          run_data.received_filenames,
142          run_data.received_artifact_paths,
143          run_data.received_artifacts,
144      )
145  
146  
147  def test_publish_multithread_consume_single_thread():
148      run_data = RunArtifacts(throw_exception_on_artifact_number=[])
149  
150      async_logging_queue = AsyncArtifactsLoggingQueue(run_data.consume_queue_data)
151      async_logging_queue.activate()
152  
153      def _send_artifact(run_data_queueing_processor, run_operations=None):
154          if run_operations is None:
155              run_operations = []
156          filenames_sent = []
157          artifact_paths_sent = []
158          artifacts_sent = []
159  
160          for filename, artifact_path, artifact in _get_run_artifacts():
161              run_operations.append(
162                  run_data_queueing_processor.log_artifacts_async(
163                      filename=filename, artifact_path=artifact_path, artifact=artifact
164                  )
165              )
166  
167              time.sleep(random.randint(1, 3))
168              filenames_sent.append(filename)
169              artifact_paths_sent.append(artifact_path)
170              artifacts_sent.append(artifact)
171  
172      run_operations = []
173      t1 = threading.Thread(
174          name="test-async-artifacts-1",
175          target=_send_artifact,
176          args=(async_logging_queue, run_operations),
177      )
178      t2 = threading.Thread(
179          name="test-async-artifacts-2",
180          target=_send_artifact,
181          args=(async_logging_queue, run_operations),
182      )
183  
184      t1.start()
185      t2.start()
186      t1.join()
187      t2.join()
188  
189      for run_operation in run_operations:
190          run_operation.wait()
191  
192      assert len(run_data.received_filenames) == 2 * TOTAL_ARTIFACTS
193      assert len(run_data.received_artifact_paths) == 2 * TOTAL_ARTIFACTS
194      assert len(run_data.received_artifacts) == 2 * TOTAL_ARTIFACTS
195  
196  
197  class Consumer:
198      def __init__(self) -> None:
199          self.filenames = []
200          self.artifact_paths = []
201          self.artifacts = []
202          self.barrier = threading.Event()
203  
204      def consume_queue_data(self, filename, artifact_path, artifact):
205          self.barrier.wait()
206          self.filenames.append(filename)
207          self.artifact_paths.append(artifact_path)
208          self.artifacts.append(artifact)
209  
210      def __getstate__(self):
211          state = self.__dict__.copy()
212          del state["barrier"]
213          return state
214  
215      def __setstate__(self, state):
216          self.__dict__.update(state)
217          self.barrier = threading.Event()
218  
219  
220  def test_async_logging_queue_pickle():
221      consumer = Consumer()
222      async_logging_queue = AsyncArtifactsLoggingQueue(consumer.consume_queue_data)
223  
224      # Pickle the queue without activating it.
225      buffer = io.BytesIO()
226      pickle.dump(async_logging_queue, buffer)
227      deserialized_queue = pickle.loads(buffer.getvalue())  # Type: AsyncArtifactsLoggingQueue
228  
229      # Activate the queue and submit 10 items. Workers block on the barrier,
230      # so the consumer's state remains empty during pickling.
231      async_logging_queue.activate()
232  
233      run_operations = [
234          async_logging_queue.log_artifacts_async(
235              filename=f"image-{val}.png",
236              artifact_path="images/image-artifact.png",
237              artifact=Image.new("RGB", (100, 100), color="blue"),
238          )
239          for val in range(0, 10)
240      ]
241  
242      # Pickle while workers are blocked — consumer state is deterministically empty.
243      buffer = io.BytesIO()
244      pickle.dump(async_logging_queue, buffer)
245  
246      deserialized_queue = pickle.loads(buffer.getvalue())  # Type: AsyncLoggingQueue
247      assert deserialized_queue._queue.empty()
248      assert deserialized_queue._lock is not None
249      assert deserialized_queue._is_activated is False
250  
251      # Release workers and wait for all operations to complete.
252      consumer.barrier.set()
253  
254      for run_operation in run_operations:
255          run_operation.wait()
256  
257      assert len(consumer.filenames) == 10
258  
259      # Activate the deserialized queue and submit 10 more items.
260      # The deserialized consumer is a separate copy with an empty filenames list.
261      deserialized_consumer = deserialized_queue._artifact_logging_func.__self__
262      deserialized_consumer.barrier.set()
263      deserialized_queue.activate()
264      assert deserialized_queue._is_activated
265  
266      run_operations = []
267  
268      for val in range(0, 10):
269          run_operations.append(
270              deserialized_queue.log_artifacts_async(
271                  filename=f"image2-{val}.png",
272                  artifact_path="images/image-artifact2.png",
273                  artifact=Image.new("RGB", (100, 100), color="green"),
274              )
275          )
276  
277      for run_operation in run_operations:
278          run_operation.wait()
279  
280      assert len(deserialized_consumer.filenames) == 10