qdrant.py
1 """Qdrant vector store implementation.""" 2 from __future__ import annotations 3 4 import logging 5 from typing import Any 6 7 from langchain_core.documents import Document 8 from langchain_qdrant import QdrantVectorStore 9 from qdrant_client import QdrantClient 10 from qdrant_client.http.exceptions import UnexpectedResponse 11 from qdrant_client.models import Distance, PointStruct, VectorParams 12 13 from ..constants import DEFAULT_RETRIEVAL_K 14 from ..embeddings.protocol import Embeddings 15 from .constants import DEFAULT_COLLECTION_NAME 16 from .protocol import VectorStore 17 18 19 class QdrantVectorStoreWrapper: 20 """Wrapper for QdrantVectorStore to support pre-computed embeddings in add_texts.""" 21 22 def __init__( 23 self, 24 qdrant_store: QdrantVectorStore, 25 client: QdrantClient, 26 collection_name: str, 27 ): 28 """Initialize the wrapper. 29 30 Parameters 31 ---------- 32 qdrant_store 33 The underlying QdrantVectorStore instance. 34 client 35 The QdrantClient instance for direct access. 36 collection_name 37 Name of the Qdrant collection. 38 """ 39 self._logger = logging.getLogger(__name__) 40 self._logger.info(f"Initializing QdrantVectorStoreWrapper for collection: {collection_name}") 41 42 self._store = qdrant_store 43 self._client = client 44 self._collection_name = collection_name 45 self._pending_points: list[PointStruct] = [] 46 47 def add_texts( 48 self, 49 texts: list[str], 50 metadatas: list[dict[str, Any]] | None = None, 51 ids: list[int] | None = None, 52 embeddings: list[list[float]] | None = None, 53 ) -> None: 54 """Add texts to the vector store with pre-computed embeddings. 55 56 Points are accumulated and will be persisted when persist() is called. 57 This allows for efficient batching of multiple add_texts calls. 58 59 Parameters 60 ---------- 61 texts 62 List of text strings to add. 63 metadatas 64 Optional list of metadata dictionaries. 65 ids 66 Optional list of document IDs. 67 embeddings 68 Pre-computed embedding vectors (always provided in this pipeline). 69 """ 70 if ids is None: 71 ids = list(range(len(texts))) 72 if metadatas is None: 73 metadatas = [{}] * len(texts) 74 75 for doc_id, text, embedding, metadata in zip(ids, texts, embeddings, metadatas, strict=False): 76 vector = list(embedding) 77 78 payload = {"page_content": text, **(metadata or {})} 79 80 point = PointStruct( 81 id=doc_id, 82 vector=vector, 83 payload=payload, 84 ) 85 self._pending_points.append(point) 86 87 def _ensure_collection_exists(self) -> None: 88 """Ensure the Qdrant collection exists, raise helpful error if not.""" 89 try: 90 self._client.get_collection(self._collection_name) 91 except UnexpectedResponse as e: 92 if e.status_code == 404: 93 raise ValueError( 94 f"Qdrant collection '{self._collection_name}' does not exist. " 95 f"Please run ingestion first to create the collection and add documents. " 96 ) from e 97 raise 98 99 def similarity_search( 100 self, 101 query: str, 102 k: int = DEFAULT_RETRIEVAL_K, 103 filter: Any | None = None, 104 ) -> list[Document]: 105 """Search for similar documents.""" 106 self._ensure_collection_exists() 107 if filter is not None: 108 return self._store.similarity_search(query, k=k, filter=filter) 109 return self._store.similarity_search(query, k=k) 110 111 def similarity_search_with_score( 112 self, 113 query: str, 114 k: int = DEFAULT_RETRIEVAL_K, 115 filter: Any | None = None, 116 ) -> list[tuple[Document, float]]: 117 """Search for similar documents with similarity scores.""" 118 self._ensure_collection_exists() 119 if filter is not None: 120 results = self._store.similarity_search_with_score(query, k=k, filter=filter) 121 else: 122 results = self._store.similarity_search_with_score(query, k=k) 123 124 # Fix metadata mapping - LangChain doesn't map all payload fields to metadata 125 fixed_results = [] 126 for doc, score in results: 127 point_id = doc.metadata.get('_id') 128 if point_id: 129 try: 130 points = self._client.retrieve( 131 collection_name=self._collection_name, 132 ids=[point_id], 133 with_payload=True, 134 ) 135 if points: 136 payload = points[0].payload 137 doc.metadata = {k: v for k, v in payload.items() if k != 'page_content'} 138 except Exception as e: 139 self._logger.info(f"Could not fix metadata for point {point_id}: {e}") 140 141 fixed_results.append((doc, score)) 142 143 return fixed_results 144 145 def add_documents(self, documents: list[Document]) -> list[str]: 146 """Add documents to the vector store.""" 147 return self._store.add_documents(documents) 148 149 def persist(self) -> None: 150 """Persist accumulated points to the Qdrant server. 151 152 This method performs the actual upsert operation with all points 153 that were accumulated via add_texts() calls. After persisting, 154 the pending points buffer is cleared. 155 """ 156 if not self._pending_points: 157 return 158 159 self._logger.info(f"Persisting {len(self._pending_points)} points to Qdrant") 160 161 self._client.upsert( 162 collection_name=self._collection_name, 163 points=self._pending_points, 164 ) 165 166 self._pending_points.clear() 167 168 def create_qdrant_store(config: dict[str, Any]) -> VectorStore: 169 """Create a Qdrant vector store from configuration. 170 171 Parameters 172 ---------- 173 config 174 Configuration dictionary with keys: 175 - embedding_function: Embeddings (required) - Embedding model instance 176 - url: str (optional) - URL to Qdrant server (default: http://localhost:6333) 177 - collection_name: str (optional) - Name of the collection 178 179 Returns 180 ------- 181 VectorStore instance. 182 183 Raises 184 ------ 185 ValueError 186 If embedding_function is not provided. 187 ImportError 188 If required packages are not installed. 189 """ 190 191 logger = logging.getLogger(__name__) 192 193 url = config.get("url", "http://qdrant:6333") 194 collection_name = config.get("collection_name", DEFAULT_COLLECTION_NAME) 195 embedding_function: Embeddings = config.get("embedding_function") 196 197 if embedding_function is None: 198 raise ValueError( 199 "Qdrant requires an embedding_function. " 200 "Pass it via config: {'embedding_function': embedder.embedding_model}" 201 ) 202 203 client = QdrantClient(url=url, check_compatibility=False) 204 205 try: 206 client.get_collection(collection_name) 207 logger.info(f"Qdrant collection already exists: {collection_name}") 208 except UnexpectedResponse as e: 209 if e.status_code == 404: 210 test_embedding = embedding_function.embed_query("test") 211 embedding_dim = len(test_embedding) 212 213 logger.info(f"Creating Qdrant collection: {collection_name}") 214 logger.info(f"Embedding dimension: {embedding_dim}") 215 216 client.create_collection( 217 collection_name=collection_name, 218 vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE), 219 ) 220 else: 221 raise 222 except Exception: 223 logger.exception("Failed to create Qdrant collection") 224 raise 225 226 qdrant_store = QdrantVectorStore( 227 client=client, 228 embedding=embedding_function, 229 collection_name=collection_name, 230 ) 231 232 return QdrantVectorStoreWrapper( 233 qdrant_store=qdrant_store, 234 client=client, 235 collection_name=collection_name, 236 )