/ src / python / txtai / console / base.py
base.py
  1  """
  2  Console module
  3  """
  4  
  5  import os
  6  import shlex
  7  
  8  from cmd import Cmd
  9  
 10  # Conditional import
 11  try:
 12      from rich import box
 13      from rich.console import Console as RichConsole
 14      from rich.table import Table
 15  
 16      RICH = True
 17  except ImportError:
 18      RICH = False
 19  
 20  from txtai.app import Application
 21  from txtai.embeddings import Embeddings
 22  
 23  
 24  class Console(Cmd):
 25      """
 26      txtai console.
 27      """
 28  
 29      def __init__(self, path=None):
 30          """
 31          Creates a new command line console.
 32  
 33          Args:
 34              path: path to initial configuration, if any
 35          """
 36  
 37          super().__init__()
 38  
 39          if not RICH:
 40              raise ImportError('Console is not available - install "console" extra to enable')
 41  
 42          self.prompt = ">>> "
 43  
 44          # Rich console
 45          self.console = RichConsole()
 46  
 47          # App parameters
 48          self.app = None
 49          self.path = path
 50  
 51          # Parameters
 52          self.vhighlight = None
 53          self.vlimit = None
 54  
 55      def preloop(self):
 56          """
 57          Loads initial configuration.
 58          """
 59  
 60          self.console.print("txtai console", style="#03a9f4")
 61  
 62          # Load default path
 63          if self.path:
 64              self.load(self.path)
 65  
 66      def default(self, line):
 67          """
 68          Default event loop.
 69  
 70          Args:
 71              line: command line
 72          """
 73  
 74          # pylint: disable=W0703
 75          try:
 76              command = line.lower()
 77              if command.startswith(".config"):
 78                  self.config()
 79              elif command.startswith(".highlight"):
 80                  self.highlight(command)
 81              elif command.startswith(".limit"):
 82                  self.limit(command)
 83              elif command.startswith(".load"):
 84                  command = self.split(line)
 85                  self.path = command[1]
 86                  self.load(self.path)
 87              elif command.startswith(".workflow"):
 88                  self.workflow(line)
 89              else:
 90                  # Search is default action
 91                  self.search(line)
 92          except Exception:
 93              self.console.print_exception()
 94  
 95      def config(self):
 96          """
 97          Processes .config command.
 98          """
 99  
100          self.console.print(self.app.config)
101  
102      def highlight(self, command):
103          """
104          Processes .highlight command.
105  
106          Args:
107              command: command line
108          """
109  
110          _, action = self.split(command, "#ffff00")
111          self.vhighlight = action
112          self.console.print(f"Set highlight to {self.vhighlight}")
113  
114      def limit(self, command):
115          """
116          Processes .limit command.
117  
118          Args:
119              command: command line
120          """
121  
122          _, action = self.split(command, 10)
123          self.vlimit = int(action)
124          self.console.print(f"Set limit to {self.vlimit}")
125  
126      def load(self, path):
127          """
128          Processes .load command.
129  
130          Args:
131              path: path to configuration
132          """
133  
134          if self.isyaml(path):
135              self.console.print(f"Loading application {path}")
136              self.app = Application(path)
137          else:
138              self.console.print(f"Loading index {path}")
139  
140              # Load embeddings index
141              self.app = Embeddings()
142              self.app.load(path)
143  
144      def search(self, query):
145          """
146          Runs a search query.
147  
148          Args:
149              query: query to run
150          """
151  
152          if self.vhighlight:
153              results = self.app.explain(query, limit=self.vlimit)
154          else:
155              results = self.app.search(query, limit=self.vlimit)
156  
157          columns, table = {}, Table(box=box.SQUARE, style="#03a9f4")
158  
159          # Build column list
160          result = results[0]
161          if isinstance(result, tuple):
162              columns = dict.fromkeys(["id", "score"])
163          else:
164              columns = dict(result)
165  
166          # Add columns to table
167          columns = list(x for x in columns if x != "tokens")
168          for column in columns:
169              table.add_column(column)
170  
171          # Add rows to table
172          for result in results:
173              if isinstance(result, tuple):
174                  table.add_row(*(self.render(result, None, x) for x in result))
175              else:
176                  table.add_row(*(self.render(result, column, result.get(column)) for column in columns))
177  
178          # Print table to console
179          self.console.print(table)
180  
181      def workflow(self, command):
182          """
183          Processes .workflow command.
184  
185          Args:
186              command: command line
187          """
188  
189          command = shlex.split(command)
190          if isinstance(self.app, Application):
191              self.console.print(list(self.app.workflow(command[1], command[2:])))
192  
193      def isyaml(self, path):
194          """
195          Checks if file at path is a valid YAML file.
196  
197          Args:
198              path: file to check
199  
200          Returns:
201              True if file is valid YAML, False otherwise
202          """
203  
204          if os.path.exists(path) and os.path.isfile(path):
205              try:
206                  return Application.read(path)
207              # pylint: disable=W0702
208              except:
209                  pass
210  
211          return False
212  
213      def split(self, command, default=None):
214          """
215          Splits command by whitespace.
216  
217          Args:
218              command: command line
219              default: default command action
220  
221          Returns:
222              command action
223          """
224  
225          values = command.split(" ", 1)
226          return values if len(values) > 1 else (command, default)
227  
228      def render(self, result, column, value):
229          """
230          Renders a search result column value.
231  
232          Args:
233              result: result row
234              column: column name
235              value: column value
236          """
237  
238          if isinstance(value, float):
239              return f"{value:.4f}"
240  
241          # Explain highlighting
242          if column == "text" and "tokens" in result:
243              spans = []
244              for token, score in result["tokens"]:
245                  color = None
246                  if score >= 0.02:
247                      color = f"b {self.vhighlight}"
248  
249                  spans.append((token, score, color))
250  
251              if result["score"] >= 0.05 and not [color for _, _, color in spans if color]:
252                  mscore = max(score for _, score, _ in spans)
253                  spans = [(token, score, f"b {self.vhighlight}" if score == mscore else color) for token, score, color in spans]
254  
255              output = ""
256              for token, _, color in spans:
257                  if color:
258                      output += f"[{color}]{token}[/{color}] "
259                  else:
260                      output += f"{token} "
261  
262              return output
263  
264          return str(value)