client.py
1 """ 2 Defines an MlflowAutologgingQueueingClient developer API that provides batching, queueing, and 3 asynchronous execution capabilities for a subset of MLflow Tracking logging operations used most 4 frequently by autologging operations. 5 6 TODO(dbczumar): Migrate request batching, queueing, and async execution support from 7 MlflowAutologgingQueueingClient to MlflowClient in order to provide broader benefits to end users. 8 Remove this developer API. 9 """ 10 11 import logging 12 import os 13 from concurrent.futures import ThreadPoolExecutor 14 from itertools import zip_longest 15 from typing import TYPE_CHECKING, Any, NamedTuple, Optional 16 17 from mlflow.entities import Metric, Param, RunTag 18 from mlflow.entities.dataset_input import DatasetInput 19 from mlflow.exceptions import MlflowException 20 from mlflow.utils import _truncate_dict, chunk_list 21 from mlflow.utils.time import get_current_time_millis 22 from mlflow.utils.validation import ( 23 MAX_DATASETS_PER_BATCH, 24 MAX_ENTITIES_PER_BATCH, 25 MAX_ENTITY_KEY_LENGTH, 26 MAX_METRICS_PER_BATCH, 27 MAX_PARAM_VAL_LENGTH, 28 MAX_PARAMS_TAGS_PER_BATCH, 29 MAX_TAG_VAL_LENGTH, 30 ) 31 32 if TYPE_CHECKING: 33 from mlflow.data.dataset import Dataset 34 35 _logger = logging.getLogger(__name__) 36 37 38 class _PendingCreateRun(NamedTuple): 39 experiment_id: str 40 start_time: int | None 41 tags: list[RunTag] 42 run_name: str | None 43 44 45 class _PendingSetTerminated(NamedTuple): 46 status: str | None 47 end_time: int | None 48 49 50 class PendingRunId: 51 """ 52 Serves as a placeholder for the ID of a run that does not yet exist, enabling additional 53 metadata (e.g. metrics, params, ...) to be enqueued for the run prior to its creation. 54 """ 55 56 57 class RunOperations: 58 """ 59 Represents a collection of operations on one or more MLflow Runs, such as run creation 60 or metric logging. 61 """ 62 63 def __init__(self, operation_futures): 64 self._operation_futures = operation_futures 65 66 def await_completion(self): 67 """ 68 Blocks on completion of the MLflow Run operations. 69 """ 70 failed_operations = [] 71 for future in self._operation_futures: 72 try: 73 future.result() 74 except Exception as e: 75 failed_operations.append(e) 76 77 if len(failed_operations) > 0: 78 raise MlflowException( 79 message=( 80 "The following failures occurred while performing one or more logging" 81 f" operations: {failed_operations}" 82 ) 83 ) 84 85 86 # Define a threadpool for use across `MlflowAutologgingQueueingClient` instances to ensure that 87 # `MlflowAutologgingQueueingClient` instances can be pickled (ThreadPoolExecutor objects are not 88 # pickleable and therefore cannot be assigned as instance attributes). 89 # 90 # We limit the number of threads used for run operations, using at most 8 threads or 2 * the number 91 # of CPU cores available on the system (whichever is smaller) 92 num_cpus = os.cpu_count() or 4 93 num_logging_workers = min(num_cpus * 2, 8) 94 _AUTOLOGGING_QUEUEING_CLIENT_THREAD_POOL = ThreadPoolExecutor( 95 max_workers=num_logging_workers, 96 thread_name_prefix="MlflowAutologgingQueueingClient", 97 ) 98 99 100 class MlflowAutologgingQueueingClient: 101 """ 102 Efficiently implements a subset of MLflow Tracking's `MlflowClient` and fluent APIs to provide 103 automatic batching and async execution of run operations by way of queueing, as well as 104 parameter / tag truncation for autologging use cases. Run operations defined by this client, 105 such as `create_run` and `log_metrics`, enqueue data for future persistence to MLflow 106 Tracking. Data is not persisted until the queue is flushed via the `flush()` method, which 107 supports synchronous and asynchronous execution. 108 109 MlflowAutologgingQueueingClient is not threadsafe; none of its APIs should be called 110 concurrently. 111 """ 112 113 def __init__(self, tracking_uri=None): 114 from mlflow.tracking.client import MlflowClient 115 116 self._client = MlflowClient(tracking_uri) 117 self._pending_ops_by_run_id = {} 118 119 def __enter__(self): 120 """ 121 Enables `MlflowAutologgingQueueingClient` to be used as a context manager with 122 synchronous flushing upon exit, removing the need to call `flush()` for use cases 123 where logging completion can be waited upon synchronously. 124 125 Run content is only flushed if the context exited without an exception. 126 """ 127 return self 128 129 def __exit__(self, exc_type, exc, traceback): 130 """ 131 Enables `MlflowAutologgingQueueingClient` to be used as a context manager with 132 synchronous flushing upon exit, removing the need to call `flush()` for use cases 133 where logging completion can be waited upon synchronously. 134 135 Run content is only flushed if the context exited without an exception. 136 """ 137 # NB: Run content is only flushed upon context exit to ensure that we don't elide the 138 # original exception thrown by the context (because `flush()` itself may throw). This 139 # is consistent with the behavior of a routine that calls `flush()` explicitly: content 140 # is not logged if an exception preempts the call to `flush()` 141 if exc is None and exc_type is None and traceback is None: 142 self.flush(synchronous=True) 143 else: 144 _logger.debug( 145 "Skipping run content logging upon MlflowAutologgingQueueingClient context because" 146 " an exception was raised within the context: %s", 147 exc, 148 ) 149 150 def create_run( 151 self, 152 experiment_id: str, 153 start_time: int | None = None, 154 tags: dict[str, Any] | None = None, 155 run_name: str | None = None, 156 ) -> PendingRunId: 157 """ 158 Enqueues a CreateRun operation with the specified attributes, returning a `PendingRunId` 159 instance that can be used as input to other client logging APIs (e.g. `log_metrics`, 160 `log_params`, ...). 161 162 Returns: 163 A `PendingRunId` that can be passed as the `run_id` parameter to other client 164 logging APIs, such as `log_params` and `log_metrics`. 165 """ 166 tags = tags or {} 167 tags = _truncate_dict( 168 tags, max_key_length=MAX_ENTITY_KEY_LENGTH, max_value_length=MAX_TAG_VAL_LENGTH 169 ) 170 run_id = PendingRunId() 171 self._get_pending_operations(run_id).enqueue( 172 create_run=_PendingCreateRun( 173 experiment_id=experiment_id, 174 start_time=start_time, 175 tags=[RunTag(key, str(value)) for key, value in tags.items()], 176 run_name=run_name, 177 ) 178 ) 179 return run_id 180 181 def set_terminated( 182 self, 183 run_id: str | PendingRunId, 184 status: str | None = None, 185 end_time: int | None = None, 186 ) -> None: 187 """ 188 Enqueues an UpdateRun operation with the specified `status` and `end_time` attributes 189 for the specified `run_id`. 190 """ 191 self._get_pending_operations(run_id).enqueue( 192 set_terminated=_PendingSetTerminated(status=status, end_time=end_time) 193 ) 194 195 def log_params(self, run_id: str | PendingRunId, params: dict[str, Any]) -> None: 196 """ 197 Enqueues a collection of Parameters to be logged to the run specified by `run_id`. 198 """ 199 params = _truncate_dict( 200 params, max_key_length=MAX_ENTITY_KEY_LENGTH, max_value_length=MAX_PARAM_VAL_LENGTH 201 ) 202 params_arr = [Param(key, str(value)) for key, value in params.items()] 203 self._get_pending_operations(run_id).enqueue(params=params_arr) 204 205 def log_inputs(self, run_id: str | PendingRunId, datasets: list[DatasetInput] | None) -> None: 206 """ 207 Enqueues a collection of Dataset to be logged to the run specified by `run_id`. 208 """ 209 if datasets is None or len(datasets) == 0: 210 return 211 self._get_pending_operations(run_id).enqueue(datasets=datasets) 212 213 def log_metrics( 214 self, 215 run_id: str | PendingRunId, 216 metrics: dict[str, float], 217 step: int | None = None, 218 dataset: Optional["Dataset"] = None, 219 model_id: str | None = None, 220 ) -> None: 221 """ 222 Enqueues a collection of Metrics to be logged to the run specified by `run_id` at the 223 step specified by `step`. 224 """ 225 metrics = _truncate_dict(metrics, max_key_length=MAX_ENTITY_KEY_LENGTH) 226 timestamp_ms = get_current_time_millis() 227 metrics_arr = [ 228 Metric( 229 key, 230 value, 231 timestamp_ms, 232 step or 0, 233 model_id=model_id, 234 dataset_name=dataset and dataset.name, 235 dataset_digest=dataset and dataset.digest, 236 ) 237 for key, value in metrics.items() 238 ] 239 self._get_pending_operations(run_id).enqueue(metrics=metrics_arr) 240 241 def set_tags(self, run_id: str | PendingRunId, tags: dict[str, Any]) -> None: 242 """ 243 Enqueues a collection of Tags to be logged to the run specified by `run_id`. 244 """ 245 tags = _truncate_dict( 246 tags, max_key_length=MAX_ENTITY_KEY_LENGTH, max_value_length=MAX_TAG_VAL_LENGTH 247 ) 248 tags_arr = [RunTag(key, str(value)) for key, value in tags.items()] 249 self._get_pending_operations(run_id).enqueue(tags=tags_arr) 250 251 def flush(self, synchronous=True): 252 """ 253 Flushes all queued run operations, resulting in the creation or mutation of runs 254 and run data. 255 256 Args: 257 synchronous: If `True`, run operations are performed synchronously, and a 258 `RunOperations` result object is only returned once all operations 259 are complete. If `False`, run operations are performed asynchronously, 260 and an `RunOperations` object is returned that represents the ongoing 261 run operations. 262 263 Returns: 264 A `RunOperations` instance representing the flushed operations. These operations 265 are already complete if `synchronous` is `True`. If `synchronous` is `False`, these 266 operations may still be inflight. Operation completion can be synchronously waited 267 on via `RunOperations.await_completion()`. 268 """ 269 logging_futures = [ 270 _AUTOLOGGING_QUEUEING_CLIENT_THREAD_POOL.submit( 271 self._flush_pending_operations, 272 pending_operations=pending_operations, 273 ) 274 for pending_operations in self._pending_ops_by_run_id.values() 275 ] 276 self._pending_ops_by_run_id = {} 277 278 logging_operations = RunOperations(logging_futures) 279 if synchronous: 280 logging_operations.await_completion() 281 return logging_operations 282 283 def _get_pending_operations(self, run_id): 284 """ 285 Returns: 286 A `_PendingRunOperations` containing all pending operations for the 287 specified `run_id`. 288 """ 289 if run_id not in self._pending_ops_by_run_id: 290 self._pending_ops_by_run_id[run_id] = _PendingRunOperations(run_id=run_id) 291 return self._pending_ops_by_run_id[run_id] 292 293 def _try_operation(self, fn, *args, **kwargs): 294 """ 295 Attempt to evaluate the specified function, `fn`, on the specified `*args` and `**kwargs`, 296 returning either the result of the function evaluation (if evaluation was successful) or 297 the exception raised by the function evaluation (if evaluation was unsuccessful). 298 """ 299 try: 300 return fn(*args, **kwargs) 301 except Exception as e: 302 return e 303 304 def _flush_pending_operations(self, pending_operations): 305 """ 306 Synchronously and sequentially flushes the specified list of pending run operations. 307 308 NB: Operations are not parallelized on a per-run basis because MLflow's File Store, which 309 is frequently used for local ML development, does not support threadsafe metadata logging 310 within a given run. 311 """ 312 if pending_operations.create_run: 313 create_run_tags = pending_operations.create_run.tags 314 num_additional_tags_to_include_during_creation = MAX_ENTITIES_PER_BATCH - len( 315 create_run_tags 316 ) 317 if num_additional_tags_to_include_during_creation > 0: 318 create_run_tags.extend( 319 pending_operations.tags_queue[:num_additional_tags_to_include_during_creation] 320 ) 321 pending_operations.tags_queue = pending_operations.tags_queue[ 322 num_additional_tags_to_include_during_creation: 323 ] 324 325 new_run = self._client.create_run( 326 experiment_id=pending_operations.create_run.experiment_id, 327 start_time=pending_operations.create_run.start_time, 328 tags={tag.key: tag.value for tag in create_run_tags}, 329 ) 330 pending_operations.run_id = new_run.info.run_id 331 332 run_id = pending_operations.run_id 333 assert not isinstance(run_id, PendingRunId), "Run ID cannot be pending for logging" 334 335 operation_results = [] 336 337 param_batches_to_log = chunk_list( 338 pending_operations.params_queue, 339 chunk_size=MAX_PARAMS_TAGS_PER_BATCH, 340 ) 341 tag_batches_to_log = chunk_list( 342 pending_operations.tags_queue, 343 chunk_size=MAX_PARAMS_TAGS_PER_BATCH, 344 ) 345 for params_batch, tags_batch in zip_longest( 346 param_batches_to_log, tag_batches_to_log, fillvalue=[] 347 ): 348 metrics_batch_size = min( 349 MAX_ENTITIES_PER_BATCH - len(params_batch) - len(tags_batch), 350 MAX_METRICS_PER_BATCH, 351 ) 352 metrics_batch_size = max(metrics_batch_size, 0) 353 metrics_batch = pending_operations.metrics_queue[:metrics_batch_size] 354 pending_operations.metrics_queue = pending_operations.metrics_queue[metrics_batch_size:] 355 356 operation_results.append( 357 self._try_operation( 358 self._client.log_batch, 359 run_id=run_id, 360 metrics=metrics_batch, 361 params=params_batch, 362 tags=tags_batch, 363 ) 364 ) 365 366 operation_results.extend( 367 self._try_operation(self._client.log_batch, run_id=run_id, metrics=metrics_batch) 368 for metrics_batch in chunk_list( 369 pending_operations.metrics_queue, chunk_size=MAX_METRICS_PER_BATCH 370 ) 371 ) 372 373 operation_results.extend( 374 self._try_operation(self._client.log_inputs, run_id=run_id, datasets=datasets_batch) 375 for datasets_batch in chunk_list( 376 pending_operations.datasets_queue, chunk_size=MAX_DATASETS_PER_BATCH 377 ) 378 ) 379 380 if pending_operations.set_terminated: 381 operation_results.append( 382 self._try_operation( 383 self._client.set_terminated, 384 run_id=run_id, 385 status=pending_operations.set_terminated.status, 386 end_time=pending_operations.set_terminated.end_time, 387 ) 388 ) 389 390 failures = [result for result in operation_results if isinstance(result, Exception)] 391 if len(failures) > 0: 392 raise MlflowException( 393 message=( 394 f"Failed to perform one or more operations on the run with ID {run_id}." 395 f" Failed operations: {failures}" 396 ) 397 ) 398 399 400 class _PendingRunOperations: 401 """ 402 Represents a collection of queued / pending MLflow Run operations. 403 """ 404 405 def __init__(self, run_id): 406 self.run_id = run_id 407 self.create_run = None 408 self.set_terminated = None 409 self.params_queue = [] 410 self.tags_queue = [] 411 self.metrics_queue = [] 412 self.datasets_queue = [] 413 414 def enqueue( 415 self, 416 params=None, 417 tags=None, 418 metrics=None, 419 datasets=None, 420 create_run=None, 421 set_terminated=None, 422 ): 423 """ 424 Enqueues a new pending logging operation for the associated MLflow Run. 425 """ 426 if create_run: 427 assert not self.create_run, "Attempted to create the same run multiple times" 428 self.create_run = create_run 429 if set_terminated: 430 assert not self.set_terminated, "Attempted to terminate the same run multiple times" 431 self.set_terminated = set_terminated 432 433 self.params_queue += params or [] 434 self.tags_queue += tags or [] 435 self.metrics_queue += metrics or [] 436 self.datasets_queue += datasets or []