/ src / vector_stores / qdrant.py
qdrant.py
  1  """Qdrant vector store implementation."""
  2  from __future__ import annotations
  3  
  4  import logging
  5  from typing import Any
  6  
  7  from langchain_core.documents import Document
  8  from langchain_qdrant import QdrantVectorStore
  9  from qdrant_client import QdrantClient
 10  from qdrant_client.http.exceptions import UnexpectedResponse
 11  from qdrant_client.models import Distance, PointStruct, VectorParams
 12  
 13  from ..constants import DEFAULT_RETRIEVAL_K
 14  from ..embeddings.protocol import Embeddings
 15  from .constants import DEFAULT_COLLECTION_NAME
 16  from .protocol import VectorStore
 17  
 18  
 19  class QdrantVectorStoreWrapper:
 20      """Wrapper for QdrantVectorStore to support pre-computed embeddings in add_texts."""
 21  
 22      def __init__(
 23          self,
 24          qdrant_store: QdrantVectorStore,
 25          client: QdrantClient,
 26          collection_name: str,
 27      ):
 28          """Initialize the wrapper.
 29  
 30          Parameters
 31          ----------
 32          qdrant_store
 33              The underlying QdrantVectorStore instance.
 34          client
 35              The QdrantClient instance for direct access.
 36          collection_name
 37              Name of the Qdrant collection.
 38          """
 39          self._logger = logging.getLogger(__name__)
 40          self._logger.info(f"Initializing QdrantVectorStoreWrapper for collection: {collection_name}")
 41  
 42          self._store = qdrant_store
 43          self._client = client
 44          self._collection_name = collection_name
 45          self._pending_points: list[PointStruct] = []
 46  
 47      def add_texts(
 48          self,
 49          texts: list[str],
 50          metadatas: list[dict[str, Any]] | None = None,
 51          ids: list[int] | None = None,
 52          embeddings: list[list[float]] | None = None,
 53      ) -> None:
 54          """Add texts to the vector store with pre-computed embeddings.
 55  
 56          Points are accumulated and will be persisted when persist() is called.
 57          This allows for efficient batching of multiple add_texts calls.
 58  
 59          Parameters
 60          ----------
 61          texts
 62              List of text strings to add.
 63          metadatas
 64              Optional list of metadata dictionaries.
 65          ids
 66              Optional list of document IDs.
 67          embeddings
 68              Pre-computed embedding vectors (always provided in this pipeline).
 69          """
 70          if ids is None:
 71              ids = list(range(len(texts)))
 72          if metadatas is None:
 73              metadatas = [{}] * len(texts)
 74  
 75          for doc_id, text, embedding, metadata in zip(ids, texts, embeddings, metadatas, strict=False):
 76              vector = list(embedding)
 77  
 78              payload = {"page_content": text, **(metadata or {})}
 79  
 80              point = PointStruct(
 81                  id=doc_id,
 82                  vector=vector,
 83                  payload=payload,
 84              )
 85              self._pending_points.append(point)
 86  
 87      def _ensure_collection_exists(self) -> None:
 88          """Ensure the Qdrant collection exists, raise helpful error if not."""
 89          try:
 90              self._client.get_collection(self._collection_name)
 91          except UnexpectedResponse as e:
 92              if e.status_code == 404:
 93                  raise ValueError(
 94                      f"Qdrant collection '{self._collection_name}' does not exist. "
 95                      f"Please run ingestion first to create the collection and add documents. "
 96                  ) from e
 97              raise
 98  
 99      def similarity_search(
100          self,
101          query: str,
102          k: int = DEFAULT_RETRIEVAL_K,
103          filter: Any | None = None,
104      ) -> list[Document]:
105          """Search for similar documents."""
106          self._ensure_collection_exists()
107          if filter is not None:
108              return self._store.similarity_search(query, k=k, filter=filter)
109          return self._store.similarity_search(query, k=k)
110  
111      def similarity_search_with_score(
112          self,
113          query: str,
114          k: int = DEFAULT_RETRIEVAL_K,
115          filter: Any | None = None,
116      ) -> list[tuple[Document, float]]:
117          """Search for similar documents with similarity scores."""
118          self._ensure_collection_exists()
119          if filter is not None:
120              results = self._store.similarity_search_with_score(query, k=k, filter=filter)
121          else:
122              results = self._store.similarity_search_with_score(query, k=k)
123  
124          # Fix metadata mapping - LangChain doesn't map all payload fields to metadata
125          fixed_results = []
126          for doc, score in results:
127              point_id = doc.metadata.get('_id')
128              if point_id:
129                  try:
130                      points = self._client.retrieve(
131                          collection_name=self._collection_name,
132                          ids=[point_id],
133                          with_payload=True,
134                      )
135                      if points:
136                          payload = points[0].payload
137                          doc.metadata = {k: v for k, v in payload.items() if k != 'page_content'}
138                  except Exception as e:
139                      self._logger.info(f"Could not fix metadata for point {point_id}: {e}")
140  
141              fixed_results.append((doc, score))
142  
143          return fixed_results
144  
145      def add_documents(self, documents: list[Document]) -> list[str]:
146          """Add documents to the vector store."""
147          return self._store.add_documents(documents)
148  
149      def persist(self) -> None:
150          """Persist accumulated points to the Qdrant server.
151  
152          This method performs the actual upsert operation with all points
153          that were accumulated via add_texts() calls. After persisting,
154          the pending points buffer is cleared.
155          """
156          if not self._pending_points:
157              return
158  
159          self._logger.info(f"Persisting {len(self._pending_points)} points to Qdrant")
160  
161          self._client.upsert(
162              collection_name=self._collection_name,
163              points=self._pending_points,
164          )
165  
166          self._pending_points.clear()
167  
168  def create_qdrant_store(config: dict[str, Any]) -> VectorStore:
169      """Create a Qdrant vector store from configuration.
170  
171      Parameters
172      ----------
173      config
174          Configuration dictionary with keys:
175          - embedding_function: Embeddings (required) - Embedding model instance
176          - url: str (optional) - URL to Qdrant server (default: http://localhost:6333)
177          - collection_name: str (optional) - Name of the collection
178  
179      Returns
180      -------
181      VectorStore instance.
182  
183      Raises
184      ------
185      ValueError
186          If embedding_function is not provided.
187      ImportError
188          If required packages are not installed.
189      """
190  
191      logger = logging.getLogger(__name__)
192  
193      url = config.get("url", "http://qdrant:6333")
194      collection_name = config.get("collection_name", DEFAULT_COLLECTION_NAME)
195      embedding_function: Embeddings = config.get("embedding_function")
196  
197      if embedding_function is None:
198          raise ValueError(
199              "Qdrant requires an embedding_function. "
200              "Pass it via config: {'embedding_function': embedder.embedding_model}"
201          )
202  
203      client = QdrantClient(url=url, check_compatibility=False)
204  
205      try:
206          client.get_collection(collection_name)
207          logger.info(f"Qdrant collection already exists: {collection_name}")
208      except UnexpectedResponse as e:
209          if e.status_code == 404:
210              test_embedding = embedding_function.embed_query("test")
211              embedding_dim = len(test_embedding)
212  
213              logger.info(f"Creating Qdrant collection: {collection_name}")
214              logger.info(f"Embedding dimension: {embedding_dim}")
215  
216              client.create_collection(
217                  collection_name=collection_name,
218                  vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
219              )
220          else:
221              raise
222      except Exception:
223          logger.exception("Failed to create Qdrant collection")
224          raise
225  
226      qdrant_store = QdrantVectorStore(
227          client=client,
228          embedding=embedding_function,
229          collection_name=collection_name,
230      )
231  
232      return QdrantVectorStoreWrapper(
233          qdrant_store=qdrant_store,
234          client=client,
235          collection_name=collection_name,
236      )