aggregate.py
1 """ 2 Aggregate module 3 """ 4 5 import itertools 6 import operator 7 8 from .base import SQL 9 10 11 class Aggregate(SQL): 12 """ 13 Aggregates partial results from queries. Partial results come from queries when working with sharded indexes. 14 """ 15 16 def __init__(self, database=None): 17 # Always return token lists as this method requires them 18 super().__init__(database, True) 19 20 def __call__(self, query, results): 21 """ 22 Analyzes query results, combines aggregate function results and applies ordering. 23 24 Args: 25 query: input query 26 results: query results 27 28 Returns: 29 aggregated query results 30 """ 31 32 # Parse query 33 query = super().__call__(query) 34 35 # Check if this is a SQL query 36 if "select" in query: 37 # Get list of unique and aggregate columns. If no aggregate columns or order by found, skip 38 columns = list(results[0].keys()) 39 aggcolumns = self.aggcolumns(columns) 40 if aggcolumns or query["orderby"]: 41 # Merge aggregate columns 42 if aggcolumns: 43 results = self.aggregate(query, results, columns, aggcolumns) 44 45 # Sort results and return 46 return self.orderby(query, results) if query["orderby"] else self.defaultsort(results) 47 48 # Otherwise, run default sort 49 return self.defaultsort(results) 50 51 def aggcolumns(self, columns): 52 """ 53 Filters columns for columns that have an aggregate function call. 54 55 Args: 56 columns: list of columns 57 58 Returns: 59 list of aggregate columns 60 """ 61 62 aggregates = {} 63 for column in columns: 64 column = column.lower() 65 if column.startswith(("count(", "sum(", "total(")): 66 aggregates[column] = sum 67 elif column.startswith("max("): 68 aggregates[column] = max 69 elif column.startswith("min("): 70 aggregates[column] = min 71 elif column.startswith("avg("): 72 aggregates[column] = lambda x: sum(x) / len(x) 73 74 return aggregates 75 76 def aggregate(self, query, results, columns, aggcolumns): 77 """ 78 Merges aggregate columns in results. 79 80 Args: 81 query: input query 82 results: query results 83 columns: list of select columns 84 aggcolumns: list of aggregate columns 85 86 Returns: 87 results with aggregates merged 88 """ 89 90 # Group data, if necessary 91 if query["groupby"]: 92 results = self.groupby(query, results, columns) 93 else: 94 results = [results] 95 96 # Compute column values 97 rows = [] 98 for result in results: 99 # Calculate/copy column values 100 row = {} 101 for column in columns: 102 if column in aggcolumns: 103 # Calculate aggregate value 104 function = aggcolumns[column] 105 row[column] = function([r[column] for r in result]) 106 else: 107 # Non aggregate column value repeat, use first value 108 row[column] = result[0][column] 109 110 # Add row using original query columns 111 rows.append(row) 112 113 return rows 114 115 def groupby(self, query, results, columns): 116 """ 117 Groups results using query group by clause. 118 119 Args: 120 query: input query 121 results: query results 122 columns: list of select columns 123 124 Returns: 125 results grouped using group by clause 126 """ 127 128 groupby = [column for column in columns if column.lower() in query["groupby"]] 129 if groupby: 130 results = sorted(results, key=operator.itemgetter(*groupby)) 131 return [list(value) for _, value in itertools.groupby(results, operator.itemgetter(*groupby))] 132 133 return [results] 134 135 def orderby(self, query, results): 136 """ 137 Applies an order by clause to results. 138 139 Args: 140 query: input query 141 results: query results 142 143 Returns: 144 results ordered using order by clause 145 """ 146 147 # Sort in reverse order 148 for clause in query["orderby"][::-1]: 149 # Order by columns must be selected 150 reverse = False 151 if clause.lower().endswith(" asc"): 152 clause = clause.rsplit(" ")[0] 153 elif clause.lower().endswith(" desc"): 154 clause = clause.rsplit(" ")[0] 155 reverse = True 156 157 # Order by columns must be in select clause 158 if clause in query["select"]: 159 results = sorted(results, key=operator.itemgetter(clause), reverse=reverse) 160 161 return results 162 163 def defaultsort(self, results): 164 """ 165 Default sorting algorithm for results. Sorts by score descending, if available. 166 167 Args: 168 results: query results 169 170 Returns: 171 results ordered by score descending 172 """ 173 174 # Sort standard query using score column, if present 175 if results and "score" in results[0]: 176 return sorted(results, key=lambda x: x["score"], reverse=True) 177 178 return results