/ src / python / txtai / database / sql / base.py
base.py
  1  """
  2  SQL module
  3  """
  4  
  5  from io import StringIO
  6  from shlex import shlex
  7  
  8  from .expression import Expression
  9  
 10  
 11  class SQL:
 12      """
 13      Translates txtai SQL statements into database native queries.
 14      """
 15  
 16      # List of clauses to parse
 17      CLAUSES = ["select", "from", "where", "group", "having", "order", "limit", "offset"]
 18  
 19      def __init__(self, database=None, tolist=False):
 20          """
 21          Creates a new SQL query parser.
 22  
 23          Args:
 24              database: database instance that provides resolver callback, if any
 25              tolist: outputs expression lists if True, expression text otherwise, defaults to False
 26          """
 27  
 28          # Expression parser
 29          self.expression = Expression(database.resolve if database else self.defaultresolve, tolist)
 30  
 31      def __call__(self, query):
 32          """
 33          Parses an input SQL query and normalizes column names in the query clauses. This method will also embed
 34          similarity search placeholders into the query.
 35  
 36          Args:
 37              query: input query
 38  
 39          Returns:
 40              {clause name: clause text}
 41          """
 42  
 43          clauses = None
 44          if self.issql(query):
 45              # Ignore multiple statements
 46              query = query.split(";")[0]
 47  
 48              # Tokenize query
 49              tokens, positions = self.tokenize(query)
 50  
 51              # Alias clauses and similar queries
 52              aliases, similar = {}, []
 53  
 54              # Parse SQL clauses
 55              clauses = {
 56                  "select": self.parse(tokens, positions, "select", alias=True, aliases=aliases),
 57                  "where": self.parse(tokens, positions, "where", aliases=aliases, similar=similar),
 58                  "groupby": self.parse(tokens, positions, "group", offset=2, aliases=aliases),
 59                  "having": self.parse(tokens, positions, "having", aliases=aliases),
 60                  "orderby": self.parse(tokens, positions, "order", offset=2, aliases=aliases),
 61                  "limit": self.parse(tokens, positions, "limit", aliases=aliases),
 62                  "offset": self.parse(tokens, positions, "offset", aliases=aliases),
 63              }
 64  
 65              # Add parsed similar queries, if any
 66              if similar:
 67                  clauses["similar"] = similar
 68  
 69          # Return clauses, default to full query if this is not a SQL query
 70          return clauses if clauses else {"similar": [[query]]}
 71  
 72      # pylint: disable=W0613
 73      def defaultresolve(self, name, alias=None):
 74          """
 75          Default resolve function. Performs no processing, only returns name.
 76  
 77          Args:
 78              name: query column name
 79              alias: alias name, defaults to None
 80  
 81          Returns:
 82              name
 83          """
 84  
 85          return name
 86  
 87      def issql(self, query):
 88          """
 89          Detects if this is a SQL query.
 90  
 91          Args:
 92              query: input query
 93  
 94          Returns:
 95              True if this is a valid SQL query, False otherwise
 96          """
 97  
 98          if isinstance(query, str):
 99              # Reduce query to a lower-cased single line stripped of leading/trailing whitespace
100              query = query.lower().strip(";").replace("\n", " ").replace("\t", " ").strip()
101  
102              # Detect if this is a valid txtai SQL statement
103              return query.startswith("select ") and (" from txtai " in query or query.endswith(" from txtai"))
104  
105          return False
106  
107      def snippet(self, text):
108          """
109          Parses a partial SQL snippet.
110  
111          Args:
112              text: SQL snippet
113  
114          Returns:
115              parsed snippet
116          """
117  
118          tokens, _ = self.tokenize(text)
119          return self.expression(tokens)
120  
121      def tokenize(self, query):
122          """
123          Tokenizes SQL query into tokens.
124  
125          Args:
126              query: input query
127  
128          Returns:
129              (tokenized query, token positions)
130          """
131  
132          # Build a simple SQL lexer
133          #   - Punctuation chars are parsed as standalone tokens which helps identify operators
134          #   - Add additional wordchars to prevent splitting on those values
135          #   - Disable comments
136          tokens = shlex(StringIO(query), punctuation_chars="=!<>+-*/%|")
137          tokens.wordchars += ":@#"
138          tokens.commenters = ""
139          tokens = list(tokens)
140  
141          # Identify sql clause token positions
142          positions = {}
143  
144          # Get position of clause keywords. For multi-term clauses, validate next token matches as well
145          for x, token in enumerate(tokens):
146              t = token.lower()
147              if t not in positions and t in SQL.CLAUSES and (t not in ["group", "order"] or (x + 1 < len(tokens) and tokens[x + 1].lower() == "by")):
148                  positions[t] = x
149  
150          return (tokens, positions)
151  
152      def parse(self, tokens, positions, name, offset=1, alias=False, aliases=None, similar=None):
153          """
154          Runs query column name to database column name mappings for clauses. This method will also
155          parse SIMILAR() function calls, extract parameters for those calls and leave a placeholder
156          to be filled in with similarity results.
157  
158          Args:
159              tokens: query tokens
160              positions: token positions - used to locate the start of sql clauses
161              name: current query clause name
162              offset: how many tokens are in the clause name
163              alias: True if terms in the clause should be aliased (i.e. column as alias)
164              aliases: dict of generated aliases, if present these tokens should NOT be resolved
165              similar: list where parsed similar clauses should be stored
166  
167          Returns:
168              formatted clause
169          """
170  
171          clause = None
172          if name in positions:
173              # Find the next clause token
174              end = [positions.get(x, len(tokens)) for x in SQL.CLAUSES[SQL.CLAUSES.index(name) + 1 :]]
175              end = min(end) if end else len(tokens)
176  
177              # Start after current clause token and end before next clause or end of string
178              clause = tokens[positions[name] + offset : end]
179  
180              # Parse and resolve parameters
181              clause = self.expression(clause, alias, aliases, similar)
182  
183          return clause
184  
185  
186  class SQLError(Exception):
187      """
188      Raised for errors generated by user SQL queries
189      """