embeddings.py
1 """ 2 Defines API paths for embeddings endpoints. 3 """ 4 5 from io import BytesIO 6 from typing import List, Optional 7 8 import PIL 9 10 from fastapi import APIRouter, Body, File, Form, HTTPException, Request, UploadFile 11 from fastapi.encoders import jsonable_encoder 12 13 from .. import application 14 from ..responses import ResponseFactory 15 from ..route import EncodingAPIRoute 16 17 from ...app import ReadOnlyError 18 from ...graph import Graph 19 20 router = APIRouter(route_class=EncodingAPIRoute) 21 22 23 @router.get("/search") 24 def search(query: str, request: Request): 25 """ 26 Finds documents most similar to the input query. This method will run either an index search 27 or an index + database search depending on if a database is available. 28 29 Args: 30 query: input query 31 request: FastAPI request 32 33 Returns: 34 list of {id: value, score: value} for index search, list of dict for an index + database search 35 """ 36 37 # Execute search 38 results = application.get().search(query, request=request) 39 40 # Encode using standard FastAPI encoder but skip certain classes 41 results = jsonable_encoder( 42 results, custom_encoder={bytes: lambda x: x, BytesIO: lambda x: x, PIL.Image.Image: lambda x: x, Graph: lambda x: x.savedict()} 43 ) 44 45 # Return raw response to prevent duplicate encoding 46 response = ResponseFactory.create(request) 47 return response(results) 48 49 50 # pylint: disable=W0621 51 @router.post("/batchsearch") 52 def batchsearch( 53 request: Request, 54 queries: List[str] = Body(...), 55 limit: int = Body(default=None), 56 weights: float = Body(default=None), 57 index: str = Body(default=None), 58 parameters: List[dict] = Body(default=None), 59 graph: bool = Body(default=False), 60 ): 61 """ 62 Finds documents most similar to the input queries. This method will run either an index search 63 or an index + database search depending on if a database is available. 64 65 Args: 66 queries: input queries 67 limit: maximum results 68 weights: hybrid score weights, if applicable 69 index: index name, if applicable 70 parameters: list of dicts of named parameters to bind to placeholders 71 graph: return graph results if True 72 73 Returns: 74 list of {id: value, score: value} per query for index search, list of dict per query for an index + database search 75 """ 76 77 # Execute search 78 results = application.get().batchsearch(queries, limit, weights, index, parameters, graph) 79 80 # Encode using standard FastAPI encoder but skip certain classes 81 results = jsonable_encoder( 82 results, custom_encoder={bytes: lambda x: x, BytesIO: lambda x: x, PIL.Image.Image: lambda x: x, Graph: lambda x: x.savedict()} 83 ) 84 85 # Return raw response to prevent duplicate encoding 86 response = ResponseFactory.create(request) 87 return response(results) 88 89 90 @router.post("/add") 91 def add(documents: List[dict] = Body(...)): 92 """ 93 Adds a batch of documents for indexing. 94 95 Args: 96 documents: list of {id: value, text: value, tags: value} 97 """ 98 99 try: 100 application.get().add(documents) 101 except ReadOnlyError as e: 102 raise HTTPException(status_code=403, detail=e.args[0]) from e 103 104 105 @router.post("/addobject") 106 def addobject(data: List[bytes] = File(), uid: List[str] = Form(default=None), field: str = Form(default=None)): 107 """ 108 Adds a batch of binary documents for indexing. 109 110 Args: 111 data: list of binary objects 112 uid: list of corresponding ids 113 field: optional object field name 114 """ 115 116 if uid and len(data) != len(uid): 117 raise HTTPException(status_code=422, detail="Length of data and document lists must match") 118 119 try: 120 # Add objects 121 application.get().addobject(data, uid, field) 122 except ReadOnlyError as e: 123 raise HTTPException(status_code=403, detail=e.args[0]) from e 124 125 126 @router.post("/addimage") 127 def addimage(data: List[UploadFile] = File(), uid: List[str] = Form(), field: str = Form(default=None)): 128 """ 129 Adds a batch of images for indexing. 130 131 Args: 132 data: list of images 133 uid: list of corresponding ids 134 field: optional object field name 135 """ 136 137 if uid and len(data) != len(uid): 138 raise HTTPException(status_code=422, detail="Length of data and uid lists must match") 139 140 try: 141 # Add images 142 application.get().addobject([PIL.Image.open(content.file) for content in data], uid, field) 143 except ReadOnlyError as e: 144 raise HTTPException(status_code=403, detail=e.args[0]) from e 145 146 147 @router.get("/index") 148 def index(): 149 """ 150 Builds an embeddings index for previously batched documents. 151 """ 152 153 try: 154 application.get().index() 155 except ReadOnlyError as e: 156 raise HTTPException(status_code=403, detail=e.args[0]) from e 157 158 159 @router.get("/upsert") 160 def upsert(): 161 """ 162 Runs an embeddings upsert operation for previously batched documents. 163 """ 164 165 try: 166 application.get().upsert() 167 except ReadOnlyError as e: 168 raise HTTPException(status_code=403, detail=e.args[0]) from e 169 170 171 @router.post("/delete") 172 def delete(ids: List = Body(...)): 173 """ 174 Deletes from an embeddings index. Returns list of ids deleted. 175 176 Args: 177 ids: list of ids to delete 178 179 Returns: 180 ids deleted 181 """ 182 183 try: 184 return application.get().delete(ids) 185 except ReadOnlyError as e: 186 raise HTTPException(status_code=403, detail=e.args[0]) from e 187 188 189 @router.post("/reindex") 190 def reindex(config: dict = Body(...), function: str = Body(default=None)): 191 """ 192 Recreates this embeddings index using config. This method only works if document content storage is enabled. 193 194 Args: 195 config: new config 196 function: optional function to prepare content for indexing 197 """ 198 199 try: 200 application.get().reindex(config, function) 201 except ReadOnlyError as e: 202 raise HTTPException(status_code=403, detail=e.args[0]) from e 203 204 205 @router.get("/count") 206 def count(): 207 """ 208 Total number of elements in this embeddings index. 209 210 Returns: 211 number of elements in embeddings index 212 """ 213 214 return application.get().count() 215 216 217 @router.post("/explain") 218 def explain(query: str = Body(...), texts: List[str] = Body(default=None), limit: int = Body(default=None)): 219 """ 220 Explains the importance of each input token in text for a query. 221 222 Args: 223 query: query text 224 texts: list of text 225 226 Returns: 227 list of dict where a higher scores represents higher importance relative to the query 228 """ 229 230 return application.get().explain(query, texts, limit) 231 232 233 @router.post("/batchexplain") 234 def batchexplain(queries: List[str] = Body(...), texts: List[str] = Body(default=None), limit: int = Body(default=None)): 235 """ 236 Explains the importance of each input token in text for a query. 237 238 Args: 239 query: query text 240 texts: list of text 241 242 Returns: 243 list of dict where a higher scores represents higher importance relative to the query 244 """ 245 246 return application.get().batchexplain(queries, texts, limit) 247 248 249 @router.get("/transform") 250 def transform(text: str, category: Optional[str] = None, index: Optional[str] = None): 251 """ 252 Transforms text into an embeddings array. 253 254 Args: 255 text: input text 256 category: category for instruction-based embeddings 257 index: index name, if applicable 258 259 Returns: 260 embeddings array 261 """ 262 263 return application.get().transform(text, category, index) 264 265 266 @router.post("/batchtransform") 267 def batchtransform(texts: List[str] = Body(...), category: Optional[str] = None, index: Optional[str] = None): 268 """ 269 Transforms list of text into embeddings arrays. 270 271 Args: 272 texts: list of text 273 category: category for instruction-based embeddings 274 index: index name, if applicable 275 276 Returns: 277 embeddings arrays 278 """ 279 280 return application.get().batchtransform(texts, category, index)