client.py
  1  """
  2  Defines an MlflowAutologgingQueueingClient developer API that provides batching, queueing, and
  3  asynchronous execution capabilities for a subset of MLflow Tracking logging operations used most
  4  frequently by autologging operations.
  5  
  6  TODO(dbczumar): Migrate request batching, queueing, and async execution support from
  7  MlflowAutologgingQueueingClient to MlflowClient in order to provide broader benefits to end users.
  8  Remove this developer API.
  9  """
 10  
 11  import logging
 12  import os
 13  from concurrent.futures import ThreadPoolExecutor
 14  from itertools import zip_longest
 15  from typing import TYPE_CHECKING, Any, NamedTuple, Optional
 16  
 17  from mlflow.entities import Metric, Param, RunTag
 18  from mlflow.entities.dataset_input import DatasetInput
 19  from mlflow.exceptions import MlflowException
 20  from mlflow.utils import _truncate_dict, chunk_list
 21  from mlflow.utils.time import get_current_time_millis
 22  from mlflow.utils.validation import (
 23      MAX_DATASETS_PER_BATCH,
 24      MAX_ENTITIES_PER_BATCH,
 25      MAX_ENTITY_KEY_LENGTH,
 26      MAX_METRICS_PER_BATCH,
 27      MAX_PARAM_VAL_LENGTH,
 28      MAX_PARAMS_TAGS_PER_BATCH,
 29      MAX_TAG_VAL_LENGTH,
 30  )
 31  
 32  if TYPE_CHECKING:
 33      from mlflow.data.dataset import Dataset
 34  
 35  _logger = logging.getLogger(__name__)
 36  
 37  
 38  class _PendingCreateRun(NamedTuple):
 39      experiment_id: str
 40      start_time: int | None
 41      tags: list[RunTag]
 42      run_name: str | None
 43  
 44  
 45  class _PendingSetTerminated(NamedTuple):
 46      status: str | None
 47      end_time: int | None
 48  
 49  
 50  class PendingRunId:
 51      """
 52      Serves as a placeholder for the ID of a run that does not yet exist, enabling additional
 53      metadata (e.g. metrics, params, ...) to be enqueued for the run prior to its creation.
 54      """
 55  
 56  
 57  class RunOperations:
 58      """
 59      Represents a collection of operations on one or more MLflow Runs, such as run creation
 60      or metric logging.
 61      """
 62  
 63      def __init__(self, operation_futures):
 64          self._operation_futures = operation_futures
 65  
 66      def await_completion(self):
 67          """
 68          Blocks on completion of the MLflow Run operations.
 69          """
 70          failed_operations = []
 71          for future in self._operation_futures:
 72              try:
 73                  future.result()
 74              except Exception as e:
 75                  failed_operations.append(e)
 76  
 77          if len(failed_operations) > 0:
 78              raise MlflowException(
 79                  message=(
 80                      "The following failures occurred while performing one or more logging"
 81                      f" operations: {failed_operations}"
 82                  )
 83              )
 84  
 85  
 86  # Define a threadpool for use across `MlflowAutologgingQueueingClient` instances to ensure that
 87  # `MlflowAutologgingQueueingClient` instances can be pickled (ThreadPoolExecutor objects are not
 88  # pickleable and therefore cannot be assigned as instance attributes).
 89  #
 90  # We limit the number of threads used for run operations, using at most 8 threads or 2 * the number
 91  # of CPU cores available on the system (whichever is smaller)
 92  num_cpus = os.cpu_count() or 4
 93  num_logging_workers = min(num_cpus * 2, 8)
 94  _AUTOLOGGING_QUEUEING_CLIENT_THREAD_POOL = ThreadPoolExecutor(
 95      max_workers=num_logging_workers,
 96      thread_name_prefix="MlflowAutologgingQueueingClient",
 97  )
 98  
 99  
100  class MlflowAutologgingQueueingClient:
101      """
102      Efficiently implements a subset of MLflow Tracking's  `MlflowClient` and fluent APIs to provide
103      automatic batching and async execution of run operations by way of queueing, as well as
104      parameter / tag truncation for autologging use cases. Run operations defined by this client,
105      such as `create_run` and `log_metrics`, enqueue data for future persistence to MLflow
106      Tracking. Data is not persisted until the queue is flushed via the `flush()` method, which
107      supports synchronous and asynchronous execution.
108  
109      MlflowAutologgingQueueingClient is not threadsafe; none of its APIs should be called
110      concurrently.
111      """
112  
113      def __init__(self, tracking_uri=None):
114          from mlflow.tracking.client import MlflowClient
115  
116          self._client = MlflowClient(tracking_uri)
117          self._pending_ops_by_run_id = {}
118  
119      def __enter__(self):
120          """
121          Enables `MlflowAutologgingQueueingClient` to be used as a context manager with
122          synchronous flushing upon exit, removing the need to call `flush()` for use cases
123          where logging completion can be waited upon synchronously.
124  
125          Run content is only flushed if the context exited without an exception.
126          """
127          return self
128  
129      def __exit__(self, exc_type, exc, traceback):
130          """
131          Enables `MlflowAutologgingQueueingClient` to be used as a context manager with
132          synchronous flushing upon exit, removing the need to call `flush()` for use cases
133          where logging completion can be waited upon synchronously.
134  
135          Run content is only flushed if the context exited without an exception.
136          """
137          # NB: Run content is only flushed upon context exit to ensure that we don't elide the
138          # original exception thrown by the context (because `flush()` itself may throw). This
139          # is consistent with the behavior of a routine that calls `flush()` explicitly: content
140          # is not logged if an exception preempts the call to `flush()`
141          if exc is None and exc_type is None and traceback is None:
142              self.flush(synchronous=True)
143          else:
144              _logger.debug(
145                  "Skipping run content logging upon MlflowAutologgingQueueingClient context because"
146                  " an exception was raised within the context: %s",
147                  exc,
148              )
149  
150      def create_run(
151          self,
152          experiment_id: str,
153          start_time: int | None = None,
154          tags: dict[str, Any] | None = None,
155          run_name: str | None = None,
156      ) -> PendingRunId:
157          """
158          Enqueues a CreateRun operation with the specified attributes, returning a `PendingRunId`
159          instance that can be used as input to other client logging APIs (e.g. `log_metrics`,
160          `log_params`, ...).
161  
162          Returns:
163              A `PendingRunId` that can be passed as the `run_id` parameter to other client
164              logging APIs, such as `log_params` and `log_metrics`.
165          """
166          tags = tags or {}
167          tags = _truncate_dict(
168              tags, max_key_length=MAX_ENTITY_KEY_LENGTH, max_value_length=MAX_TAG_VAL_LENGTH
169          )
170          run_id = PendingRunId()
171          self._get_pending_operations(run_id).enqueue(
172              create_run=_PendingCreateRun(
173                  experiment_id=experiment_id,
174                  start_time=start_time,
175                  tags=[RunTag(key, str(value)) for key, value in tags.items()],
176                  run_name=run_name,
177              )
178          )
179          return run_id
180  
181      def set_terminated(
182          self,
183          run_id: str | PendingRunId,
184          status: str | None = None,
185          end_time: int | None = None,
186      ) -> None:
187          """
188          Enqueues an UpdateRun operation with the specified `status` and `end_time` attributes
189          for the specified `run_id`.
190          """
191          self._get_pending_operations(run_id).enqueue(
192              set_terminated=_PendingSetTerminated(status=status, end_time=end_time)
193          )
194  
195      def log_params(self, run_id: str | PendingRunId, params: dict[str, Any]) -> None:
196          """
197          Enqueues a collection of Parameters to be logged to the run specified by `run_id`.
198          """
199          params = _truncate_dict(
200              params, max_key_length=MAX_ENTITY_KEY_LENGTH, max_value_length=MAX_PARAM_VAL_LENGTH
201          )
202          params_arr = [Param(key, str(value)) for key, value in params.items()]
203          self._get_pending_operations(run_id).enqueue(params=params_arr)
204  
205      def log_inputs(self, run_id: str | PendingRunId, datasets: list[DatasetInput] | None) -> None:
206          """
207          Enqueues a collection of Dataset to be logged to the run specified by `run_id`.
208          """
209          if datasets is None or len(datasets) == 0:
210              return
211          self._get_pending_operations(run_id).enqueue(datasets=datasets)
212  
213      def log_metrics(
214          self,
215          run_id: str | PendingRunId,
216          metrics: dict[str, float],
217          step: int | None = None,
218          dataset: Optional["Dataset"] = None,
219          model_id: str | None = None,
220      ) -> None:
221          """
222          Enqueues a collection of Metrics to be logged to the run specified by `run_id` at the
223          step specified by `step`.
224          """
225          metrics = _truncate_dict(metrics, max_key_length=MAX_ENTITY_KEY_LENGTH)
226          timestamp_ms = get_current_time_millis()
227          metrics_arr = [
228              Metric(
229                  key,
230                  value,
231                  timestamp_ms,
232                  step or 0,
233                  model_id=model_id,
234                  dataset_name=dataset and dataset.name,
235                  dataset_digest=dataset and dataset.digest,
236              )
237              for key, value in metrics.items()
238          ]
239          self._get_pending_operations(run_id).enqueue(metrics=metrics_arr)
240  
241      def set_tags(self, run_id: str | PendingRunId, tags: dict[str, Any]) -> None:
242          """
243          Enqueues a collection of Tags to be logged to the run specified by `run_id`.
244          """
245          tags = _truncate_dict(
246              tags, max_key_length=MAX_ENTITY_KEY_LENGTH, max_value_length=MAX_TAG_VAL_LENGTH
247          )
248          tags_arr = [RunTag(key, str(value)) for key, value in tags.items()]
249          self._get_pending_operations(run_id).enqueue(tags=tags_arr)
250  
251      def flush(self, synchronous=True):
252          """
253          Flushes all queued run operations, resulting in the creation or mutation of runs
254          and run data.
255  
256          Args:
257              synchronous: If `True`, run operations are performed synchronously, and a
258                  `RunOperations` result object is only returned once all operations
259                  are complete. If `False`, run operations are performed asynchronously,
260                  and an `RunOperations` object is returned that represents the ongoing
261                  run operations.
262  
263          Returns:
264              A `RunOperations` instance representing the flushed operations. These operations
265              are already complete if `synchronous` is `True`. If `synchronous` is `False`, these
266              operations may still be inflight. Operation completion can be synchronously waited
267              on via `RunOperations.await_completion()`.
268          """
269          logging_futures = [
270              _AUTOLOGGING_QUEUEING_CLIENT_THREAD_POOL.submit(
271                  self._flush_pending_operations,
272                  pending_operations=pending_operations,
273              )
274              for pending_operations in self._pending_ops_by_run_id.values()
275          ]
276          self._pending_ops_by_run_id = {}
277  
278          logging_operations = RunOperations(logging_futures)
279          if synchronous:
280              logging_operations.await_completion()
281          return logging_operations
282  
283      def _get_pending_operations(self, run_id):
284          """
285          Returns:
286              A `_PendingRunOperations` containing all pending operations for the
287              specified `run_id`.
288          """
289          if run_id not in self._pending_ops_by_run_id:
290              self._pending_ops_by_run_id[run_id] = _PendingRunOperations(run_id=run_id)
291          return self._pending_ops_by_run_id[run_id]
292  
293      def _try_operation(self, fn, *args, **kwargs):
294          """
295          Attempt to evaluate the specified function, `fn`, on the specified `*args` and `**kwargs`,
296          returning either the result of the function evaluation (if evaluation was successful) or
297          the exception raised by the function evaluation (if evaluation was unsuccessful).
298          """
299          try:
300              return fn(*args, **kwargs)
301          except Exception as e:
302              return e
303  
304      def _flush_pending_operations(self, pending_operations):
305          """
306          Synchronously and sequentially flushes the specified list of pending run operations.
307  
308          NB: Operations are not parallelized on a per-run basis because MLflow's File Store, which
309          is frequently used for local ML development, does not support threadsafe metadata logging
310          within a given run.
311          """
312          if pending_operations.create_run:
313              create_run_tags = pending_operations.create_run.tags
314              num_additional_tags_to_include_during_creation = MAX_ENTITIES_PER_BATCH - len(
315                  create_run_tags
316              )
317              if num_additional_tags_to_include_during_creation > 0:
318                  create_run_tags.extend(
319                      pending_operations.tags_queue[:num_additional_tags_to_include_during_creation]
320                  )
321                  pending_operations.tags_queue = pending_operations.tags_queue[
322                      num_additional_tags_to_include_during_creation:
323                  ]
324  
325              new_run = self._client.create_run(
326                  experiment_id=pending_operations.create_run.experiment_id,
327                  start_time=pending_operations.create_run.start_time,
328                  tags={tag.key: tag.value for tag in create_run_tags},
329              )
330              pending_operations.run_id = new_run.info.run_id
331  
332          run_id = pending_operations.run_id
333          assert not isinstance(run_id, PendingRunId), "Run ID cannot be pending for logging"
334  
335          operation_results = []
336  
337          param_batches_to_log = chunk_list(
338              pending_operations.params_queue,
339              chunk_size=MAX_PARAMS_TAGS_PER_BATCH,
340          )
341          tag_batches_to_log = chunk_list(
342              pending_operations.tags_queue,
343              chunk_size=MAX_PARAMS_TAGS_PER_BATCH,
344          )
345          for params_batch, tags_batch in zip_longest(
346              param_batches_to_log, tag_batches_to_log, fillvalue=[]
347          ):
348              metrics_batch_size = min(
349                  MAX_ENTITIES_PER_BATCH - len(params_batch) - len(tags_batch),
350                  MAX_METRICS_PER_BATCH,
351              )
352              metrics_batch_size = max(metrics_batch_size, 0)
353              metrics_batch = pending_operations.metrics_queue[:metrics_batch_size]
354              pending_operations.metrics_queue = pending_operations.metrics_queue[metrics_batch_size:]
355  
356              operation_results.append(
357                  self._try_operation(
358                      self._client.log_batch,
359                      run_id=run_id,
360                      metrics=metrics_batch,
361                      params=params_batch,
362                      tags=tags_batch,
363                  )
364              )
365  
366          operation_results.extend(
367              self._try_operation(self._client.log_batch, run_id=run_id, metrics=metrics_batch)
368              for metrics_batch in chunk_list(
369                  pending_operations.metrics_queue, chunk_size=MAX_METRICS_PER_BATCH
370              )
371          )
372  
373          operation_results.extend(
374              self._try_operation(self._client.log_inputs, run_id=run_id, datasets=datasets_batch)
375              for datasets_batch in chunk_list(
376                  pending_operations.datasets_queue, chunk_size=MAX_DATASETS_PER_BATCH
377              )
378          )
379  
380          if pending_operations.set_terminated:
381              operation_results.append(
382                  self._try_operation(
383                      self._client.set_terminated,
384                      run_id=run_id,
385                      status=pending_operations.set_terminated.status,
386                      end_time=pending_operations.set_terminated.end_time,
387                  )
388              )
389  
390          failures = [result for result in operation_results if isinstance(result, Exception)]
391          if len(failures) > 0:
392              raise MlflowException(
393                  message=(
394                      f"Failed to perform one or more operations on the run with ID {run_id}."
395                      f" Failed operations: {failures}"
396                  )
397              )
398  
399  
400  class _PendingRunOperations:
401      """
402      Represents a collection of queued / pending MLflow Run operations.
403      """
404  
405      def __init__(self, run_id):
406          self.run_id = run_id
407          self.create_run = None
408          self.set_terminated = None
409          self.params_queue = []
410          self.tags_queue = []
411          self.metrics_queue = []
412          self.datasets_queue = []
413  
414      def enqueue(
415          self,
416          params=None,
417          tags=None,
418          metrics=None,
419          datasets=None,
420          create_run=None,
421          set_terminated=None,
422      ):
423          """
424          Enqueues a new pending logging operation for the associated MLflow Run.
425          """
426          if create_run:
427              assert not self.create_run, "Attempted to create the same run multiple times"
428              self.create_run = create_run
429          if set_terminated:
430              assert not self.set_terminated, "Attempted to terminate the same run multiple times"
431              self.set_terminated = set_terminated
432  
433          self.params_queue += params or []
434          self.tags_queue += tags or []
435          self.metrics_queue += metrics or []
436          self.datasets_queue += datasets or []