/ src / python / txtai / database / sql / aggregate.py
aggregate.py
  1  """
  2  Aggregate module
  3  """
  4  
  5  import itertools
  6  import operator
  7  
  8  from .base import SQL
  9  
 10  
 11  class Aggregate(SQL):
 12      """
 13      Aggregates partial results from queries. Partial results come from queries when working with sharded indexes.
 14      """
 15  
 16      def __init__(self, database=None):
 17          # Always return token lists as this method requires them
 18          super().__init__(database, True)
 19  
 20      def __call__(self, query, results):
 21          """
 22          Analyzes query results, combines aggregate function results and applies ordering.
 23  
 24          Args:
 25              query: input query
 26              results: query results
 27  
 28          Returns:
 29              aggregated query results
 30          """
 31  
 32          # Parse query
 33          query = super().__call__(query)
 34  
 35          # Check if this is a SQL query
 36          if "select" in query:
 37              # Get list of unique and aggregate columns. If no aggregate columns or order by found, skip
 38              columns = list(results[0].keys())
 39              aggcolumns = self.aggcolumns(columns)
 40              if aggcolumns or query["orderby"]:
 41                  # Merge aggregate columns
 42                  if aggcolumns:
 43                      results = self.aggregate(query, results, columns, aggcolumns)
 44  
 45                  # Sort results and return
 46                  return self.orderby(query, results) if query["orderby"] else self.defaultsort(results)
 47  
 48          # Otherwise, run default sort
 49          return self.defaultsort(results)
 50  
 51      def aggcolumns(self, columns):
 52          """
 53          Filters columns for columns that have an aggregate function call.
 54  
 55          Args:
 56              columns: list of columns
 57  
 58          Returns:
 59              list of aggregate columns
 60          """
 61  
 62          aggregates = {}
 63          for column in columns:
 64              column = column.lower()
 65              if column.startswith(("count(", "sum(", "total(")):
 66                  aggregates[column] = sum
 67              elif column.startswith("max("):
 68                  aggregates[column] = max
 69              elif column.startswith("min("):
 70                  aggregates[column] = min
 71              elif column.startswith("avg("):
 72                  aggregates[column] = lambda x: sum(x) / len(x)
 73  
 74          return aggregates
 75  
 76      def aggregate(self, query, results, columns, aggcolumns):
 77          """
 78          Merges aggregate columns in results.
 79  
 80          Args:
 81              query: input query
 82              results: query results
 83              columns: list of select columns
 84              aggcolumns: list of aggregate columns
 85  
 86          Returns:
 87              results with aggregates merged
 88          """
 89  
 90          # Group data, if necessary
 91          if query["groupby"]:
 92              results = self.groupby(query, results, columns)
 93          else:
 94              results = [results]
 95  
 96          # Compute column values
 97          rows = []
 98          for result in results:
 99              # Calculate/copy column values
100              row = {}
101              for column in columns:
102                  if column in aggcolumns:
103                      # Calculate aggregate value
104                      function = aggcolumns[column]
105                      row[column] = function([r[column] for r in result])
106                  else:
107                      # Non aggregate column value repeat, use first value
108                      row[column] = result[0][column]
109  
110              # Add row using original query columns
111              rows.append(row)
112  
113          return rows
114  
115      def groupby(self, query, results, columns):
116          """
117          Groups results using query group by clause.
118  
119          Args:
120              query: input query
121              results: query results
122              columns: list of select columns
123  
124          Returns:
125              results grouped using group by clause
126          """
127  
128          groupby = [column for column in columns if column.lower() in query["groupby"]]
129          if groupby:
130              results = sorted(results, key=operator.itemgetter(*groupby))
131              return [list(value) for _, value in itertools.groupby(results, operator.itemgetter(*groupby))]
132  
133          return [results]
134  
135      def orderby(self, query, results):
136          """
137          Applies an order by clause to results.
138  
139          Args:
140              query: input query
141              results: query results
142  
143          Returns:
144              results ordered using order by clause
145          """
146  
147          # Sort in reverse order
148          for clause in query["orderby"][::-1]:
149              # Order by columns must be selected
150              reverse = False
151              if clause.lower().endswith(" asc"):
152                  clause = clause.rsplit(" ")[0]
153              elif clause.lower().endswith(" desc"):
154                  clause = clause.rsplit(" ")[0]
155                  reverse = True
156  
157              # Order by columns must be in select clause
158              if clause in query["select"]:
159                  results = sorted(results, key=operator.itemgetter(clause), reverse=reverse)
160  
161          return results
162  
163      def defaultsort(self, results):
164          """
165          Default sorting algorithm for results. Sorts by score descending, if available.
166  
167          Args:
168              results: query results
169  
170          Returns:
171              results ordered by score descending
172          """
173  
174          # Sort standard query using score column, if present
175          if results and "score" in results[0]:
176              return sorted(results, key=lambda x: x["score"], reverse=True)
177  
178          return results