/ test / python / testworkflow.py
testworkflow.py
  1  """
  2  Workflow module tests
  3  """
  4  
  5  import contextlib
  6  import glob
  7  import io
  8  import os
  9  import tempfile
 10  import sys
 11  import unittest
 12  
 13  import numpy as np
 14  import torch
 15  
 16  from txtai.api import API
 17  from txtai.embeddings import Documents, Embeddings
 18  from txtai.pipeline import Nop, Segmentation, Summary, Translation, Textractor
 19  from txtai.workflow import (
 20      Workflow,
 21      Task,
 22      ConsoleTask,
 23      ExportTask,
 24      FileTask,
 25      ImageTask,
 26      RagTask,
 27      RetrieveTask,
 28      StorageTask,
 29      TemplateTask,
 30      WorkflowTask,
 31  )
 32  
 33  # pylint: disable=C0411
 34  from utils import Utils
 35  
 36  
 37  # pylint: disable=R0904
 38  class TestWorkflow(unittest.TestCase):
 39      """
 40      Workflow tests.
 41      """
 42  
 43      @classmethod
 44      def setUpClass(cls):
 45          """
 46          Initialize test data.
 47          """
 48  
 49          # Default YAML workflow configuration
 50          cls.config = """
 51          # Embeddings index
 52          writable: true
 53          embeddings:
 54              scoring: bm25
 55              path: google/bert_uncased_L-2_H-128_A-2
 56              content: true
 57  
 58          # Text segmentation
 59          segmentation:
 60              sentences: true
 61  
 62          # Workflow definitions
 63          workflow:
 64              index:
 65                  tasks:
 66                      - action: segmentation
 67                      - action: index
 68              search:
 69                  tasks:
 70                      - search
 71              transform:
 72                  tasks:
 73                      - transform
 74          """
 75  
 76      def testBaseWorkflow(self):
 77          """
 78          Test a basic workflow
 79          """
 80  
 81          translate = Translation()
 82  
 83          # Workflow that translate text to Spanish
 84          workflow = Workflow([Task(lambda x: translate(x, "es"))])
 85  
 86          results = list(workflow(["The sky is blue", "Forest through the trees"]))
 87  
 88          self.assertEqual(len(results), 2)
 89  
 90      def testChainWorkflow(self):
 91          """
 92          Test a chain of workflows
 93          """
 94  
 95          workflow1 = Workflow([Task(lambda x: [y * 2 for y in x])])
 96          workflow2 = Workflow([Task(lambda x: [y - 1 for y in x])], batch=4)
 97  
 98          results = list(workflow2(workflow1([1, 2, 4, 8, 16, 32])))
 99          self.assertEqual(results, [1, 3, 7, 15, 31, 63])
100  
101      def testComplexWorkflow(self):
102          """
103          Test a complex workflow
104          """
105  
106          textractor = Textractor(paragraphs=True, minlength=150, join=True)
107          summary = Summary("t5-small")
108  
109          embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2"})
110          documents = Documents()
111  
112          def index(x):
113              documents.add(x)
114              return x
115  
116          # Extract text and summarize articles
117          articles = Workflow([FileTask(textractor), Task(lambda x: summary(x, maxlength=15))])
118  
119          # Complex workflow that extracts text, runs summarization then loads into an embeddings index
120          tasks = [WorkflowTask(articles, r".\.pdf$"), Task(index, unpack=False)]
121  
122          data = ["file://" + Utils.PATH + "/article.pdf", "Workflows can process audio files, documents and snippets"]
123  
124          # Convert file paths to data tuples
125          data = [(x, element, None) for x, element in enumerate(data)]
126  
127          # Execute workflow, discard results as they are streamed
128          workflow = Workflow(tasks)
129          data = list(workflow(data))
130  
131          # Build the embeddings index
132          embeddings.index(documents)
133  
134          # Cleanup temporary storage
135          documents.close()
136  
137          # Run search and validate result
138          index, _ = embeddings.search("search text", 1)[0]
139          self.assertEqual(index, 0)
140          self.assertEqual(data[0][1], "txtai builds an AI-powered index over sections")
141  
142      def testConcurrentWorkflow(self):
143          """
144          Test running concurrent task actions
145          """
146  
147          nop = Nop()
148  
149          workflow = Workflow([Task([nop, nop], concurrency="thread")])
150          results = list(workflow([2, 4]))
151          self.assertEqual(results, [(2, 2), (4, 4)])
152  
153          workflow = Workflow([Task([nop, nop], concurrency="process")])
154          results = list(workflow([2, 4]))
155          self.assertEqual(results, [(2, 2), (4, 4)])
156  
157          workflow = Workflow([Task([nop, nop], concurrency="unknown")])
158          results = list(workflow([2, 4]))
159          self.assertEqual(results, [(2, 2), (4, 4)])
160  
161      def testConsoleWorkflow(self):
162          """
163          Test a console task
164          """
165  
166          # Excel export
167          workflow = Workflow([ConsoleTask()])
168  
169          output = io.StringIO()
170          with contextlib.redirect_stdout(output):
171              list(workflow([{"id": 1, "text": "Sentence 1"}, {"id": 2, "text": "Sentence 2"}]))
172  
173          self.assertIn("Sentence 2", output.getvalue())
174  
175      def testExportWorkflow(self):
176          """
177          Test an export task
178          """
179  
180          # Excel export
181          path = os.path.join(tempfile.gettempdir(), "export.xlsx")
182          workflow = Workflow([ExportTask(output=path)])
183          list(workflow([{"id": 1, "text": "Sentence 1"}, {"id": 2, "text": "Sentence 2"}]))
184          self.assertGreater(os.path.getsize(path), 0)
185  
186          # Export CSV
187          path = os.path.join(tempfile.gettempdir(), "export.csv")
188          workflow = Workflow([ExportTask(output=path)])
189          list(workflow([{"id": 1, "text": "Sentence 1"}, {"id": 2, "text": "Sentence 2"}]))
190          self.assertGreater(os.path.getsize(path), 0)
191  
192          # Export CSV with timestamp
193          path = os.path.join(tempfile.gettempdir(), "export-timestamp.csv")
194          workflow = Workflow([ExportTask(output=path, timestamp=True)])
195          list(workflow([{"id": 1, "text": "Sentence 1"}, {"id": 2, "text": "Sentence 2"}]))
196  
197          # Find timestamped file and ensure it has data
198          path = glob.glob(os.path.join(tempfile.gettempdir(), "export-timestamp*.csv"))[0]
199          self.assertGreater(os.path.getsize(path), 0)
200  
201      def testExtractWorkflow(self):
202          """
203          Test column extraction tasks
204          """
205  
206          workflow = Workflow([Task(lambda x: x, unpack=False, column=0)], batch=1)
207  
208          results = list(workflow([(0, 1)]))
209          self.assertEqual(results[0], 0)
210  
211          results = list(workflow([(0, (1, 2), None)]))
212          self.assertEqual(results[0], (0, 1, None))
213  
214          results = list(workflow([1]))
215          self.assertEqual(results[0], 1)
216  
217      def testImageWorkflow(self):
218          """
219          Test an image task
220          """
221  
222          workflow = Workflow([ImageTask()])
223  
224          results = list(workflow([Utils.PATH + "/books.jpg"]))
225  
226          self.assertEqual(results[0].size, (1024, 682))
227  
228      def testInvalidWorkflow(self):
229          """
230          Test task with invalid parameters
231          """
232  
233          with self.assertRaises(TypeError):
234              Task(invalid=True)
235  
236      def testMergeWorkflow(self):
237          """
238          Test merge tasks
239          """
240  
241          task = Task([lambda x: [pow(y, 2) for y in x], lambda x: [pow(y, 3) for y in x]], merge="hstack")
242  
243          # Test hstack (column-wise) merge
244          workflow = Workflow([task])
245          results = list(workflow([2, 4]))
246          self.assertEqual(results, [(4, 8), (16, 64)])
247  
248          # Test vstack (row-wise) merge
249          task.merge = "vstack"
250          results = list(workflow([2, 4]))
251          self.assertEqual(results, [4, 8, 16, 64])
252  
253          # Test concat (values joined into single string) merge
254          task.merge = "concat"
255          results = list(workflow([2, 4]))
256          self.assertEqual(results, ["4. 8", "16. 64"])
257  
258          # Test no merge
259          task.merge = None
260          results = list(workflow([2, 4, 6]))
261          self.assertEqual(results, [[4, 16, 36], [8, 64, 216]])
262  
263          # Test generated (id, data, tag) tuples are properly returned
264          workflow = Workflow([Task(lambda x: [(0, y, None) for y in x])])
265          results = list(workflow([(1, "text", "tags")]))
266          self.assertEqual(results[0], (0, "text", None))
267  
268      def testMergeUnbalancedWorkflow(self):
269          """
270          Test merge tasks with unbalanced outputs (i.e. one action produce more output than another for same input).
271          """
272  
273          nop = Nop()
274          segment1 = Segmentation(sentences=True)
275  
276          task = Task([nop, segment1])
277  
278          # Test hstack
279          workflow = Workflow([task])
280          results = list(workflow(["This is a test sentence. And another sentence to split."]))
281          self.assertEqual(
282              results, [("This is a test sentence. And another sentence to split.", ["This is a test sentence.", "And another sentence to split."])]
283          )
284  
285          # Test vstack
286          task.merge = "vstack"
287          workflow = Workflow([task])
288          results = list(workflow(["This is a test sentence. And another sentence to split."]))
289          self.assertEqual(
290              results, ["This is a test sentence. And another sentence to split.", "This is a test sentence.", "And another sentence to split."]
291          )
292  
293      def testNumpyWorkflow(self):
294          """
295          Test a numpy workflow
296          """
297  
298          task = Task([lambda x: np.power(x, 2), lambda x: np.power(x, 3)], merge="hstack")
299  
300          # Test hstack (column-wise) merge
301          workflow = Workflow([task])
302          results = list(workflow(np.array([2, 4])))
303          self.assertTrue(np.array_equal(np.array(results), np.array([[4, 8], [16, 64]])))
304  
305          # Test vstack (row-wise) merge
306          task.merge = "vstack"
307          results = list(workflow(np.array([2, 4])))
308          self.assertEqual(results, [4, 8, 16, 64])
309  
310          # Test no merge
311          task.merge = None
312          results = list(workflow(np.array([2, 4, 6])))
313          self.assertTrue(np.array_equal(np.array(results), np.array([[4, 16, 36], [8, 64, 216]])))
314  
315      def testRetrieveWorkflow(self):
316          """
317          Test a retrieve task
318          """
319  
320          # Test retrieve with generated temporary directory
321          workflow = Workflow([RetrieveTask()])
322          results = list(workflow(["file://" + Utils.PATH + "/books.jpg"]))
323          self.assertTrue(results[0].endswith("books.jpg"))
324  
325          # Test retrieve with specified temporary directory
326          workflow = Workflow([RetrieveTask(directory=os.path.join(tempfile.gettempdir(), "retrieve"))])
327          results = list(workflow(["file://" + Utils.PATH + "/books.jpg"]))
328          self.assertTrue(results[0].endswith("books.jpg"))
329  
330          # Test with directory structures
331          workflow = Workflow([RetrieveTask(flatten=False)])
332          results = list(workflow(["file://" + Utils.PATH + "/books.jpg"]))
333          self.assertTrue(results[0].endswith("books.jpg") and "txtai" in results[0])
334  
335      def testScheduleWorkflow(self):
336          """
337          Test workflow schedules
338          """
339  
340          # Test workflow schedule with Python
341          workflow = Workflow([Task()])
342          workflow.schedule("* * * * * *", ["test"], 1)
343          self.assertEqual(len(workflow.tasks), 1)
344  
345          # Test workflow schedule with YAML
346          workflow = """
347          segmentation:
348              sentences: true
349          workflow:
350              segment:
351                  schedule:
352                      cron: '* * * * * *'
353                      elements:
354                          - a sentence to segment
355                      iterations: 1
356                  tasks:
357                      - action: segmentation
358                        task: console
359          """
360  
361          output = io.StringIO()
362          with contextlib.redirect_stdout(output):
363              app = API(workflow)
364              app.wait()
365  
366          self.assertIn("a sentence to segment", output.getvalue())
367  
368      def testScheduleErrorWorkflow(self):
369          """
370          Test workflow schedules with errors
371          """
372  
373          def action(elements):
374              raise FileNotFoundError
375  
376          # Test workflow proceeds after exception raised
377          with self.assertLogs() as logs:
378              workflow = Workflow([Task(action=action)])
379              workflow.schedule("* * * * * *", ["test"], 1)
380  
381          self.assertIn("FileNotFoundError", " ".join(logs.output))
382  
383      def testStorageWorkflow(self):
384          """
385          Test a storage task
386          """
387  
388          workflow = Workflow([StorageTask()])
389  
390          results = list(workflow(["local://" + Utils.PATH, "test string"]))
391  
392          self.assertEqual(len(results), 22)
393  
394      def testTemplateInput(self):
395          """
396          Test template task input
397          """
398  
399          workflow = Workflow([TemplateTask(template="This is a {text}")])
400  
401          # Test with string inputs
402          results = list(workflow(["prompt"]))
403          self.assertEqual(results[0], "This is a prompt")
404  
405          # Test with dict inputs
406          results = list(workflow([{"text": "prompt"}]))
407          self.assertEqual(results[0], "This is a prompt")
408  
409          # Test with tuple inputs
410          workflow = Workflow([TemplateTask(template="This is a {arg0}", unpack=False)])
411          results = list(workflow([("prompt",)]))
412          self.assertEqual(results[0], "This is a prompt")
413  
414          # Test invalid inputs
415          with self.assertRaises(KeyError):
416              workflow = Workflow([TemplateTask(template="No variables")])
417              results = list(workflow([{"unused": "prompt"}]))
418  
419          # Test no template
420          workflow = Workflow([TemplateTask()])
421          results = list(workflow(["prompt"]))
422          self.assertEqual(results[0], "prompt")
423  
424      def testTemplateRules(self):
425          """
426          Test template task rules
427          """
428  
429          # Test rule applied
430          workflow = Workflow([TemplateTask(template="This is a {text}", rules={"text": "Test skip"})])
431          results = list(workflow([{"text": "Test skip"}]))
432          self.assertEqual(results[0], "Test skip")
433  
434          # Test rule not applied
435          results = list(workflow([{"text": "prompt"}]))
436          self.assertEqual(results[0], "This is a prompt")
437  
438      def testTemplateRag(self):
439          """
440          Test rag template task
441          """
442  
443          # Test outputs
444          workflow = Workflow([RagTask(template="This is a {text}")])
445          results = list(workflow(["prompt"]))
446          self.assertEqual(results[0], {"query": "prompt", "question": "This is a prompt"})
447  
448          # Test partial outputs
449          workflow = Workflow([RagTask(template="This is a {text}")])
450          results = list(workflow([{"query": "query", "question": "prompt"}]))
451          self.assertEqual(results[0], {"query": "query", "question": "This is a prompt"})
452  
453          # Test additional template parameters
454          workflow = Workflow([RagTask(template="This is a {text} with another {param}")])
455          results = list(workflow([{"query": "query", "question": "prompt", "param": "value"}]))
456          self.assertEqual(results[0], {"query": "query", "question": "This is a prompt with another value", "param": "value"})
457  
458      def testTensorTransformWorkflow(self):
459          """
460          Test a tensor workflow with list transformations
461          """
462  
463          # Test one-one list transformation
464          task = Task(lambda x: x.tolist())
465          workflow = Workflow([task])
466          results = list(workflow(np.array([2])))
467          self.assertEqual(results, [2])
468  
469          # Test one-many list transformation
470          task = Task(lambda x: [x.tolist() * 2])
471          workflow = Workflow([task])
472          results = list(workflow(np.array([2])))
473          self.assertEqual(results, [2, 2])
474  
475      def testTorchWorkflow(self):
476          """
477          Test a torch workflow
478          """
479  
480          # pylint: disable=E1101,E1102
481          task = Task([lambda x: torch.pow(x, 2), lambda x: torch.pow(x, 3)], merge="hstack")
482  
483          # Test hstack (column-wise) merge
484          workflow = Workflow([task])
485          results = np.array([x.numpy() for x in workflow(torch.tensor([2, 4]))])
486          self.assertTrue(np.array_equal(results, np.array([[4, 8], [16, 64]])))
487  
488          # Test vstack (row-wise) merge
489          task.merge = "vstack"
490          results = list(workflow(torch.tensor([2, 4])))
491          self.assertEqual(results, [4, 8, 16, 64])
492  
493          # Test no merge
494          task.merge = None
495          results = np.array([x.numpy() for x in workflow(torch.tensor([2, 4, 6]))])
496          self.assertTrue(np.array_equal(np.array(results), np.array([[4, 16, 36], [8, 64, 216]])))
497  
498      def testYamlFunctionWorkflow(self):
499          """
500          Test YAML workflow with a function action
501          """
502  
503          # Create function and add to module
504          def action(elements):
505              return [x * 2 for x in elements]
506  
507          sys.modules[__name__].action = action
508  
509          workflow = """
510          workflow:
511              run:
512                  tasks:
513                      - testworkflow.action
514          """
515  
516          app = API(workflow)
517          self.assertEqual(list(app.workflow("run", [1, 2])), [2, 4])
518  
519      def testYamlIndexWorkflow(self):
520          """
521          Test reading a YAML index workflow in Python.
522          """
523  
524          app = API(self.config)
525          self.assertEqual(
526              list(app.workflow("index", ["This is a test sentence. And another sentence to split."])),
527              ["This is a test sentence.", "And another sentence to split."],
528          )
529  
530          # Read from file
531          path = os.path.join(tempfile.gettempdir(), "workflow.yml")
532          with open(path, "w", encoding="utf-8") as f:
533              f.write(self.config)
534  
535          app = API(path)
536          self.assertEqual(
537              list(app.workflow("index", ["This is a test sentence. And another sentence to split."])),
538              ["This is a test sentence.", "And another sentence to split."],
539          )
540  
541          # Read from YAML object
542          app = API(API.read(self.config))
543          self.assertEqual(
544              list(app.workflow("index", ["This is a test sentence. And another sentence to split."])),
545              ["This is a test sentence.", "And another sentence to split."],
546          )
547  
548      def testYamlSearchWorkflow(self):
549          """
550          Test reading a YAML search workflow in Python.
551          """
552  
553          # Test search
554          app = API(self.config)
555          list(app.workflow("index", ["This is a test sentence. And another sentence to split."]))
556          self.assertEqual(
557              list(app.workflow("search", ["another"]))[0]["text"],
558              "And another sentence to split.",
559          )
560  
561      def testYamlWorkflowTask(self):
562          """
563          Test YAML workflow with a workflow task
564          """
565  
566          # Create function and add to module
567          def action(elements):
568              return [x * 2 for x in elements]
569  
570          sys.modules[__name__].action = action
571  
572          workflow = """
573          workflow:
574              run:
575                  tasks:
576                      - testworkflow.action
577              flow:
578                  tasks:
579                      - run
580          """
581  
582          app = API(workflow)
583          self.assertEqual(list(app.workflow("flow", [1, 2])), [2, 4])
584  
585      def testYamlTransformWorkflow(self):
586          """
587          Test reading a YAML transform workflow in Python.
588          """
589  
590          # Test search
591          app = API(self.config)
592          self.assertEqual(len(list(app.workflow("transform", ["text"]))[0]), 128)
593  
594      def testYamlError(self):
595          """
596          Test reading a YAML workflow with errors.
597          """
598  
599          # Read from string
600          config = """
601          # Workflow definitions
602          workflow:
603              error:
604                  tasks:
605                      - action: error
606          """
607  
608          with self.assertRaises(KeyError):
609              API(config)