/ examples / workflows.py
workflows.py
  1  """
  2  Build txtai workflows.
  3  
  4  Requires streamlit to be installed.
  5    pip install streamlit
  6  """
  7  
  8  import contextlib
  9  import copy
 10  import os
 11  import re
 12  import tempfile
 13  import threading
 14  import time
 15  
 16  import uvicorn
 17  import yaml
 18  
 19  import pandas as pd
 20  import streamlit as st
 21  
 22  import txtai.api.application
 23  import txtai.app
 24  
 25  
 26  class Server(uvicorn.Server):
 27      """
 28      Threaded uvicorn server used to bring up an API service.
 29      """
 30  
 31      def __init__(self, application=None, host="127.0.0.1", port=8000, log_level="info"):
 32          """
 33          Initialize server configuration.
 34          """
 35  
 36          config = uvicorn.Config(application, host=host, port=port, log_level=log_level)
 37          super().__init__(config)
 38  
 39      def install_signal_handlers(self):
 40          """
 41          Signal handlers no-op.
 42          """
 43  
 44      @contextlib.contextmanager
 45      def service(self):
 46          """
 47          Runs threaded server service.
 48          """
 49  
 50          # pylint: disable=W0201
 51          thread = threading.Thread(target=self.run)
 52          thread.start()
 53          try:
 54              while not self.started:
 55                  time.sleep(1e-3)
 56              yield
 57  
 58          finally:
 59              self.should_exit = True
 60              thread.join()
 61  
 62  
 63  class Process:
 64      """
 65      Container for an active Workflow process instance.
 66      """
 67  
 68      @staticmethod
 69      @st.cache_resource(show_spinner=False)
 70      def get(name, config):
 71          """
 72          Lookup or creates a new workflow process instance.
 73  
 74          Args:
 75              name: workflow name
 76              config: application configuration
 77  
 78          Returns:
 79              Process
 80          """
 81  
 82          process = Process()
 83  
 84          # Build workflow
 85          with st.spinner("Building workflow...."):
 86              process.build(name, config)
 87  
 88          return process
 89  
 90      def __init__(self):
 91          """
 92          Creates a new Process.
 93          """
 94  
 95          # Application handle
 96          self.application = None
 97  
 98          # Workflow name
 99          self.name = None
100  
101          # Workflow data
102          self.data = None
103  
104      def build(self, name, config):
105          """
106          Builds an application.
107  
108          Args:
109              name: workflow name
110              config: application configuration
111          """
112  
113          # Create application
114          self.application = txtai.app.Application(config)
115  
116          # Workflow name
117          self.name = name
118  
119      def run(self, data):
120          """
121          Runs a workflow using data as input.
122  
123          Args:
124              data: input data
125          """
126  
127          if data and self.application:
128              # Build tuples for embedding index
129              if self.application.embeddings:
130                  data = [(x, element, None) for x, element in enumerate(data)]
131  
132              # Process workflow
133              with st.spinner("Running workflow...."):
134                  results = []
135                  for result in self.application.workflow(self.name, data):
136                      # Store result
137                      results.append(result)
138  
139                      # Write result if this isn't an indexing workflow
140                      if not self.application.embeddings:
141                          st.write(result)
142  
143                  # Store workflow results
144                  self.data = results
145  
146      def search(self, query):
147          """
148          Runs a search.
149  
150          Args:
151              query: input query
152          """
153  
154          if self.application and query:
155              st.markdown(
156                  """
157              <style>
158              table td:nth-child(1) {
159                  display: none
160              }
161              table th:nth-child(1) {
162                  display: none
163              }
164              table {text-align: left !important}
165              </style>
166              """,
167                  unsafe_allow_html=True,
168              )
169  
170              results = []
171              for result in self.application.search(query, 5):
172                  # Text is only present when content is stored
173                  if "text" not in result:
174                      uid, score = result["id"], result["score"]
175                      results.append({"text": self.find(uid), "score": f"{score:.2}"})
176                  else:
177                      if "id" in result and "text" in result:
178                          result["text"] = self.content(result.pop("id"), result["text"])
179                      if "score" in result and result["score"]:
180                          result["score"] = f'{result["score"]:.2}'
181  
182                      results.append(result)
183  
184              df = pd.DataFrame(results)
185              st.write(df.to_html(escape=False), unsafe_allow_html=True)
186  
187      def find(self, key):
188          """
189          Lookup record from cached data by uid key.
190  
191          Args:
192              key: id to search for
193  
194          Returns:
195              text for matching id
196          """
197  
198          # Lookup text by id
199          text = [text for uid, text, _ in self.data if uid == key][0]
200          return self.content(key, text)
201  
202      def content(self, uid, text):
203          """
204          Builds a content reference for uid and text.
205  
206          Args:
207              uid: record id
208              text: record text
209  
210          Returns:
211              content
212          """
213  
214          if uid and isinstance(uid, str) and uid.lower().startswith("http"):
215              return f"<a href='{uid}' rel='noopener noreferrer' target='blank'>{text}</a>"
216  
217          return text
218  
219  
220  class Application:
221      """
222      Main application.
223      """
224  
225      def load(self, components):
226          """
227          Load an existing workflow file.
228  
229          Args:
230              components: list of components to load
231  
232          Returns:
233              (names of components loaded, workflow config, file changed)
234          """
235  
236          workflow = st.file_uploader("Load workflow", type=["yml"])
237          if workflow:
238              # Detect file upload change
239              upload = workflow.name != self.state("path")
240              st.session_state["path"] = workflow.name
241  
242              workflow = yaml.safe_load(workflow)
243  
244              st.markdown("---")
245  
246              # Get tasks for first workflow
247              tasks = list(workflow["workflow"].values())[0]["tasks"]
248              selected = []
249  
250              for task in tasks:
251                  name = task.get("action", task.get("task"))
252                  if name in components:
253                      selected.append(name)
254                  elif name in ["index", "upsert"]:
255                      selected.append("embeddings")
256  
257              return (selected, workflow, upload)
258  
259          return (None, None, None)
260  
261      def state(self, key):
262          """
263          Lookup a session state variable.
264  
265          Args:
266              key: variable key
267  
268          Returns:
269              variable value
270          """
271  
272          if key in st.session_state:
273              return st.session_state[key]
274  
275          return None
276  
277      def appsetting(self, workflow, name):
278          """
279          Looks up an application configuration setting.
280  
281          Args:
282              workflow: workflow configuration
283              name: setting name
284  
285          Returns:
286              app setting value
287          """
288  
289          if workflow:
290              config = workflow.get("app")
291              if config:
292                  return config.get(name)
293  
294          return None
295  
296      def setting(self, config, name, default=None):
297          """
298          Looks up a component configuration setting.
299  
300          Args:
301              config: component configuration
302              name: setting name
303              default: default setting value
304  
305          Returns:
306              setting value
307          """
308  
309          return config.get(name, default) if config else default
310  
311      def text(self, label, component, config, name, default=None):
312          """
313          Create a new text input field.
314  
315          Args:
316              label: field label
317              component: component name
318              config: component configuration
319              name: setting name
320              default: default setting value
321  
322          Returns:
323              text input field value
324          """
325  
326          default = self.setting(config, name, default)
327          if not default:
328              default = ""
329          elif isinstance(default, list):
330              default = ",".join(default)
331          elif isinstance(default, dict):
332              default = ",".join(default.keys())
333  
334          return st.text_input(label, value=default, key=component + name)
335  
336      def number(self, label, component, config, name, default=None):
337          """
338          Creates a new numeric input field.
339  
340          Args:
341              label: field label
342              component: component name
343              config: component configuration
344              name: setting name
345              default: default setting value
346  
347          Returns:
348              numeric value
349          """
350  
351          value = self.text(label, component, config, name, default)
352          return int(value) if value else None
353  
354      def boolean(self, label, component, config, name, default=False):
355          """
356          Creates a new checkbox field.
357  
358          Args:
359              label: field label
360              component: component name
361              config: component configuration
362              name: setting name
363              default: default setting value
364  
365          Returns:
366              boolean value
367          """
368  
369          default = self.setting(config, name, default)
370          return st.checkbox(label, value=default, key=component + name)
371  
372      def select(self, label, component, config, name, options, default=0):
373          """
374          Creates a new select box field.
375  
376          Args:
377              label: field label
378              component: component name
379              config: component configuration
380              name: setting name
381              options: list of dropdown options
382              default: default setting value
383  
384          Returns:
385              boolean value
386          """
387  
388          index = self.setting(config, name)
389          index = [x for x, option in enumerate(options) if option == default]
390  
391          # Derive default index
392          default = index[0] if index else default
393  
394          return st.selectbox(label, options, index=default, key=component + name)
395  
396      def split(self, text):
397          """
398          Splits text on commas and returns a list.
399  
400          Args:
401              text: input text
402  
403          Returns:
404              list
405          """
406  
407          return [x.strip() for x in text.split(",")]
408  
409      def options(self, component, workflow, index):
410          """
411          Extracts component settings into a component configuration dict.
412  
413          Args:
414              component: component type
415              workflow: existing workflow, can be None
416              index: task index
417  
418          Returns:
419              dict with component settings
420          """
421  
422          # pylint: disable=R0912, R0915
423          options = {"type": component}
424  
425          st.markdown("---")
426  
427          # Lookup component configuration
428          #   - Runtime components have config defined within tasks
429          #   - Pipeline components have config defined at workflow root
430          config = None
431          if workflow:
432              if component in ["service", "translation"]:
433                  # Service config is found in tasks section
434                  tasks = list(workflow["workflow"].values())[0]["tasks"]
435                  tasks = [task for task in tasks if task.get("task") == component or task.get("action") == component]
436                  if tasks:
437                      config = tasks[0]
438              else:
439                  config = workflow.get(component)
440  
441          if component == "embeddings":
442              st.markdown(f"**{index + 1}.) Embeddings Index**  \n*Index workflow output*")
443              options["index"] = self.text("Embeddings storage path", component, config, "index")
444              options["path"] = self.text("Embeddings model path", component, config, "path", "sentence-transformers/nli-mpnet-base-v2")
445              options["upsert"] = self.boolean("Upsert", component, config, "upsert")
446              options["content"] = self.boolean("Content", component, config, "content")
447  
448          elif component in ("segmentation", "textractor"):
449              if component == "segmentation":
450                  st.markdown(f"**{index + 1}.) Segment**  \n*Split text into semantic units*")
451              else:
452                  st.markdown(f"**{index + 1}.) Textract**  \n*Extract text from documents*")
453  
454              options["sentences"] = self.boolean("Split sentences", component, config, "sentences")
455              options["lines"] = self.boolean("Split lines", component, config, "lines")
456              options["paragraphs"] = self.boolean("Split paragraphs", component, config, "paragraphs")
457              options["join"] = self.boolean("Join tokenized", component, config, "join")
458              options["minlength"] = self.number("Min section length", component, config, "minlength")
459  
460          elif component == "service":
461              st.markdown(f"**{index + 1}.) Service**  \n*Extract data from an API*")
462              options["url"] = self.text("URL", component, config, "url")
463              options["method"] = self.select("Method", component, config, "method", ["get", "post"], 0)
464              options["params"] = self.text("URL parameters", component, config, "params")
465              options["batch"] = self.boolean("Run as batch", component, config, "batch", True)
466              options["extract"] = self.text("Subsection(s) to extract", component, config, "extract")
467  
468              if options["params"]:
469                  options["params"] = {key: None for key in self.split(options["params"])}
470              if options["extract"]:
471                  options["extract"] = self.split(options["extract"])
472  
473          elif component == "summary":
474              st.markdown(f"**{index + 1}.) Summary**  \n*Abstractive text summarization*")
475              options["path"] = self.text("Model", component, config, "path", "sshleifer/distilbart-cnn-12-6")
476              options["minlength"] = self.number("Min length", component, config, "minlength")
477              options["maxlength"] = self.number("Max length", component, config, "maxlength")
478  
479          elif component == "tabular":
480              st.markdown(f"**{index + 1}.) Tabular**  \n*Split tabular data into rows and columns*")
481              options["idcolumn"] = self.text("Id columns", component, config, "idcolumn")
482              options["textcolumns"] = self.text("Text columns", component, config, "textcolumns")
483              options["content"] = self.text("Content", component, config, "content")
484  
485              if options["textcolumns"]:
486                  options["textcolumns"] = self.split(options["textcolumns"])
487  
488              if options["content"]:
489                  options["content"] = self.split(options["content"])
490                  if len(options["content"]) == 1 and options["content"][0] == "1":
491                      options["content"] = options["content"][0]
492  
493          elif component == "transcription":
494              st.markdown(f"**{index + 1}.) Transcribe**  \n*Transcribe audio to text*")
495              options["path"] = self.text("Model", component, config, "path", "facebook/wav2vec2-base-960h")
496  
497          elif component == "translation":
498              st.markdown(f"**{index + 1}.) Translate**  \n*Machine translation*")
499              options["target"] = self.text("Target language code", component, config, "args", "en")
500  
501          return options
502  
503      def config(self, components):
504          """
505          Builds configuration for components
506  
507          Args:
508              components: list of components to add to configuration
509  
510          Returns:
511              (workflow name, configuration)
512          """
513  
514          data = {}
515          tasks = []
516          name = None
517  
518          for component in components:
519              component = dict(component)
520              name = wtype = component.pop("type")
521  
522              if wtype == "embeddings":
523                  index = component.pop("index")
524                  upsert = component.pop("upsert")
525  
526                  data[wtype] = component
527                  data["writable"] = True
528  
529                  if index:
530                      data["path"] = index
531  
532                  name = "index"
533                  tasks.append({"action": "upsert" if upsert else "index"})
534  
535              elif wtype == "segmentation":
536                  data[wtype] = component
537                  tasks.append({"action": wtype})
538  
539              elif wtype == "service":
540                  config = {**component}
541                  config["task"] = wtype
542                  tasks.append(config)
543  
544              elif wtype == "summary":
545                  data[wtype] = {"path": component.pop("path")}
546                  tasks.append({"action": wtype})
547  
548              elif wtype == "tabular":
549                  data[wtype] = component
550                  tasks.append({"action": wtype})
551  
552              elif wtype == "textractor":
553                  data[wtype] = component
554                  tasks.append({"action": wtype, "task": "url"})
555  
556              elif wtype == "transcription":
557                  data[wtype] = {"path": component.pop("path")}
558                  tasks.append({"action": wtype, "task": "url"})
559  
560              elif wtype == "translation":
561                  data[wtype] = {}
562                  tasks.append({"action": wtype, "args": list(component.values())})
563  
564          # Add in workflow
565          data["workflow"] = {name: {"tasks": tasks}}
566  
567          # Return workflow name and application configuration
568          return (name, data)
569  
570      def api(self, config):
571          """
572          Starts an internal uvicorn server to host an API service for the current workflow.
573  
574          Args:
575              config: workflow configuration as YAML string
576          """
577  
578          # Generate workflow file
579          workflow = os.path.join(tempfile.gettempdir(), "workflow.yml")
580          with open(workflow, "w", encoding="utf-8") as f:
581              f.write(config)
582  
583          os.environ["CONFIG"] = workflow
584          txtai.api.application.start()
585          server = Server(txtai.api.application.app)
586          with server.service():
587              uid = 0
588              while True:
589                  stop = st.empty()
590                  click = stop.button("stop", key=uid)
591                  if not click:
592                      time.sleep(5)
593                      uid += 1
594                  stop.empty()
595  
596      def inputs(self, selected, workflow):
597          """
598          Generate process input fields.
599  
600          Args:
601              selected: list of selected components
602              workflow: workflow configuration
603  
604          Returns:
605              True if inputs changed, False otherwise
606          """
607  
608          change, query = False, None
609          with st.expander("Data", expanded="embeddings" not in selected):
610              default = self.appsetting(workflow, "data")
611              default = default if default else ""
612  
613              data = st.text_area("Input", height=10, value=default)
614  
615              if selected and data and data != self.state("data"):
616                  change = True
617  
618              # Save data and workflow state
619              st.session_state["data"] = data
620  
621          if "embeddings" in selected:
622              default = self.appsetting(workflow, "query")
623              default = default if default else ""
624  
625              # Set query and limit
626              query = st.text_input("Query", value=default)
627  
628              if selected and query and query != self.state("query"):
629                  change = True
630  
631          # Save query state
632          st.session_state["query"] = query
633  
634          return change or self.state("api") or self.state("download")
635  
636      def data(self):
637          """
638          Gets input data.
639  
640          Returns:
641              input data
642          """
643  
644          data = self.state("data")
645  
646          # Split on newlines if urls detected, allows a list of urls to be processed
647          if re.match(r"^(http|https|file):\/\/", data):
648              return [x for x in data.split("\n") if x]
649  
650          return [data]
651  
652      def process(self, components, index):
653          """
654          Processes the current application action.
655  
656          Args:
657              components: workflow components
658              index: True if this is an indexing workflow
659          """
660  
661          # Generate application configuration
662          name, config = self.config(components)
663  
664          # Get workflow process
665          process = Process.get(name, copy.deepcopy(config))
666  
667          # Run workflow process
668          process.run(self.data())
669  
670          # Run search
671          if index:
672              process.search(self.state("query"))
673  
674          return name, config
675  
676      def run(self):
677          """
678          Runs Streamlit application.
679          """
680  
681          build = False
682          with st.sidebar:
683              st.image("https://github.com/neuml/txtai/raw/master/logo.png", width=256)
684              st.markdown("# Workflow builder  \n*Build and apply workflows to data*  ")
685              st.markdown("---")
686  
687              # Component configuration
688              labels = {"segmentation": "segment", "textractor": "textract", "transcription": "transcribe", "translation": "translate"}
689              components = ["embeddings", "segmentation", "service", "summary", "tabular", "textractor", "transcription", "translation"]
690  
691              selected, workflow, upload = self.load(components)
692              selected = st.multiselect("Select components", components, default=selected, format_func=lambda text: labels.get(text, text))
693  
694              if selected:
695                  st.markdown(
696                      """
697                  <style>
698                  [data-testid="stForm"] {
699                      border: 0;
700                      padding: 0;
701                  }
702                  </style>
703                  """,
704                      unsafe_allow_html=True,
705                  )
706  
707                  with st.form("workflow"):
708                      # Get selected options
709                      components = [self.options(component, workflow, x) for x, component in enumerate(selected)]
710                      st.markdown("---")
711  
712                      # Build or re-build workflow when build button clicked or new workflow loaded
713                      build = st.form_submit_button("Build", help="Build the workflow and run within this application")
714  
715          # Generate input fields
716          inputs = self.inputs(selected, workflow)
717  
718          # Only execute if build button clicked, new workflow uploaded or inputs changed
719          if build or upload or inputs:
720              # Process current action
721              name, config = self.process(components, "embeddings" in selected)
722  
723              with st.sidebar:
724                  with st.expander("Other Actions", expanded=True):
725                      col1, col2 = st.columns(2)
726  
727                      # Add state information to configuration and export to YAML string
728                      config = config.copy()
729                      config.update({"app": {"data": self.state("data"), "query": self.state("query")}})
730                      config = yaml.dump(config)
731  
732                      api = col1.button("API", key="api", help="Start an API instance within this application")
733                      if api:
734                          with st.spinner(f"Running workflow '{name}' via API service, click stop to terminate"):
735                              self.api(config)
736  
737                      col2.download_button("Export", config, file_name="workflow.yml", key="download", help="Export the API workflow as YAML")
738  
739  
740  if __name__ == "__main__":
741      os.environ["TOKENIZERS_PARALLELISM"] = "false"
742  
743      # Create and run application
744      app = Application()
745      app.run()