base.py
1 """ 2 Console module 3 """ 4 5 import os 6 import shlex 7 8 from cmd import Cmd 9 10 # Conditional import 11 try: 12 from rich import box 13 from rich.console import Console as RichConsole 14 from rich.table import Table 15 16 RICH = True 17 except ImportError: 18 RICH = False 19 20 from txtai.app import Application 21 from txtai.embeddings import Embeddings 22 23 24 class Console(Cmd): 25 """ 26 txtai console. 27 """ 28 29 def __init__(self, path=None): 30 """ 31 Creates a new command line console. 32 33 Args: 34 path: path to initial configuration, if any 35 """ 36 37 super().__init__() 38 39 if not RICH: 40 raise ImportError('Console is not available - install "console" extra to enable') 41 42 self.prompt = ">>> " 43 44 # Rich console 45 self.console = RichConsole() 46 47 # App parameters 48 self.app = None 49 self.path = path 50 51 # Parameters 52 self.vhighlight = None 53 self.vlimit = None 54 55 def preloop(self): 56 """ 57 Loads initial configuration. 58 """ 59 60 self.console.print("txtai console", style="#03a9f4") 61 62 # Load default path 63 if self.path: 64 self.load(self.path) 65 66 def default(self, line): 67 """ 68 Default event loop. 69 70 Args: 71 line: command line 72 """ 73 74 # pylint: disable=W0703 75 try: 76 command = line.lower() 77 if command.startswith(".config"): 78 self.config() 79 elif command.startswith(".highlight"): 80 self.highlight(command) 81 elif command.startswith(".limit"): 82 self.limit(command) 83 elif command.startswith(".load"): 84 command = self.split(line) 85 self.path = command[1] 86 self.load(self.path) 87 elif command.startswith(".workflow"): 88 self.workflow(line) 89 else: 90 # Search is default action 91 self.search(line) 92 except Exception: 93 self.console.print_exception() 94 95 def config(self): 96 """ 97 Processes .config command. 98 """ 99 100 self.console.print(self.app.config) 101 102 def highlight(self, command): 103 """ 104 Processes .highlight command. 105 106 Args: 107 command: command line 108 """ 109 110 _, action = self.split(command, "#ffff00") 111 self.vhighlight = action 112 self.console.print(f"Set highlight to {self.vhighlight}") 113 114 def limit(self, command): 115 """ 116 Processes .limit command. 117 118 Args: 119 command: command line 120 """ 121 122 _, action = self.split(command, 10) 123 self.vlimit = int(action) 124 self.console.print(f"Set limit to {self.vlimit}") 125 126 def load(self, path): 127 """ 128 Processes .load command. 129 130 Args: 131 path: path to configuration 132 """ 133 134 if self.isyaml(path): 135 self.console.print(f"Loading application {path}") 136 self.app = Application(path) 137 else: 138 self.console.print(f"Loading index {path}") 139 140 # Load embeddings index 141 self.app = Embeddings() 142 self.app.load(path) 143 144 def search(self, query): 145 """ 146 Runs a search query. 147 148 Args: 149 query: query to run 150 """ 151 152 if self.vhighlight: 153 results = self.app.explain(query, limit=self.vlimit) 154 else: 155 results = self.app.search(query, limit=self.vlimit) 156 157 columns, table = {}, Table(box=box.SQUARE, style="#03a9f4") 158 159 # Build column list 160 result = results[0] 161 if isinstance(result, tuple): 162 columns = dict.fromkeys(["id", "score"]) 163 else: 164 columns = dict(result) 165 166 # Add columns to table 167 columns = list(x for x in columns if x != "tokens") 168 for column in columns: 169 table.add_column(column) 170 171 # Add rows to table 172 for result in results: 173 if isinstance(result, tuple): 174 table.add_row(*(self.render(result, None, x) for x in result)) 175 else: 176 table.add_row(*(self.render(result, column, result.get(column)) for column in columns)) 177 178 # Print table to console 179 self.console.print(table) 180 181 def workflow(self, command): 182 """ 183 Processes .workflow command. 184 185 Args: 186 command: command line 187 """ 188 189 command = shlex.split(command) 190 if isinstance(self.app, Application): 191 self.console.print(list(self.app.workflow(command[1], command[2:]))) 192 193 def isyaml(self, path): 194 """ 195 Checks if file at path is a valid YAML file. 196 197 Args: 198 path: file to check 199 200 Returns: 201 True if file is valid YAML, False otherwise 202 """ 203 204 if os.path.exists(path) and os.path.isfile(path): 205 try: 206 return Application.read(path) 207 # pylint: disable=W0702 208 except: 209 pass 210 211 return False 212 213 def split(self, command, default=None): 214 """ 215 Splits command by whitespace. 216 217 Args: 218 command: command line 219 default: default command action 220 221 Returns: 222 command action 223 """ 224 225 values = command.split(" ", 1) 226 return values if len(values) > 1 else (command, default) 227 228 def render(self, result, column, value): 229 """ 230 Renders a search result column value. 231 232 Args: 233 result: result row 234 column: column name 235 value: column value 236 """ 237 238 if isinstance(value, float): 239 return f"{value:.4f}" 240 241 # Explain highlighting 242 if column == "text" and "tokens" in result: 243 spans = [] 244 for token, score in result["tokens"]: 245 color = None 246 if score >= 0.02: 247 color = f"b {self.vhighlight}" 248 249 spans.append((token, score, color)) 250 251 if result["score"] >= 0.05 and not [color for _, _, color in spans if color]: 252 mscore = max(score for _, score, _ in spans) 253 spans = [(token, score, f"b {self.vhighlight}" if score == mscore else color) for token, score, color in spans] 254 255 output = "" 256 for token, _, color in spans: 257 if color: 258 output += f"[{color}]{token}[/{color}] " 259 else: 260 output += f"{token} " 261 262 return output 263 264 return str(value)