thread_utils.py
1 import os 2 import threading 3 from typing import Any 4 5 6 class ThreadLocalVariable: 7 """ 8 Class for creating a thread local variable. 9 10 Args: 11 default_factory: A function used to create the default value 12 reset_in_subprocess: Indicating whether the variable is reset in subprocess. 13 """ 14 15 def __init__(self, default_factory, reset_in_subprocess=True): 16 self.reset_in_subprocess = reset_in_subprocess 17 self.default_factory = default_factory 18 self.thread_local = threading.local() 19 # The `__global_thread_values` attribute saves all thread-local values, 20 # the key is thread ID. 21 self.__global_thread_values: dict[int, Any] = {} 22 23 def get(self): 24 """ 25 Get the thread-local variable value. 26 If the thread-local variable is not set, return the provided `init_value` value. 27 If `get` is called in a forked subprocess and `reset_in_subprocess` is True, 28 return the provided `init_value` value 29 """ 30 if hasattr(self.thread_local, "value"): 31 value, pid = self.thread_local.value 32 if self.reset_in_subprocess and pid != os.getpid(): 33 # `get` is called in a forked subprocess, reset it. 34 init_value = self.default_factory() 35 self.set(init_value) 36 return init_value 37 else: 38 return value 39 else: 40 init_value = self.default_factory() 41 self.set(init_value) 42 return init_value 43 44 def set(self, value): 45 """ 46 Set a value for the thread-local variable. 47 """ 48 self.thread_local.value = (value, os.getpid()) 49 self.__global_thread_values[threading.get_ident()] = value 50 51 def get_all_thread_values(self) -> dict[int, Any]: 52 """ 53 Return all thread values as a dict, dict key is the thread ID. 54 """ 55 return self.__global_thread_values.copy() 56 57 def reset(self): 58 """ 59 Reset the thread-local variable. 60 Clear the global thread values and create a new thread local variable. 61 """ 62 self.__global_thread_values.clear() 63 self.thread_local = threading.local()