test_async_artifacts_logging_queue.py
1 import io 2 import pickle 3 import random 4 import threading 5 import time 6 7 import pytest 8 from PIL import Image 9 10 from mlflow import MlflowException 11 from mlflow.utils.async_logging.async_artifacts_logging_queue import AsyncArtifactsLoggingQueue 12 13 TOTAL_ARTIFACTS = 5 14 15 16 class RunArtifacts: 17 def __init__(self, throw_exception_on_artifact_number=None): 18 if throw_exception_on_artifact_number is None: 19 throw_exception_on_artifact_number = [] 20 self.received_run_id = "" 21 self.received_artifacts = [] 22 self.received_filenames = [] 23 self.received_artifact_paths = [] 24 self.artifact_count = 0 25 self.throw_exception_on_artifact_number = throw_exception_on_artifact_number or [] 26 27 def consume_queue_data(self, filename, artifact_path, artifact): 28 self.artifact_count += 1 29 if self.artifact_count in self.throw_exception_on_artifact_number: 30 raise MlflowException("Failed to log run data") 31 self.received_artifacts.append(artifact) 32 self.received_filenames.append(filename) 33 self.received_artifact_paths.append(artifact_path) 34 35 36 def _get_run_artifacts(total_artifacts=TOTAL_ARTIFACTS): 37 for num in range(0, total_artifacts): 38 filename = f"image_{num}.png" 39 artifact_path = f"images/artifact_{num}" 40 artifact = Image.new("RGB", (100, 100), color="red") 41 yield filename, artifact_path, artifact 42 43 44 def _assert_sent_received_artifacts( 45 filenames_sent, 46 artifact_paths_sent, 47 artifacts_sent, 48 received_filenames, 49 received_artifact_paths, 50 received_artifacts, 51 ): 52 for num in range(1, len(filenames_sent)): 53 assert filenames_sent[num] == received_filenames[num] 54 55 for num in range(1, len(artifact_paths_sent)): 56 assert artifact_paths_sent[num] == received_artifact_paths[num] 57 58 for num in range(1, len(artifacts_sent)): 59 assert artifacts_sent[num] == received_artifacts[num] 60 61 62 def test_single_thread_publish_consume_queue(): 63 run_artifacts = RunArtifacts() 64 async_logging_queue = AsyncArtifactsLoggingQueue(run_artifacts.consume_queue_data) 65 async_logging_queue.activate() 66 filenames_sent = [] 67 artifact_paths_sent = [] 68 artifacts_sent = [] 69 for filename, artifact_path, artifact in _get_run_artifacts(): 70 async_logging_queue.log_artifacts_async( 71 filename=filename, artifact_path=artifact_path, artifact=artifact 72 ) 73 filenames_sent.append(filename) 74 artifact_paths_sent.append(artifact_path) 75 artifacts_sent.append(artifact) 76 async_logging_queue.flush() 77 78 _assert_sent_received_artifacts( 79 filenames_sent, 80 artifact_paths_sent, 81 artifacts_sent, 82 run_artifacts.received_filenames, 83 run_artifacts.received_artifact_paths, 84 run_artifacts.received_artifacts, 85 ) 86 87 88 def test_queue_activation(): 89 run_artifacts = RunArtifacts() 90 async_logging_queue = AsyncArtifactsLoggingQueue(run_artifacts.consume_queue_data) 91 92 assert not async_logging_queue._is_activated 93 94 for filename, artifact_path, artifact in _get_run_artifacts(1): 95 with pytest.raises(MlflowException, match="AsyncArtifactsLoggingQueue is not activated."): 96 async_logging_queue.log_artifacts_async( 97 filename=filename, artifact_path=artifact_path, artifact=artifact 98 ) 99 100 async_logging_queue.activate() 101 assert async_logging_queue._is_activated 102 103 104 def test_partial_logging_failed(): 105 run_data = RunArtifacts(throw_exception_on_artifact_number=[3, 4]) 106 107 async_logging_queue = AsyncArtifactsLoggingQueue(run_data.consume_queue_data) 108 async_logging_queue.activate() 109 110 filenames_sent = [] 111 artifact_paths_sent = [] 112 artifacts_sent = [] 113 114 run_operations = [] 115 batch_id = 1 116 for filename, artifact_path, artifact in _get_run_artifacts(): 117 if batch_id in [3, 4]: 118 with pytest.raises(MlflowException, match="Failed to log run data"): 119 async_logging_queue.log_artifacts_async( 120 filename=filename, artifact_path=artifact_path, artifact=artifact 121 ).wait() 122 else: 123 run_operations.append( 124 async_logging_queue.log_artifacts_async( 125 filename=filename, artifact_path=artifact_path, artifact=artifact 126 ) 127 ) 128 filenames_sent.append(filename) 129 artifact_paths_sent.append(artifact_path) 130 artifacts_sent.append(artifact) 131 132 batch_id += 1 133 134 for run_operation in run_operations: 135 run_operation.wait() 136 137 _assert_sent_received_artifacts( 138 filenames_sent, 139 artifact_paths_sent, 140 artifacts_sent, 141 run_data.received_filenames, 142 run_data.received_artifact_paths, 143 run_data.received_artifacts, 144 ) 145 146 147 def test_publish_multithread_consume_single_thread(): 148 run_data = RunArtifacts(throw_exception_on_artifact_number=[]) 149 150 async_logging_queue = AsyncArtifactsLoggingQueue(run_data.consume_queue_data) 151 async_logging_queue.activate() 152 153 def _send_artifact(run_data_queueing_processor, run_operations=None): 154 if run_operations is None: 155 run_operations = [] 156 filenames_sent = [] 157 artifact_paths_sent = [] 158 artifacts_sent = [] 159 160 for filename, artifact_path, artifact in _get_run_artifacts(): 161 run_operations.append( 162 run_data_queueing_processor.log_artifacts_async( 163 filename=filename, artifact_path=artifact_path, artifact=artifact 164 ) 165 ) 166 167 time.sleep(random.randint(1, 3)) 168 filenames_sent.append(filename) 169 artifact_paths_sent.append(artifact_path) 170 artifacts_sent.append(artifact) 171 172 run_operations = [] 173 t1 = threading.Thread( 174 name="test-async-artifacts-1", 175 target=_send_artifact, 176 args=(async_logging_queue, run_operations), 177 ) 178 t2 = threading.Thread( 179 name="test-async-artifacts-2", 180 target=_send_artifact, 181 args=(async_logging_queue, run_operations), 182 ) 183 184 t1.start() 185 t2.start() 186 t1.join() 187 t2.join() 188 189 for run_operation in run_operations: 190 run_operation.wait() 191 192 assert len(run_data.received_filenames) == 2 * TOTAL_ARTIFACTS 193 assert len(run_data.received_artifact_paths) == 2 * TOTAL_ARTIFACTS 194 assert len(run_data.received_artifacts) == 2 * TOTAL_ARTIFACTS 195 196 197 class Consumer: 198 def __init__(self) -> None: 199 self.filenames = [] 200 self.artifact_paths = [] 201 self.artifacts = [] 202 self.barrier = threading.Event() 203 204 def consume_queue_data(self, filename, artifact_path, artifact): 205 self.barrier.wait() 206 self.filenames.append(filename) 207 self.artifact_paths.append(artifact_path) 208 self.artifacts.append(artifact) 209 210 def __getstate__(self): 211 state = self.__dict__.copy() 212 del state["barrier"] 213 return state 214 215 def __setstate__(self, state): 216 self.__dict__.update(state) 217 self.barrier = threading.Event() 218 219 220 def test_async_logging_queue_pickle(): 221 consumer = Consumer() 222 async_logging_queue = AsyncArtifactsLoggingQueue(consumer.consume_queue_data) 223 224 # Pickle the queue without activating it. 225 buffer = io.BytesIO() 226 pickle.dump(async_logging_queue, buffer) 227 deserialized_queue = pickle.loads(buffer.getvalue()) # Type: AsyncArtifactsLoggingQueue 228 229 # Activate the queue and submit 10 items. Workers block on the barrier, 230 # so the consumer's state remains empty during pickling. 231 async_logging_queue.activate() 232 233 run_operations = [ 234 async_logging_queue.log_artifacts_async( 235 filename=f"image-{val}.png", 236 artifact_path="images/image-artifact.png", 237 artifact=Image.new("RGB", (100, 100), color="blue"), 238 ) 239 for val in range(0, 10) 240 ] 241 242 # Pickle while workers are blocked — consumer state is deterministically empty. 243 buffer = io.BytesIO() 244 pickle.dump(async_logging_queue, buffer) 245 246 deserialized_queue = pickle.loads(buffer.getvalue()) # Type: AsyncLoggingQueue 247 assert deserialized_queue._queue.empty() 248 assert deserialized_queue._lock is not None 249 assert deserialized_queue._is_activated is False 250 251 # Release workers and wait for all operations to complete. 252 consumer.barrier.set() 253 254 for run_operation in run_operations: 255 run_operation.wait() 256 257 assert len(consumer.filenames) == 10 258 259 # Activate the deserialized queue and submit 10 more items. 260 # The deserialized consumer is a separate copy with an empty filenames list. 261 deserialized_consumer = deserialized_queue._artifact_logging_func.__self__ 262 deserialized_consumer.barrier.set() 263 deserialized_queue.activate() 264 assert deserialized_queue._is_activated 265 266 run_operations = [] 267 268 for val in range(0, 10): 269 run_operations.append( 270 deserialized_queue.log_artifacts_async( 271 filename=f"image2-{val}.png", 272 artifact_path="images/image-artifact2.png", 273 artifact=Image.new("RGB", (100, 100), color="green"), 274 ) 275 ) 276 277 for run_operation in run_operations: 278 run_operation.wait() 279 280 assert len(deserialized_consumer.filenames) == 10