test_async_logging_queue.py
1 import contextlib 2 import io 3 import pickle 4 import random 5 import threading 6 import time 7 import uuid 8 from unittest.mock import MagicMock, patch 9 10 import pytest 11 12 import mlflow.utils.async_logging.async_logging_queue 13 from mlflow import MlflowException 14 from mlflow.entities.metric import Metric 15 from mlflow.entities.param import Param 16 from mlflow.entities.run_tag import RunTag 17 from mlflow.utils.async_logging.async_logging_queue import AsyncLoggingQueue, QueueStatus 18 19 METRIC_PER_BATCH = 250 20 TAGS_PER_BATCH = 1 21 PARAMS_PER_BATCH = 1 22 TOTAL_BATCHES = 5 23 24 25 class RunData: 26 def __init__(self, throw_exception_on_batch_number=None) -> None: 27 if throw_exception_on_batch_number is None: 28 throw_exception_on_batch_number = [] 29 self.received_run_id = "" 30 self.received_metrics = [] 31 self.received_tags = [] 32 self.received_params = [] 33 self.batch_count = 0 34 self.throw_exception_on_batch_number = throw_exception_on_batch_number or [] 35 36 def consume_queue_data(self, run_id, metrics, tags, params): 37 self.batch_count += 1 38 if self.batch_count in self.throw_exception_on_batch_number: 39 raise MlflowException("Failed to log run data") 40 self.received_run_id = run_id 41 self.received_metrics.extend(metrics or []) 42 self.received_params.extend(params or []) 43 self.received_tags.extend(tags or []) 44 45 46 @contextlib.contextmanager 47 def generate_async_logging_queue(clazz): 48 async_logging_queue = AsyncLoggingQueue(clazz.consume_queue_data) 49 try: 50 yield async_logging_queue 51 finally: 52 async_logging_queue.shut_down_async_logging() 53 54 55 def test_single_thread_publish_consume_queue(monkeypatch): 56 monkeypatch.setenv("MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS", "3") 57 58 with ( 59 patch.object( 60 AsyncLoggingQueue, "_batch_logging_worker_threadpool", create=True 61 ) as mock_worker_threadpool, 62 patch.object( 63 AsyncLoggingQueue, "_batch_status_check_threadpool", create=True 64 ) as mock_check_threadpool, 65 ): 66 mock_worker_threadpool.submit = MagicMock() 67 mock_check_threadpool.submit = MagicMock() 68 mock_worker_threadpool.shutdown = MagicMock() 69 mock_check_threadpool.shutdown = MagicMock() 70 71 run_id = "test_run_id" 72 run_data = RunData() 73 with generate_async_logging_queue(run_data) as async_logging_queue: 74 async_logging_queue.activate() 75 async_logging_queue._batch_logging_worker_threadpool = mock_worker_threadpool 76 async_logging_queue._batch_status_check_threadpool = mock_check_threadpool 77 78 for params, tags, metrics in _get_run_data(): 79 async_logging_queue.log_batch_async( 80 run_id=run_id, metrics=metrics, tags=tags, params=params 81 ) 82 async_logging_queue.flush() 83 # 2 batches are sent to the worker thread pool due to grouping, otherwise it would be 5. 84 assert mock_worker_threadpool.submit.call_count == 2 85 assert async_logging_queue.is_active() 86 assert mock_check_threadpool.shutdown.call_count == 1 87 assert mock_worker_threadpool.shutdown.call_count == 1 88 89 90 def test_grouping_batch_in_time_window(): 91 run_id = "test_run_id" 92 run_data = RunData() 93 with generate_async_logging_queue(run_data) as async_logging_queue: 94 async_logging_queue.activate() 95 metrics_sent = [] 96 tags_sent = [] 97 params_sent = [] 98 99 for params, tags, metrics in _get_run_data(): 100 async_logging_queue.log_batch_async( 101 run_id=run_id, metrics=metrics, tags=tags, params=params 102 ) 103 metrics_sent += metrics 104 tags_sent += tags 105 params_sent += params 106 107 async_logging_queue.flush() 108 109 _assert_sent_received_data( 110 metrics_sent, 111 params_sent, 112 tags_sent, 113 run_data.received_metrics, 114 run_data.received_params, 115 run_data.received_tags, 116 ) 117 118 119 def test_queue_activation(): 120 run_id = "test_run_id" 121 run_data = RunData() 122 with generate_async_logging_queue(run_data) as async_logging_queue: 123 assert async_logging_queue.is_idle() 124 125 metrics = [ 126 Metric( 127 key=f"batch metrics async-{val}", 128 value=val, 129 timestamp=val, 130 step=0, 131 ) 132 for val in range(METRIC_PER_BATCH) 133 ] 134 with pytest.raises(MlflowException, match="AsyncLoggingQueue is not activated."): 135 async_logging_queue.log_batch_async(run_id=run_id, metrics=metrics, tags=[], params=[]) 136 137 async_logging_queue.activate() 138 assert async_logging_queue.is_active() 139 140 141 def test_end_async_logging(): 142 run_id = "test_run_id" 143 run_data = RunData() 144 with generate_async_logging_queue(run_data) as async_logging_queue: 145 async_logging_queue.activate() 146 147 metrics = [ 148 Metric( 149 key=f"batch metrics async-{val}", 150 value=val, 151 timestamp=val, 152 step=0, 153 ) 154 for val in range(METRIC_PER_BATCH) 155 ] 156 async_logging_queue.log_batch_async(run_id=run_id, metrics=metrics, tags=[], params=[]) 157 async_logging_queue.end_async_logging() 158 assert async_logging_queue._status == QueueStatus.TEAR_DOWN 159 # end_async_logging should not shutdown the threadpool 160 assert not async_logging_queue._batch_logging_worker_threadpool._shutdown 161 assert not async_logging_queue._batch_status_check_threadpool._shutdown 162 163 async_logging_queue.flush() 164 assert async_logging_queue.is_active() 165 166 167 def test_partial_logging_failed(): 168 run_id = "test_run_id" 169 run_data = RunData(throw_exception_on_batch_number=[3, 4]) 170 with generate_async_logging_queue(run_data) as async_logging_queue: 171 async_logging_queue.activate() 172 173 metrics_sent = [] 174 tags_sent = [] 175 params_sent = [] 176 177 run_operations = [] 178 batch_id = 1 179 for params, tags, metrics in _get_run_data(): 180 if batch_id in [3, 4]: 181 with pytest.raises(MlflowException, match="Failed to log run data"): 182 async_logging_queue.log_batch_async( 183 run_id=run_id, metrics=metrics, tags=tags, params=params 184 ).wait() 185 else: 186 run_operations.append( 187 async_logging_queue.log_batch_async( 188 run_id=run_id, metrics=metrics, tags=tags, params=params 189 ) 190 ) 191 metrics_sent += metrics 192 tags_sent += tags 193 params_sent += params 194 195 batch_id += 1 196 197 for run_operation in run_operations: 198 run_operation.wait() 199 200 _assert_sent_received_data( 201 metrics_sent, 202 params_sent, 203 tags_sent, 204 run_data.received_metrics, 205 run_data.received_params, 206 run_data.received_tags, 207 ) 208 209 210 def test_publish_multithread_consume_single_thread(): 211 run_id = "test_run_id" 212 run_data = RunData(throw_exception_on_batch_number=[]) 213 with generate_async_logging_queue(run_data) as async_logging_queue: 214 async_logging_queue.activate() 215 216 run_operations = [] 217 t1 = threading.Thread( 218 name="test-async-logging-1", 219 target=_send_metrics_tags_params, 220 args=(async_logging_queue, run_id, run_operations), 221 ) 222 t2 = threading.Thread( 223 name="test-async-logging-2", 224 target=_send_metrics_tags_params, 225 args=(async_logging_queue, run_id, run_operations), 226 ) 227 228 t1.start() 229 t2.start() 230 t1.join() 231 t2.join() 232 233 for run_operation in run_operations: 234 run_operation.wait() 235 236 assert len(run_data.received_metrics) == 2 * METRIC_PER_BATCH * TOTAL_BATCHES 237 assert len(run_data.received_tags) == 2 * TAGS_PER_BATCH * TOTAL_BATCHES 238 assert len(run_data.received_params) == 2 * PARAMS_PER_BATCH * TOTAL_BATCHES 239 240 241 class Consumer: 242 def __init__(self) -> None: 243 self.metrics = [] 244 self.tags = [] 245 self.params = [] 246 self.barrier = threading.Event() 247 248 def consume_queue_data(self, run_id, metrics, tags, params): 249 self.barrier.wait() 250 self.metrics.extend(metrics or []) 251 self.params.extend(params or []) 252 self.tags.extend(tags or []) 253 254 def __getstate__(self): 255 state = self.__dict__.copy() 256 del state["barrier"] 257 return state 258 259 def __setstate__(self, state): 260 self.__dict__.update(state) 261 self.barrier = threading.Event() 262 263 264 def test_async_logging_queue_pickle(): 265 run_id = "test_run_id" 266 consumer = Consumer() 267 with generate_async_logging_queue(consumer) as async_logging_queue: 268 # Pickle the queue without activating it. 269 buffer = io.BytesIO() 270 pickle.dump(async_logging_queue, buffer) 271 deserialized_queue = pickle.loads(buffer.getvalue()) # Type: AsyncLoggingQueue 272 273 # Activate the queue and submit 10 items. Workers block on the barrier, 274 # so the consumer's state remains empty during pickling. 275 async_logging_queue.activate() 276 277 run_operations = [ 278 async_logging_queue.log_batch_async( 279 run_id=run_id, 280 metrics=[Metric("metric", val, timestamp=time.time(), step=1)], 281 tags=[], 282 params=[], 283 ) 284 for val in range(0, 10) 285 ] 286 287 # Pickle while workers are blocked — consumer state is deterministically empty. 288 buffer = io.BytesIO() 289 pickle.dump(async_logging_queue, buffer) 290 291 deserialized_queue = pickle.loads(buffer.getvalue()) # Type: AsyncLoggingQueue 292 assert deserialized_queue._queue.empty() 293 assert deserialized_queue._lock is not None 294 assert deserialized_queue._status is QueueStatus.IDLE 295 296 # Release workers and wait for all operations to complete. 297 consumer.barrier.set() 298 299 for run_operation in run_operations: 300 run_operation.wait() 301 302 assert len(consumer.metrics) == 10 303 304 # Activate the deserialized queue and submit 10 more items. 305 # The deserialized consumer is a separate copy with an empty metrics list. 306 deserialized_consumer = deserialized_queue._logging_func.__self__ 307 deserialized_consumer.barrier.set() 308 deserialized_queue.activate() 309 assert deserialized_queue.is_active() 310 311 run_operations = [] 312 313 for val in range(0, 10): 314 run_operations.append( 315 deserialized_queue.log_batch_async( 316 run_id=run_id, 317 metrics=[Metric("metric", val, timestamp=time.time(), step=1)], 318 tags=[], 319 params=[], 320 ) 321 ) 322 323 for run_operation in run_operations: 324 run_operation.wait() 325 326 assert len(deserialized_consumer.metrics) == 10 327 328 deserialized_queue.shut_down_async_logging() 329 330 331 def _send_metrics_tags_params(run_data_queueing_processor, run_id, run_operations=None): 332 if run_operations is None: 333 run_operations = [] 334 metrics_sent = [] 335 tags_sent = [] 336 params_sent = [] 337 338 for params, tags, metrics in _get_run_data(): 339 run_operations.append( 340 run_data_queueing_processor.log_batch_async( 341 run_id=run_id, metrics=metrics, tags=tags, params=params 342 ) 343 ) 344 345 time.sleep(random.randint(1, 3)) 346 metrics_sent += metrics 347 tags_sent += tags 348 params_sent += params 349 350 351 def _get_run_data( 352 total_batches=TOTAL_BATCHES, 353 params_per_batch=PARAMS_PER_BATCH, 354 tags_per_batch=TAGS_PER_BATCH, 355 metrics_per_batch=METRIC_PER_BATCH, 356 ): 357 for num in range(0, total_batches): 358 guid8 = str(uuid.uuid4())[:8] 359 params = [ 360 Param(f"batch param-{guid8}-{val}", value=str(time.time())) 361 for val in range(params_per_batch) 362 ] 363 tags = [ 364 RunTag(f"batch tag-{guid8}-{val}", value=str(time.time())) 365 for val in range(tags_per_batch) 366 ] 367 metrics = [ 368 Metric( 369 key=f"batch metrics async-{num}", 370 value=val, 371 timestamp=int(time.time() * 1000), 372 step=0, 373 ) 374 for val in range(metrics_per_batch) 375 ] 376 yield params, tags, metrics 377 378 379 def _assert_sent_received_data( 380 metrics_sent, params_sent, tags_sent, received_metrics, received_params, received_tags 381 ): 382 for num in range(1, len(metrics_sent)): 383 assert metrics_sent[num].key == received_metrics[num].key 384 assert metrics_sent[num].value == received_metrics[num].value 385 assert metrics_sent[num].timestamp == received_metrics[num].timestamp 386 assert metrics_sent[num].step == received_metrics[num].step 387 388 for num in range(1, len(tags_sent)): 389 assert tags_sent[num].key == received_tags[num].key 390 assert tags_sent[num].value == received_tags[num].value 391 392 for num in range(1, len(params_sent)): 393 assert params_sent[num].key == received_params[num].key 394 assert params_sent[num].value == received_params[num].value 395 396 397 def test_batch_split(monkeypatch): 398 monkeypatch.setattr(mlflow.utils.async_logging.async_logging_queue, "_MAX_ITEMS_PER_BATCH", 10) 399 monkeypatch.setattr(mlflow.utils.async_logging.async_logging_queue, "_MAX_PARAMS_PER_BATCH", 6) 400 monkeypatch.setattr(mlflow.utils.async_logging.async_logging_queue, "_MAX_TAGS_PER_BATCH", 8) 401 402 run_data = RunData() 403 with generate_async_logging_queue(run_data) as async_logging_queue: 404 async_logging_queue.activate() 405 406 run_id = "test_run_id" 407 for params, tags, metrics in _get_run_data(2, 3, 3, 3): 408 async_logging_queue.log_batch_async( 409 run_id=run_id, metrics=metrics, tags=tags, params=params 410 ) 411 async_logging_queue.flush() 412 413 assert run_data.batch_count == 2 414 415 run_data = RunData() 416 with generate_async_logging_queue(run_data) as async_logging_queue: 417 async_logging_queue.activate() 418 419 run_id = "test_run_id" 420 for params, tags, metrics in _get_run_data(2, 4, 0, 0): 421 async_logging_queue.log_batch_async( 422 run_id=run_id, metrics=metrics, tags=tags, params=params 423 ) 424 async_logging_queue.flush() 425 426 assert run_data.batch_count == 2 427 428 run_data = RunData() 429 with generate_async_logging_queue(run_data) as async_logging_queue: 430 async_logging_queue.activate() 431 432 run_id = "test_run_id" 433 for params, tags, metrics in _get_run_data(2, 0, 5, 0): 434 async_logging_queue.log_batch_async( 435 run_id=run_id, metrics=metrics, tags=tags, params=params 436 ) 437 async_logging_queue.flush() 438 439 assert run_data.batch_count == 2