cache.py
1 import math 2 import shutil 3 import chromadb 4 import uuid 5 6 from restai.vectordb.tools import find_embeddings_path 7 from restai.config import CHROMADB_HOST, CHROMADB_PORT 8 9 # Reuse PersistentClient per cache path within the same worker process. 10 _cache_client_cache = {} 11 12 13 def _get_cache_client(path): 14 if CHROMADB_HOST: 15 return chromadb.HttpClient(host=CHROMADB_HOST, port=CHROMADB_PORT) 16 if path not in _cache_client_cache: 17 _cache_client_cache[path] = chromadb.PersistentClient(path=path) 18 return _cache_client_cache[path] 19 20 21 class Cache: 22 23 def __init__(self, project): 24 self.project = project 25 cache_path = find_embeddings_path(self.project.props.name + "_cache") 26 self.client = _get_cache_client(cache_path) 27 self.collection = self.client.get_or_create_collection( 28 name=self.project.props.name + "_cache" 29 ) 30 31 def verify(self, question): 32 results = self.collection.query( 33 query_texts=[question], 34 n_results=1, 35 include=["metadatas", "documents", "distances"], 36 ) 37 38 if len(results["ids"][0]) == 0: 39 return None 40 41 distance = math.exp(-results["distances"][0][0]) 42 threshold = self.project.props.options.cache_threshold 43 if threshold is None: 44 threshold = 0.85 45 46 if distance > threshold: 47 metadata = results["metadatas"][0][0] 48 return metadata["answer"] 49 50 return None 51 52 def add(self, question, answer): 53 self.collection.add( 54 documents=[question], 55 metadatas=[{"question": question, "answer": answer}], 56 ids=[str(uuid.uuid4())], 57 ) 58 return True 59 60 def clear(self): 61 """Clear all cached entries without deleting the cache itself.""" 62 try: 63 self.client.delete_collection(self.project.props.name + "_cache") 64 self.collection = self.client.get_or_create_collection( 65 name=self.project.props.name + "_cache" 66 ) 67 except Exception: 68 pass 69 70 def count(self): 71 """Return the number of cached entries.""" 72 return self.collection.count() 73 74 def delete(self): 75 try: 76 cache_path = find_embeddings_path(self.project.props.name + "_cache") 77 shutil.rmtree(cache_path, ignore_errors=True) 78 _cache_client_cache.pop(cache_path, None) 79 except BaseException: 80 pass