/ tests / test_cache.py
test_cache.py
  1  """Tests for the project response cache system."""
  2  
  3  import random
  4  import pytest
  5  from fastapi.testclient import TestClient
  6  
  7  from restai.config import RESTAI_DEFAULT_PASSWORD
  8  from restai.main import app
  9  
 10  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 11  
 12  suffix = str(random.randint(0, 10000000))
 13  user_name = f"cache_user_{suffix}"
 14  user_pass = "cache_pass_123"
 15  team_name = f"cache_team_{suffix}"
 16  llm_name = f"cache_llm_{suffix}"
 17  project_name = f"cache-proj-{suffix}"
 18  
 19  team_id = None
 20  project_id = None
 21  
 22  
 23  @pytest.fixture(scope="module")
 24  def client():
 25      with TestClient(app) as c:
 26          yield c
 27  
 28  
 29  def test_cache_setup(client):
 30      """Create user, team, LLM, and a block project with cache enabled."""
 31      global team_id, project_id
 32  
 33      r = client.post(
 34          "/users",
 35          json={"username": user_name, "password": user_pass, "is_admin": False, "is_private": False},
 36          auth=ADMIN,
 37      )
 38      assert r.status_code == 201
 39  
 40      r = client.post(
 41          "/llms",
 42          json={
 43              "name": llm_name,
 44              "class_name": "OpenAI",
 45              "options": {"model": "gpt-test", "api_key": "sk-fake"},
 46              "privacy": "public",
 47          },
 48          auth=ADMIN,
 49      )
 50      assert r.status_code in (200, 201)
 51  
 52      r = client.post(
 53          "/teams",
 54          json={"name": team_name, "users": [user_name], "llms": [llm_name]},
 55          auth=ADMIN,
 56      )
 57      assert r.status_code == 201
 58      team_id = r.json()["id"]
 59  
 60      # Create a block project (no LLM needed, easy to test cache)
 61      r = client.post(
 62          "/projects",
 63          json={
 64              "name": project_name,
 65              "type": "block",
 66              "team_id": team_id,
 67          },
 68          auth=ADMIN,
 69      )
 70      assert r.status_code == 201
 71      project_id = r.json()["project"]
 72  
 73      # Assign to user, enable cache, and set up a passthrough workspace
 74      r = client.patch(
 75          f"/projects/{project_id}",
 76          json={
 77              "users": [user_name],
 78              "options": {
 79                  "cache": True,
 80                  "cache_threshold": 0.85,
 81                  "blockly_workspace": {
 82                      "blocks": {
 83                          "blocks": [
 84                              {
 85                                  "type": "restai_set_output",
 86                                  "inputs": {
 87                                      "VALUE": {
 88                                          "block": {"type": "restai_get_input"}
 89                                      }
 90                                  },
 91                              }
 92                          ]
 93                      },
 94                      "variables": [],
 95                  },
 96              },
 97          },
 98          auth=ADMIN,
 99      )
100      assert r.status_code == 200
101  
102  
103  def test_cache_miss_first_request(client):
104      """First request should not be cached."""
105      r = client.post(
106          f"/projects/{project_id}/question",
107          json={"question": "cache test question alpha"},
108          auth=(user_name, user_pass),
109      )
110      assert r.status_code == 200
111      data = r.json()
112      assert data.get("answer") == "cache test question alpha"
113      assert data.get("cached") is not True
114  
115  
116  def test_cache_hit_same_question(client):
117      """Same question should return cached result."""
118      r = client.post(
119          f"/projects/{project_id}/question",
120          json={"question": "cache test question alpha"},
121          auth=(user_name, user_pass),
122      )
123      assert r.status_code == 200
124      data = r.json()
125      assert data.get("cached") is True
126      assert data.get("answer") == "cache test question alpha"
127  
128  
129  def test_cache_miss_different_question(client):
130      """A completely different question should not hit cache."""
131      r = client.post(
132          f"/projects/{project_id}/question",
133          json={"question": "something completely unrelated xyz 12345"},
134          auth=(user_name, user_pass),
135      )
136      assert r.status_code == 200
137      data = r.json()
138      # Should not be a cache hit for a very different question
139      # (may or may not be cached depending on embedding similarity —
140      # but with default chromadb embeddings this should miss)
141      assert data.get("answer") == "something completely unrelated xyz 12345"
142  
143  
144  def test_cache_clear_endpoint(client):
145      """DELETE /projects/{id}/cache should clear the cache."""
146      # Clear cache
147      r = client.delete(
148          f"/projects/{project_id}/cache",
149          auth=ADMIN,
150      )
151      assert r.status_code == 200
152      assert r.json().get("cleared") is True
153  
154      # Same question that was cached should now miss
155      r = client.post(
156          f"/projects/{project_id}/question",
157          json={"question": "cache test question alpha"},
158          auth=(user_name, user_pass),
159      )
160      assert r.status_code == 200
161      data = r.json()
162      assert data.get("cached") is not True
163  
164  
165  def test_cache_clear_when_not_enabled(client):
166      """Clearing cache on a project without cache should return cleared=False."""
167      # Disable cache
168      r = client.patch(
169          f"/projects/{project_id}",
170          json={"options": {"cache": False}},
171          auth=ADMIN,
172      )
173      assert r.status_code == 200
174  
175      r = client.delete(
176          f"/projects/{project_id}/cache",
177          auth=ADMIN,
178      )
179      assert r.status_code == 200
180      assert r.json().get("cleared") is False
181  
182  
183  def test_cache_default_threshold():
184      """Default cache_threshold should be 0.85."""
185      from restai.models.models import ProjectOptions
186      opts = ProjectOptions()
187      assert opts.cache_threshold == 0.85
188  
189  
190  def test_cache_threshold_bounds():
191      """cache_threshold should be bounded 0.0 to 1.0."""
192      from restai.models.models import ProjectOptions
193      import pydantic
194  
195      # Valid values
196      ProjectOptions(cache_threshold=0.0)
197      ProjectOptions(cache_threshold=1.0)
198      ProjectOptions(cache_threshold=0.5)
199  
200      # Invalid values
201      try:
202          ProjectOptions(cache_threshold=1.5)
203          assert False, "Should reject threshold > 1.0"
204      except pydantic.ValidationError:
205          pass
206  
207      try:
208          ProjectOptions(cache_threshold=-0.1)
209          assert False, "Should reject threshold < 0.0"
210      except pydantic.ValidationError:
211          pass
212  
213  
214  def test_cache_teardown(client):
215      """Clean up resources."""
216      if project_id:
217          client.delete(f"/projects/{project_id}", auth=ADMIN)
218      if team_id:
219          client.delete(f"/teams/{team_id}", auth=ADMIN)
220      client.delete(f"/users/{user_name}", auth=ADMIN)
221      client.delete(f"/llms/{llm_name}", auth=ADMIN)