/ src / python / txtai / api / routers / embeddings.py
embeddings.py
  1  """
  2  Defines API paths for embeddings endpoints.
  3  """
  4  
  5  from io import BytesIO
  6  from typing import List, Optional
  7  
  8  import PIL
  9  
 10  from fastapi import APIRouter, Body, File, Form, HTTPException, Request, UploadFile
 11  from fastapi.encoders import jsonable_encoder
 12  
 13  from .. import application
 14  from ..responses import ResponseFactory
 15  from ..route import EncodingAPIRoute
 16  
 17  from ...app import ReadOnlyError
 18  from ...graph import Graph
 19  
 20  router = APIRouter(route_class=EncodingAPIRoute)
 21  
 22  
 23  @router.get("/search")
 24  def search(query: str, request: Request):
 25      """
 26      Finds documents most similar to the input query. This method will run either an index search
 27      or an index + database search depending on if a database is available.
 28  
 29      Args:
 30          query: input query
 31          request: FastAPI request
 32  
 33      Returns:
 34          list of {id: value, score: value} for index search, list of dict for an index + database search
 35      """
 36  
 37      # Execute search
 38      results = application.get().search(query, request=request)
 39  
 40      # Encode using standard FastAPI encoder but skip certain classes
 41      results = jsonable_encoder(
 42          results, custom_encoder={bytes: lambda x: x, BytesIO: lambda x: x, PIL.Image.Image: lambda x: x, Graph: lambda x: x.savedict()}
 43      )
 44  
 45      # Return raw response to prevent duplicate encoding
 46      response = ResponseFactory.create(request)
 47      return response(results)
 48  
 49  
 50  # pylint: disable=W0621
 51  @router.post("/batchsearch")
 52  def batchsearch(
 53      request: Request,
 54      queries: List[str] = Body(...),
 55      limit: int = Body(default=None),
 56      weights: float = Body(default=None),
 57      index: str = Body(default=None),
 58      parameters: List[dict] = Body(default=None),
 59      graph: bool = Body(default=False),
 60  ):
 61      """
 62      Finds documents most similar to the input queries. This method will run either an index search
 63      or an index + database search depending on if a database is available.
 64  
 65      Args:
 66          queries: input queries
 67          limit: maximum results
 68          weights: hybrid score weights, if applicable
 69          index: index name, if applicable
 70          parameters: list of dicts of named parameters to bind to placeholders
 71          graph: return graph results if True
 72  
 73      Returns:
 74          list of {id: value, score: value} per query for index search, list of dict per query for an index + database search
 75      """
 76  
 77      # Execute search
 78      results = application.get().batchsearch(queries, limit, weights, index, parameters, graph)
 79  
 80      # Encode using standard FastAPI encoder but skip certain classes
 81      results = jsonable_encoder(
 82          results, custom_encoder={bytes: lambda x: x, BytesIO: lambda x: x, PIL.Image.Image: lambda x: x, Graph: lambda x: x.savedict()}
 83      )
 84  
 85      # Return raw response to prevent duplicate encoding
 86      response = ResponseFactory.create(request)
 87      return response(results)
 88  
 89  
 90  @router.post("/add")
 91  def add(documents: List[dict] = Body(...)):
 92      """
 93      Adds a batch of documents for indexing.
 94  
 95      Args:
 96          documents: list of {id: value, text: value, tags: value}
 97      """
 98  
 99      try:
100          application.get().add(documents)
101      except ReadOnlyError as e:
102          raise HTTPException(status_code=403, detail=e.args[0]) from e
103  
104  
105  @router.post("/addobject")
106  def addobject(data: List[bytes] = File(), uid: List[str] = Form(default=None), field: str = Form(default=None)):
107      """
108      Adds a batch of binary documents for indexing.
109  
110      Args:
111          data: list of binary objects
112          uid: list of corresponding ids
113          field: optional object field name
114      """
115  
116      if uid and len(data) != len(uid):
117          raise HTTPException(status_code=422, detail="Length of data and document lists must match")
118  
119      try:
120          # Add objects
121          application.get().addobject(data, uid, field)
122      except ReadOnlyError as e:
123          raise HTTPException(status_code=403, detail=e.args[0]) from e
124  
125  
126  @router.post("/addimage")
127  def addimage(data: List[UploadFile] = File(), uid: List[str] = Form(), field: str = Form(default=None)):
128      """
129      Adds a batch of images for indexing.
130  
131      Args:
132          data: list of images
133          uid: list of corresponding ids
134          field: optional object field name
135      """
136  
137      if uid and len(data) != len(uid):
138          raise HTTPException(status_code=422, detail="Length of data and uid lists must match")
139  
140      try:
141          # Add images
142          application.get().addobject([PIL.Image.open(content.file) for content in data], uid, field)
143      except ReadOnlyError as e:
144          raise HTTPException(status_code=403, detail=e.args[0]) from e
145  
146  
147  @router.get("/index")
148  def index():
149      """
150      Builds an embeddings index for previously batched documents.
151      """
152  
153      try:
154          application.get().index()
155      except ReadOnlyError as e:
156          raise HTTPException(status_code=403, detail=e.args[0]) from e
157  
158  
159  @router.get("/upsert")
160  def upsert():
161      """
162      Runs an embeddings upsert operation for previously batched documents.
163      """
164  
165      try:
166          application.get().upsert()
167      except ReadOnlyError as e:
168          raise HTTPException(status_code=403, detail=e.args[0]) from e
169  
170  
171  @router.post("/delete")
172  def delete(ids: List = Body(...)):
173      """
174      Deletes from an embeddings index. Returns list of ids deleted.
175  
176      Args:
177          ids: list of ids to delete
178  
179      Returns:
180          ids deleted
181      """
182  
183      try:
184          return application.get().delete(ids)
185      except ReadOnlyError as e:
186          raise HTTPException(status_code=403, detail=e.args[0]) from e
187  
188  
189  @router.post("/reindex")
190  def reindex(config: dict = Body(...), function: str = Body(default=None)):
191      """
192      Recreates this embeddings index using config. This method only works if document content storage is enabled.
193  
194      Args:
195          config: new config
196          function: optional function to prepare content for indexing
197      """
198  
199      try:
200          application.get().reindex(config, function)
201      except ReadOnlyError as e:
202          raise HTTPException(status_code=403, detail=e.args[0]) from e
203  
204  
205  @router.get("/count")
206  def count():
207      """
208      Total number of elements in this embeddings index.
209  
210      Returns:
211          number of elements in embeddings index
212      """
213  
214      return application.get().count()
215  
216  
217  @router.post("/explain")
218  def explain(query: str = Body(...), texts: List[str] = Body(default=None), limit: int = Body(default=None)):
219      """
220      Explains the importance of each input token in text for a query.
221  
222      Args:
223          query: query text
224          texts: list of text
225  
226      Returns:
227          list of dict where a higher scores represents higher importance relative to the query
228      """
229  
230      return application.get().explain(query, texts, limit)
231  
232  
233  @router.post("/batchexplain")
234  def batchexplain(queries: List[str] = Body(...), texts: List[str] = Body(default=None), limit: int = Body(default=None)):
235      """
236      Explains the importance of each input token in text for a query.
237  
238      Args:
239          query: query text
240          texts: list of text
241  
242      Returns:
243          list of dict where a higher scores represents higher importance relative to the query
244      """
245  
246      return application.get().batchexplain(queries, texts, limit)
247  
248  
249  @router.get("/transform")
250  def transform(text: str, category: Optional[str] = None, index: Optional[str] = None):
251      """
252      Transforms text into an embeddings array.
253  
254      Args:
255          text: input text
256          category: category for instruction-based embeddings
257          index: index name, if applicable
258  
259      Returns:
260          embeddings array
261      """
262  
263      return application.get().transform(text, category, index)
264  
265  
266  @router.post("/batchtransform")
267  def batchtransform(texts: List[str] = Body(...), category: Optional[str] = None, index: Optional[str] = None):
268      """
269      Transforms list of text into embeddings arrays.
270  
271      Args:
272          texts: list of text
273          category: category for instruction-based embeddings
274          index: index name, if applicable
275  
276      Returns:
277          embeddings arrays
278      """
279  
280      return application.get().batchtransform(texts, category, index)