/ src / python / txtai / workflow / task / base.py
base.py
  1  """
  2  Task module
  3  """
  4  
  5  import logging
  6  import re
  7  import types
  8  
  9  import numpy as np
 10  import torch
 11  
 12  # Logging configuration
 13  logger = logging.getLogger(__name__)
 14  
 15  
 16  class Task:
 17      """
 18      Base class for all workflow tasks.
 19      """
 20  
 21      def __init__(
 22          self,
 23          action=None,
 24          select=None,
 25          unpack=True,
 26          column=None,
 27          merge="hstack",
 28          initialize=None,
 29          finalize=None,
 30          concurrency=None,
 31          onetomany=True,
 32          **kwargs,
 33      ):
 34          """
 35          Creates a new task. A task defines two methods, type of data it accepts and the action to execute
 36          for each data element. Action is a callable function or list of callable functions.
 37  
 38          Args:
 39              action: action(s) to execute on each data element
 40              select: filter(s) used to select data to process
 41              unpack: if data elements should be unpacked or unwrapped from (id, data, tag) tuples
 42              column: column index to select if element is a tuple, defaults to all
 43              merge: merge mode for joining multi-action outputs, defaults to hstack
 44              initialize: action to execute before processing
 45              finalize: action to execute after processing
 46              concurrency: sets concurrency method when execute instance available
 47                           valid values: "thread" for thread-based concurrency, "process" for process-based concurrency
 48              onetomany: if one-to-many data transformations should be enabled, defaults to True
 49              kwargs: additional keyword arguments
 50          """
 51  
 52          # Standardize into list of actions
 53          if not action:
 54              action = []
 55          elif not isinstance(action, list):
 56              action = [action]
 57  
 58          self.action = action
 59          self.select = select
 60          self.unpack = unpack
 61          self.column = column
 62          self.merge = merge
 63          self.initialize = initialize
 64          self.finalize = finalize
 65          self.concurrency = concurrency
 66          self.onetomany = onetomany
 67  
 68          # Check for custom registration. Adds additional instance members and validates required dependencies available.
 69          if hasattr(self, "register"):
 70              self.register(**kwargs)
 71          elif kwargs:
 72              # Raise error if additional keyword arguments passed in without register method
 73              kwargs = ", ".join(f"'{kw}'" for kw in kwargs)
 74              raise TypeError(f"__init__() got unexpected keyword arguments: {kwargs}")
 75  
 76      def __call__(self, elements, executor=None):
 77          """
 78          Executes action for a list of data elements.
 79  
 80          Args:
 81              elements: iterable data elements
 82              executor: execute instance, enables concurrent task actions
 83  
 84          Returns:
 85              transformed data elements
 86          """
 87  
 88          if isinstance(elements, list):
 89              return self.filteredrun(elements, executor)
 90  
 91          return self.run(elements, executor)
 92  
 93      def filteredrun(self, elements, executor):
 94          """
 95          Executes a filtered run, which will tag all inputs with a process id, filter elements down to elements the
 96          task can handle and execute on that subset. Items not selected for processing will be returned unmodified.
 97  
 98          Args:
 99              elements: iterable data elements
100              executor: execute instance, enables concurrent task actions
101  
102          Returns:
103              transformed data elements
104          """
105  
106          # Build list of elements with unique process ids
107          indexed = list(enumerate(elements))
108  
109          # Filter data down to data this task handles
110          data = [(x, self.upack(element)) for x, element in indexed if self.accept(self.upack(element, True))]
111  
112          # Get list of filtered process ids
113          ids = [x for x, _ in data]
114  
115          # Prepare elements and execute task action(s)
116          results = self.execute([self.prepare(element) for _, element in data], executor)
117  
118          # Pack results back into elements
119          if self.merge:
120              elements = self.filteredpack(results, indexed, ids)
121          else:
122              elements = [self.filteredpack(r, indexed, ids) for r in results]
123  
124          return elements
125  
126      def filteredpack(self, results, indexed, ids):
127          """
128          Processes and packs results back into original input elements.
129  
130          Args:
131              results: task results
132              indexed: original elements indexed by process id
133              ids: process ids accepted by this task
134  
135          Returns:
136              packed elements
137          """
138  
139          # Update with transformed elements. Handle one to many transformations.
140          elements = []
141          for x, element in indexed:
142              if x in ids:
143                  # Get result for process id
144                  result = results[ids.index(x)]
145  
146                  if isinstance(result, OneToMany):
147                      # One to many transformations
148                      elements.extend([self.pack(element, r) for r in result])
149                  else:
150                      # One to one transformations
151                      elements.append(self.pack(element, result))
152              else:
153                  # Pass unprocessed elements through
154                  elements.append(element)
155  
156          return elements
157  
158      def run(self, elements, executor):
159          """
160          Executes a task run for elements. A standard run processes all elements.
161  
162          Args:
163              elements: iterable data elements
164              executor: execute instance, enables concurrent task actions
165  
166          Returns:
167              transformed data elements
168          """
169  
170          # Execute task actions
171          results = self.execute(elements, executor)
172  
173          # Handle one to many transformations
174          if isinstance(results, list):
175              elements = []
176              for result in results:
177                  if isinstance(result, OneToMany):
178                      # One to many transformations
179                      elements.extend(result)
180                  else:
181                      # One to one transformations
182                      elements.append(result)
183  
184              return elements
185  
186          return results
187  
188      def accept(self, element):
189          """
190          Determines if this task can handle the input data format.
191  
192          Args:
193              element: input data element
194  
195          Returns:
196              True if this task can process this data element, False otherwise
197          """
198  
199          return (isinstance(element, str) and re.search(self.select, element.lower())) if element is not None and self.select else True
200  
201      def upack(self, element, force=False):
202          """
203          Unpacks data for processing.
204  
205          Args:
206              element: input data element
207              force: if True, data is unpacked even if task has unpack set to False
208  
209          Returns:
210              data
211          """
212  
213          # Extract data from (id, data, tag) formatted elements
214          if (self.unpack or force) and isinstance(element, tuple) and len(element) > 1:
215              return element[1]
216  
217          return element
218  
219      def pack(self, element, data):
220          """
221          Packs data after processing.
222  
223          Args:
224              element: transformed data element
225              data: item to pack element into
226  
227          Returns:
228              packed data
229          """
230  
231          # Pack data into (id, data, tag) formatted elements
232          if self.unpack and isinstance(element, tuple) and len(element) > 1:
233              # If new data is a (id, data, tag) tuple use that except for multi-action "hstack" merges which produce tuples
234              if isinstance(data, tuple) and (len(self.action) <= 1 or self.merge != "hstack"):
235                  return data
236  
237              # Create a copy of tuple, update data element and return
238              element = list(element)
239              element[1] = data
240              return tuple(element)
241  
242          return data
243  
244      def prepare(self, element):
245          """
246          Method that allows downstream tasks to prepare data element for processing.
247  
248          Args:
249              element: input data element
250  
251          Returns:
252              data element ready for processing
253          """
254  
255          return element
256  
257      def execute(self, elements, executor):
258          """
259          Executes action(s) on elements.
260  
261          Args:
262              elements: list of data elements
263              executor: execute instance, enables concurrent task actions
264  
265          Returns:
266              transformed data elements
267          """
268  
269          if self.action:
270              # Run actions
271              outputs = []
272              for x, action in enumerate(self.action):
273                  # Filter elements by column index if necessary - supports a single int or an action index to column index mapping
274                  index = self.column[x] if isinstance(self.column, dict) else self.column
275                  inputs = [self.extract(e, index) for e in elements] if index is not None else elements
276  
277                  # Queue arguments for executor, process immediately if no executor available
278                  outputs.append((action, inputs) if executor else self.process(action, inputs))
279  
280              # Run with executor if available
281              if executor:
282                  outputs = executor.run(self.concurrency, self.process, outputs)
283  
284              # Run post process operations
285              return self.postprocess(outputs)
286  
287          return elements
288  
289      def extract(self, element, index):
290          """
291          Extracts a column from element by index if the element is a tuple.
292  
293          Args:
294              element: input element
295              index: column index
296  
297          Returns:
298              extracted column
299          """
300  
301          if isinstance(element, tuple):
302              if not self.unpack and len(element) == 3 and isinstance(element[1], tuple):
303                  return (element[0], element[1][index], element[2])
304  
305              return element[index]
306  
307          return element
308  
309      def process(self, action, inputs):
310          """
311          Executes action using inputs as arguments.
312  
313          Args:
314              action: callable object
315              inputs: action inputs
316  
317          Returns:
318              action outputs
319          """
320  
321          # Log inputs
322          logger.debug("Inputs: %s", inputs)
323  
324          # Execute action and get outputs
325          outputs = action(inputs)
326  
327          # Consume generator output, if necessary
328          if isinstance(outputs, types.GeneratorType):
329              outputs = list(outputs)
330  
331          # Log outputs
332          logger.debug("Outputs: %s", outputs)
333  
334          return outputs
335  
336      def postprocess(self, outputs):
337          """
338          Runs post process routines after a task action.
339  
340          Args:
341              outputs: task outputs
342  
343          Returns:
344              postprocessed outputs
345          """
346  
347          # Unpack single action tasks
348          if len(self.action) == 1:
349              return self.single(outputs[0])
350  
351          # Return unmodified outputs when merge set to None
352          if not self.merge:
353              return outputs
354  
355          if self.merge == "vstack":
356              return self.vstack(outputs)
357          if self.merge == "concat":
358              return self.concat(outputs)
359  
360          # Default mode is hstack
361          return self.hstack(outputs)
362  
363      def single(self, outputs):
364          """
365          Post processes and returns single action outputs.
366  
367          Args:
368              outputs: outputs from a single task
369  
370          Returns:
371              post processed outputs
372          """
373  
374          if self.onetomany and isinstance(outputs, list):
375              # Wrap one to many transformations
376              outputs = [OneToMany(output) if isinstance(output, list) else output for output in outputs]
377  
378          return outputs
379  
380      def vstack(self, outputs):
381          """
382          Merges outputs row-wise. Returns a list of lists which will be interpreted as a one to many transformation.
383  
384          Row-wise merge example (2 actions)
385  
386            Inputs: [a, b, c]
387  
388            Outputs => [[a1, b1, c1], [a2, b2, c2]]
389  
390            Row Merge => [[a1, a2], [b1, b2], [c1, c2]] = [a1, a2, b1, b2, c1, c2]
391  
392          Args:
393              outputs: task outputs
394  
395          Returns:
396              list of aggregated/zipped outputs as one to many transforms (row-wise)
397          """
398  
399          # If all outputs are numpy arrays, use native method
400          if all(isinstance(output, np.ndarray) for output in outputs):
401              return np.concatenate(np.stack(outputs, axis=1))
402  
403          # If all outputs are torch tensors, use native method
404          # pylint: disable=E1101
405          if all(torch.is_tensor(output) for output in outputs):
406              return torch.cat(tuple(torch.stack(outputs, axis=1)))
407  
408          # Flatten into lists of outputs per input row. Wrap as one to many transformation.
409          merge = []
410          for x in zip(*outputs):
411              combine = []
412              for y in x:
413                  if isinstance(y, list):
414                      combine.extend(y)
415                  else:
416                      combine.append(y)
417  
418              merge.append(OneToMany(combine))
419  
420          return merge
421  
422      def hstack(self, outputs):
423          """
424          Merges outputs column-wise. Returns a list of tuples which will be interpreted as a one to one transformation.
425  
426          Column-wise merge example (2 actions)
427  
428            Inputs: [a, b, c]
429  
430            Outputs => [[a1, b1, c1], [a2, b2, c2]]
431  
432            Column Merge => [(a1, a2), (b1, b2), (c1, c2)]
433  
434          Args:
435              outputs: task outputs
436  
437          Returns:
438              list of aggregated/zipped outputs as tuples (column-wise)
439          """
440  
441          # If all outputs are numpy arrays, use native method
442          if all(isinstance(output, np.ndarray) for output in outputs):
443              return np.stack(outputs, axis=1)
444  
445          # If all outputs are torch tensors, use native method
446          # pylint: disable=E1101
447          if all(torch.is_tensor(output) for output in outputs):
448              return torch.stack(outputs, axis=1)
449  
450          return list(zip(*outputs))
451  
452      def concat(self, outputs):
453          """
454          Merges outputs column-wise and concats values together into a string. Returns a list of strings.
455  
456          Concat merge example (2 actions)
457  
458            Inputs: [a, b, c]
459  
460            Outputs => [[a1, b1, c1], [a2, b2, c2]]
461  
462            Concat Merge => [(a1, a2), (b1, b2), (c1, c2)] => ["a1. a2", "b1. b2", "c1. c2"]
463  
464          Args:
465              outputs: task outputs
466  
467          Returns:
468              list of concat outputs
469          """
470  
471          return [". ".join([str(y) for y in x if y]) for x in self.hstack(outputs)]
472  
473  
474  class OneToMany:
475      """
476      Encapsulates list output for a one to many transformation.
477      """
478  
479      def __init__(self, values):
480          """
481          Creates a new OneToMany transformation.
482  
483          Args:
484              values: list of outputs
485          """
486  
487          self.values = values
488  
489      def __iter__(self):
490          return self.values.__iter__()