/ tests / utils / test_async_logging_queue.py
test_async_logging_queue.py
  1  import contextlib
  2  import io
  3  import pickle
  4  import random
  5  import threading
  6  import time
  7  import uuid
  8  from unittest.mock import MagicMock, patch
  9  
 10  import pytest
 11  
 12  import mlflow.utils.async_logging.async_logging_queue
 13  from mlflow import MlflowException
 14  from mlflow.entities.metric import Metric
 15  from mlflow.entities.param import Param
 16  from mlflow.entities.run_tag import RunTag
 17  from mlflow.utils.async_logging.async_logging_queue import AsyncLoggingQueue, QueueStatus
 18  
 19  METRIC_PER_BATCH = 250
 20  TAGS_PER_BATCH = 1
 21  PARAMS_PER_BATCH = 1
 22  TOTAL_BATCHES = 5
 23  
 24  
 25  class RunData:
 26      def __init__(self, throw_exception_on_batch_number=None) -> None:
 27          if throw_exception_on_batch_number is None:
 28              throw_exception_on_batch_number = []
 29          self.received_run_id = ""
 30          self.received_metrics = []
 31          self.received_tags = []
 32          self.received_params = []
 33          self.batch_count = 0
 34          self.throw_exception_on_batch_number = throw_exception_on_batch_number or []
 35  
 36      def consume_queue_data(self, run_id, metrics, tags, params):
 37          self.batch_count += 1
 38          if self.batch_count in self.throw_exception_on_batch_number:
 39              raise MlflowException("Failed to log run data")
 40          self.received_run_id = run_id
 41          self.received_metrics.extend(metrics or [])
 42          self.received_params.extend(params or [])
 43          self.received_tags.extend(tags or [])
 44  
 45  
 46  @contextlib.contextmanager
 47  def generate_async_logging_queue(clazz):
 48      async_logging_queue = AsyncLoggingQueue(clazz.consume_queue_data)
 49      try:
 50          yield async_logging_queue
 51      finally:
 52          async_logging_queue.shut_down_async_logging()
 53  
 54  
 55  def test_single_thread_publish_consume_queue(monkeypatch):
 56      monkeypatch.setenv("MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS", "3")
 57  
 58      with (
 59          patch.object(
 60              AsyncLoggingQueue, "_batch_logging_worker_threadpool", create=True
 61          ) as mock_worker_threadpool,
 62          patch.object(
 63              AsyncLoggingQueue, "_batch_status_check_threadpool", create=True
 64          ) as mock_check_threadpool,
 65      ):
 66          mock_worker_threadpool.submit = MagicMock()
 67          mock_check_threadpool.submit = MagicMock()
 68          mock_worker_threadpool.shutdown = MagicMock()
 69          mock_check_threadpool.shutdown = MagicMock()
 70  
 71          run_id = "test_run_id"
 72          run_data = RunData()
 73          with generate_async_logging_queue(run_data) as async_logging_queue:
 74              async_logging_queue.activate()
 75              async_logging_queue._batch_logging_worker_threadpool = mock_worker_threadpool
 76              async_logging_queue._batch_status_check_threadpool = mock_check_threadpool
 77  
 78              for params, tags, metrics in _get_run_data():
 79                  async_logging_queue.log_batch_async(
 80                      run_id=run_id, metrics=metrics, tags=tags, params=params
 81                  )
 82              async_logging_queue.flush()
 83              # 2 batches are sent to the worker thread pool due to grouping, otherwise it would be 5.
 84              assert mock_worker_threadpool.submit.call_count == 2
 85              assert async_logging_queue.is_active()
 86              assert mock_check_threadpool.shutdown.call_count == 1
 87              assert mock_worker_threadpool.shutdown.call_count == 1
 88  
 89  
 90  def test_grouping_batch_in_time_window():
 91      run_id = "test_run_id"
 92      run_data = RunData()
 93      with generate_async_logging_queue(run_data) as async_logging_queue:
 94          async_logging_queue.activate()
 95          metrics_sent = []
 96          tags_sent = []
 97          params_sent = []
 98  
 99          for params, tags, metrics in _get_run_data():
100              async_logging_queue.log_batch_async(
101                  run_id=run_id, metrics=metrics, tags=tags, params=params
102              )
103              metrics_sent += metrics
104              tags_sent += tags
105              params_sent += params
106  
107          async_logging_queue.flush()
108  
109          _assert_sent_received_data(
110              metrics_sent,
111              params_sent,
112              tags_sent,
113              run_data.received_metrics,
114              run_data.received_params,
115              run_data.received_tags,
116          )
117  
118  
119  def test_queue_activation():
120      run_id = "test_run_id"
121      run_data = RunData()
122      with generate_async_logging_queue(run_data) as async_logging_queue:
123          assert async_logging_queue.is_idle()
124  
125          metrics = [
126              Metric(
127                  key=f"batch metrics async-{val}",
128                  value=val,
129                  timestamp=val,
130                  step=0,
131              )
132              for val in range(METRIC_PER_BATCH)
133          ]
134          with pytest.raises(MlflowException, match="AsyncLoggingQueue is not activated."):
135              async_logging_queue.log_batch_async(run_id=run_id, metrics=metrics, tags=[], params=[])
136  
137          async_logging_queue.activate()
138          assert async_logging_queue.is_active()
139  
140  
141  def test_end_async_logging():
142      run_id = "test_run_id"
143      run_data = RunData()
144      with generate_async_logging_queue(run_data) as async_logging_queue:
145          async_logging_queue.activate()
146  
147          metrics = [
148              Metric(
149                  key=f"batch metrics async-{val}",
150                  value=val,
151                  timestamp=val,
152                  step=0,
153              )
154              for val in range(METRIC_PER_BATCH)
155          ]
156          async_logging_queue.log_batch_async(run_id=run_id, metrics=metrics, tags=[], params=[])
157          async_logging_queue.end_async_logging()
158          assert async_logging_queue._status == QueueStatus.TEAR_DOWN
159          # end_async_logging should not shutdown the threadpool
160          assert not async_logging_queue._batch_logging_worker_threadpool._shutdown
161          assert not async_logging_queue._batch_status_check_threadpool._shutdown
162  
163          async_logging_queue.flush()
164          assert async_logging_queue.is_active()
165  
166  
167  def test_partial_logging_failed():
168      run_id = "test_run_id"
169      run_data = RunData(throw_exception_on_batch_number=[3, 4])
170      with generate_async_logging_queue(run_data) as async_logging_queue:
171          async_logging_queue.activate()
172  
173          metrics_sent = []
174          tags_sent = []
175          params_sent = []
176  
177          run_operations = []
178          batch_id = 1
179          for params, tags, metrics in _get_run_data():
180              if batch_id in [3, 4]:
181                  with pytest.raises(MlflowException, match="Failed to log run data"):
182                      async_logging_queue.log_batch_async(
183                          run_id=run_id, metrics=metrics, tags=tags, params=params
184                      ).wait()
185              else:
186                  run_operations.append(
187                      async_logging_queue.log_batch_async(
188                          run_id=run_id, metrics=metrics, tags=tags, params=params
189                      )
190                  )
191                  metrics_sent += metrics
192                  tags_sent += tags
193                  params_sent += params
194  
195              batch_id += 1
196  
197          for run_operation in run_operations:
198              run_operation.wait()
199  
200          _assert_sent_received_data(
201              metrics_sent,
202              params_sent,
203              tags_sent,
204              run_data.received_metrics,
205              run_data.received_params,
206              run_data.received_tags,
207          )
208  
209  
210  def test_publish_multithread_consume_single_thread():
211      run_id = "test_run_id"
212      run_data = RunData(throw_exception_on_batch_number=[])
213      with generate_async_logging_queue(run_data) as async_logging_queue:
214          async_logging_queue.activate()
215  
216          run_operations = []
217          t1 = threading.Thread(
218              name="test-async-logging-1",
219              target=_send_metrics_tags_params,
220              args=(async_logging_queue, run_id, run_operations),
221          )
222          t2 = threading.Thread(
223              name="test-async-logging-2",
224              target=_send_metrics_tags_params,
225              args=(async_logging_queue, run_id, run_operations),
226          )
227  
228          t1.start()
229          t2.start()
230          t1.join()
231          t2.join()
232  
233          for run_operation in run_operations:
234              run_operation.wait()
235  
236          assert len(run_data.received_metrics) == 2 * METRIC_PER_BATCH * TOTAL_BATCHES
237          assert len(run_data.received_tags) == 2 * TAGS_PER_BATCH * TOTAL_BATCHES
238          assert len(run_data.received_params) == 2 * PARAMS_PER_BATCH * TOTAL_BATCHES
239  
240  
241  class Consumer:
242      def __init__(self) -> None:
243          self.metrics = []
244          self.tags = []
245          self.params = []
246          self.barrier = threading.Event()
247  
248      def consume_queue_data(self, run_id, metrics, tags, params):
249          self.barrier.wait()
250          self.metrics.extend(metrics or [])
251          self.params.extend(params or [])
252          self.tags.extend(tags or [])
253  
254      def __getstate__(self):
255          state = self.__dict__.copy()
256          del state["barrier"]
257          return state
258  
259      def __setstate__(self, state):
260          self.__dict__.update(state)
261          self.barrier = threading.Event()
262  
263  
264  def test_async_logging_queue_pickle():
265      run_id = "test_run_id"
266      consumer = Consumer()
267      with generate_async_logging_queue(consumer) as async_logging_queue:
268          # Pickle the queue without activating it.
269          buffer = io.BytesIO()
270          pickle.dump(async_logging_queue, buffer)
271          deserialized_queue = pickle.loads(buffer.getvalue())  # Type: AsyncLoggingQueue
272  
273          # Activate the queue and submit 10 items. Workers block on the barrier,
274          # so the consumer's state remains empty during pickling.
275          async_logging_queue.activate()
276  
277          run_operations = [
278              async_logging_queue.log_batch_async(
279                  run_id=run_id,
280                  metrics=[Metric("metric", val, timestamp=time.time(), step=1)],
281                  tags=[],
282                  params=[],
283              )
284              for val in range(0, 10)
285          ]
286  
287          # Pickle while workers are blocked — consumer state is deterministically empty.
288          buffer = io.BytesIO()
289          pickle.dump(async_logging_queue, buffer)
290  
291          deserialized_queue = pickle.loads(buffer.getvalue())  # Type: AsyncLoggingQueue
292          assert deserialized_queue._queue.empty()
293          assert deserialized_queue._lock is not None
294          assert deserialized_queue._status is QueueStatus.IDLE
295  
296          # Release workers and wait for all operations to complete.
297          consumer.barrier.set()
298  
299          for run_operation in run_operations:
300              run_operation.wait()
301  
302          assert len(consumer.metrics) == 10
303  
304          # Activate the deserialized queue and submit 10 more items.
305          # The deserialized consumer is a separate copy with an empty metrics list.
306          deserialized_consumer = deserialized_queue._logging_func.__self__
307          deserialized_consumer.barrier.set()
308          deserialized_queue.activate()
309          assert deserialized_queue.is_active()
310  
311          run_operations = []
312  
313          for val in range(0, 10):
314              run_operations.append(
315                  deserialized_queue.log_batch_async(
316                      run_id=run_id,
317                      metrics=[Metric("metric", val, timestamp=time.time(), step=1)],
318                      tags=[],
319                      params=[],
320                  )
321              )
322  
323          for run_operation in run_operations:
324              run_operation.wait()
325  
326          assert len(deserialized_consumer.metrics) == 10
327  
328          deserialized_queue.shut_down_async_logging()
329  
330  
331  def _send_metrics_tags_params(run_data_queueing_processor, run_id, run_operations=None):
332      if run_operations is None:
333          run_operations = []
334      metrics_sent = []
335      tags_sent = []
336      params_sent = []
337  
338      for params, tags, metrics in _get_run_data():
339          run_operations.append(
340              run_data_queueing_processor.log_batch_async(
341                  run_id=run_id, metrics=metrics, tags=tags, params=params
342              )
343          )
344  
345          time.sleep(random.randint(1, 3))
346          metrics_sent += metrics
347          tags_sent += tags
348          params_sent += params
349  
350  
351  def _get_run_data(
352      total_batches=TOTAL_BATCHES,
353      params_per_batch=PARAMS_PER_BATCH,
354      tags_per_batch=TAGS_PER_BATCH,
355      metrics_per_batch=METRIC_PER_BATCH,
356  ):
357      for num in range(0, total_batches):
358          guid8 = str(uuid.uuid4())[:8]
359          params = [
360              Param(f"batch param-{guid8}-{val}", value=str(time.time()))
361              for val in range(params_per_batch)
362          ]
363          tags = [
364              RunTag(f"batch tag-{guid8}-{val}", value=str(time.time()))
365              for val in range(tags_per_batch)
366          ]
367          metrics = [
368              Metric(
369                  key=f"batch metrics async-{num}",
370                  value=val,
371                  timestamp=int(time.time() * 1000),
372                  step=0,
373              )
374              for val in range(metrics_per_batch)
375          ]
376          yield params, tags, metrics
377  
378  
379  def _assert_sent_received_data(
380      metrics_sent, params_sent, tags_sent, received_metrics, received_params, received_tags
381  ):
382      for num in range(1, len(metrics_sent)):
383          assert metrics_sent[num].key == received_metrics[num].key
384          assert metrics_sent[num].value == received_metrics[num].value
385          assert metrics_sent[num].timestamp == received_metrics[num].timestamp
386          assert metrics_sent[num].step == received_metrics[num].step
387  
388      for num in range(1, len(tags_sent)):
389          assert tags_sent[num].key == received_tags[num].key
390          assert tags_sent[num].value == received_tags[num].value
391  
392      for num in range(1, len(params_sent)):
393          assert params_sent[num].key == received_params[num].key
394          assert params_sent[num].value == received_params[num].value
395  
396  
397  def test_batch_split(monkeypatch):
398      monkeypatch.setattr(mlflow.utils.async_logging.async_logging_queue, "_MAX_ITEMS_PER_BATCH", 10)
399      monkeypatch.setattr(mlflow.utils.async_logging.async_logging_queue, "_MAX_PARAMS_PER_BATCH", 6)
400      monkeypatch.setattr(mlflow.utils.async_logging.async_logging_queue, "_MAX_TAGS_PER_BATCH", 8)
401  
402      run_data = RunData()
403      with generate_async_logging_queue(run_data) as async_logging_queue:
404          async_logging_queue.activate()
405  
406          run_id = "test_run_id"
407          for params, tags, metrics in _get_run_data(2, 3, 3, 3):
408              async_logging_queue.log_batch_async(
409                  run_id=run_id, metrics=metrics, tags=tags, params=params
410              )
411          async_logging_queue.flush()
412  
413          assert run_data.batch_count == 2
414  
415      run_data = RunData()
416      with generate_async_logging_queue(run_data) as async_logging_queue:
417          async_logging_queue.activate()
418  
419          run_id = "test_run_id"
420          for params, tags, metrics in _get_run_data(2, 4, 0, 0):
421              async_logging_queue.log_batch_async(
422                  run_id=run_id, metrics=metrics, tags=tags, params=params
423              )
424          async_logging_queue.flush()
425  
426          assert run_data.batch_count == 2
427  
428      run_data = RunData()
429      with generate_async_logging_queue(run_data) as async_logging_queue:
430          async_logging_queue.activate()
431  
432          run_id = "test_run_id"
433          for params, tags, metrics in _get_run_data(2, 0, 5, 0):
434              async_logging_queue.log_batch_async(
435                  run_id=run_id, metrics=metrics, tags=tags, params=params
436              )
437          async_logging_queue.flush()
438  
439          assert run_data.batch_count == 2