execute.py
1 """ 2 Execute module 3 """ 4 5 from multiprocessing.pool import Pool, ThreadPool 6 7 import torch.multiprocessing 8 9 10 class Execute: 11 """ 12 Supports sequential, multithreading and multiprocessing based execution of tasks. 13 """ 14 15 def __init__(self, workers=None): 16 """ 17 Creates a new execute instance. Functions can be executed sequentially, in a thread pool 18 or in a process pool. Once created, the thread and/or process pool will stay open until the 19 close method is called. 20 21 Args: 22 workers: number of workers for thread/process pools 23 """ 24 25 # Number of workers to use in thread/process pools 26 self.workers = workers 27 28 self.thread = None 29 self.process = None 30 31 def __del__(self): 32 self.close() 33 34 def __enter__(self): 35 return self 36 37 def __exit__(self, etype, value, traceback): 38 self.close() 39 40 def run(self, method, function, args): 41 """ 42 Runs multiple calls of function for each tuple in args. The method parameter controls if the calls are 43 sequential (method = None), multithreaded (method = "thread") or with multiprocessing (method="process"). 44 45 Args: 46 method: run method - "thread" for multithreading, "process" for multiprocessing, otherwise runs sequentially 47 function: function to run 48 args: list of tuples with arguments to each call 49 """ 50 51 # Concurrent processing 52 if method and len(args) > 1: 53 pool = self.pool(method) 54 if pool: 55 return pool.starmap(function, args, 1) 56 57 # Sequential processing 58 return [function(*arg) for arg in args] 59 60 def pool(self, method): 61 """ 62 Gets a handle to a concurrent processing pool. This method will create the pool if it doesn't already exist. 63 64 Args: 65 method: pool type - "thread" or "process" 66 67 Returns: 68 concurrent processing pool or None if no pool of that type available 69 """ 70 71 if method == "thread": 72 if not self.thread: 73 self.thread = ThreadPool(self.workers) 74 75 return self.thread 76 77 if method == "process": 78 if not self.process: 79 # Importing torch.multiprocessing will register torch shared memory serialization for cuda 80 self.process = Pool(self.workers, context=torch.multiprocessing.get_context("spawn")) 81 82 return self.process 83 84 return None 85 86 def close(self): 87 """ 88 Closes concurrent processing pools. 89 """ 90 91 if hasattr(self, "thread") and self.thread: 92 self.thread.close() 93 self.thread.join() 94 self.thread = None 95 96 if hasattr(self, "process") and self.process: 97 self.process.close() 98 self.process.join() 99 self.process = None