/ src / python / txtai / workflow / execute.py
execute.py
 1  """
 2  Execute module
 3  """
 4  
 5  from multiprocessing.pool import Pool, ThreadPool
 6  
 7  import torch.multiprocessing
 8  
 9  
10  class Execute:
11      """
12      Supports sequential, multithreading and multiprocessing based execution of tasks.
13      """
14  
15      def __init__(self, workers=None):
16          """
17          Creates a new execute instance. Functions can be executed sequentially, in a thread pool
18          or in a process pool. Once created, the thread and/or process pool will stay open until the
19          close method is called.
20  
21          Args:
22              workers: number of workers for thread/process pools
23          """
24  
25          # Number of workers to use in thread/process pools
26          self.workers = workers
27  
28          self.thread = None
29          self.process = None
30  
31      def __del__(self):
32          self.close()
33  
34      def __enter__(self):
35          return self
36  
37      def __exit__(self, etype, value, traceback):
38          self.close()
39  
40      def run(self, method, function, args):
41          """
42          Runs multiple calls of function for each tuple in args. The method parameter controls if the calls are
43          sequential (method = None), multithreaded (method = "thread") or with multiprocessing (method="process").
44  
45          Args:
46              method: run method - "thread" for multithreading, "process" for multiprocessing, otherwise runs sequentially
47              function: function to run
48              args: list of tuples with arguments to each call
49          """
50  
51          # Concurrent processing
52          if method and len(args) > 1:
53              pool = self.pool(method)
54              if pool:
55                  return pool.starmap(function, args, 1)
56  
57          # Sequential processing
58          return [function(*arg) for arg in args]
59  
60      def pool(self, method):
61          """
62          Gets a handle to a concurrent processing pool. This method will create the pool if it doesn't already exist.
63  
64          Args:
65              method: pool type - "thread" or "process"
66  
67          Returns:
68              concurrent processing pool or None if no pool of that type available
69          """
70  
71          if method == "thread":
72              if not self.thread:
73                  self.thread = ThreadPool(self.workers)
74  
75              return self.thread
76  
77          if method == "process":
78              if not self.process:
79                  # Importing torch.multiprocessing will register torch shared memory serialization for cuda
80                  self.process = Pool(self.workers, context=torch.multiprocessing.get_context("spawn"))
81  
82              return self.process
83  
84          return None
85  
86      def close(self):
87          """
88          Closes concurrent processing pools.
89          """
90  
91          if hasattr(self, "thread") and self.thread:
92              self.thread.close()
93              self.thread.join()
94              self.thread = None
95  
96          if hasattr(self, "process") and self.process:
97              self.process.close()
98              self.process.join()
99              self.process = None