/ mlflow / utils / thread_utils.py
thread_utils.py
 1  import os
 2  import threading
 3  from typing import Any
 4  
 5  
 6  class ThreadLocalVariable:
 7      """
 8      Class for creating a thread local variable.
 9  
10      Args:
11          default_factory: A function used to create the default value
12          reset_in_subprocess: Indicating whether the variable is reset in subprocess.
13      """
14  
15      def __init__(self, default_factory, reset_in_subprocess=True):
16          self.reset_in_subprocess = reset_in_subprocess
17          self.default_factory = default_factory
18          self.thread_local = threading.local()
19          # The `__global_thread_values` attribute saves all thread-local values,
20          # the key is thread ID.
21          self.__global_thread_values: dict[int, Any] = {}
22  
23      def get(self):
24          """
25          Get the thread-local variable value.
26          If the thread-local variable is not set, return the provided `init_value` value.
27          If `get` is called in a forked subprocess and `reset_in_subprocess` is True,
28          return the provided `init_value` value
29          """
30          if hasattr(self.thread_local, "value"):
31              value, pid = self.thread_local.value
32              if self.reset_in_subprocess and pid != os.getpid():
33                  # `get` is called in a forked subprocess, reset it.
34                  init_value = self.default_factory()
35                  self.set(init_value)
36                  return init_value
37              else:
38                  return value
39          else:
40              init_value = self.default_factory()
41              self.set(init_value)
42              return init_value
43  
44      def set(self, value):
45          """
46          Set a value for the thread-local variable.
47          """
48          self.thread_local.value = (value, os.getpid())
49          self.__global_thread_values[threading.get_ident()] = value
50  
51      def get_all_thread_values(self) -> dict[int, Any]:
52          """
53          Return all thread values as a dict, dict key is the thread ID.
54          """
55          return self.__global_thread_values.copy()
56  
57      def reset(self):
58          """
59          Reset the thread-local variable.
60          Clear the global thread values and create a new thread local variable.
61          """
62          self.__global_thread_values.clear()
63          self.thread_local = threading.local()