/ restai / cache.py
cache.py
 1  import math
 2  import shutil
 3  import chromadb
 4  import uuid
 5  
 6  from restai.vectordb.tools import find_embeddings_path
 7  from restai.config import CHROMADB_HOST, CHROMADB_PORT
 8  
 9  # Reuse PersistentClient per cache path within the same worker process.
10  _cache_client_cache = {}
11  
12  
13  def _get_cache_client(path):
14      if CHROMADB_HOST:
15          return chromadb.HttpClient(host=CHROMADB_HOST, port=CHROMADB_PORT)
16      if path not in _cache_client_cache:
17          _cache_client_cache[path] = chromadb.PersistentClient(path=path)
18      return _cache_client_cache[path]
19  
20  
21  class Cache:
22  
23      def __init__(self, project):
24          self.project = project
25          cache_path = find_embeddings_path(self.project.props.name + "_cache")
26          self.client = _get_cache_client(cache_path)
27          self.collection = self.client.get_or_create_collection(
28              name=self.project.props.name + "_cache"
29          )
30  
31      def verify(self, question):
32          results = self.collection.query(
33              query_texts=[question],
34              n_results=1,
35              include=["metadatas", "documents", "distances"],
36          )
37  
38          if len(results["ids"][0]) == 0:
39              return None
40  
41          distance = math.exp(-results["distances"][0][0])
42          threshold = self.project.props.options.cache_threshold
43          if threshold is None:
44              threshold = 0.85
45  
46          if distance > threshold:
47              metadata = results["metadatas"][0][0]
48              return metadata["answer"]
49  
50          return None
51  
52      def add(self, question, answer):
53          self.collection.add(
54              documents=[question],
55              metadatas=[{"question": question, "answer": answer}],
56              ids=[str(uuid.uuid4())],
57          )
58          return True
59  
60      def clear(self):
61          """Clear all cached entries without deleting the cache itself."""
62          try:
63              self.client.delete_collection(self.project.props.name + "_cache")
64              self.collection = self.client.get_or_create_collection(
65                  name=self.project.props.name + "_cache"
66              )
67          except Exception:
68              pass
69  
70      def count(self):
71          """Return the number of cached entries."""
72          return self.collection.count()
73  
74      def delete(self):
75          try:
76              cache_path = find_embeddings_path(self.project.props.name + "_cache")
77              shutil.rmtree(cache_path, ignore_errors=True)
78              _cache_client_cache.pop(cache_path, None)
79          except BaseException:
80              pass