test_cache.py
1 """Tests for the project response cache system.""" 2 3 import random 4 import pytest 5 from fastapi.testclient import TestClient 6 7 from restai.config import RESTAI_DEFAULT_PASSWORD 8 from restai.main import app 9 10 ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD) 11 12 suffix = str(random.randint(0, 10000000)) 13 user_name = f"cache_user_{suffix}" 14 user_pass = "cache_pass_123" 15 team_name = f"cache_team_{suffix}" 16 llm_name = f"cache_llm_{suffix}" 17 project_name = f"cache-proj-{suffix}" 18 19 team_id = None 20 project_id = None 21 22 23 @pytest.fixture(scope="module") 24 def client(): 25 with TestClient(app) as c: 26 yield c 27 28 29 def test_cache_setup(client): 30 """Create user, team, LLM, and a block project with cache enabled.""" 31 global team_id, project_id 32 33 r = client.post( 34 "/users", 35 json={"username": user_name, "password": user_pass, "is_admin": False, "is_private": False}, 36 auth=ADMIN, 37 ) 38 assert r.status_code == 201 39 40 r = client.post( 41 "/llms", 42 json={ 43 "name": llm_name, 44 "class_name": "OpenAI", 45 "options": {"model": "gpt-test", "api_key": "sk-fake"}, 46 "privacy": "public", 47 }, 48 auth=ADMIN, 49 ) 50 assert r.status_code in (200, 201) 51 52 r = client.post( 53 "/teams", 54 json={"name": team_name, "users": [user_name], "llms": [llm_name]}, 55 auth=ADMIN, 56 ) 57 assert r.status_code == 201 58 team_id = r.json()["id"] 59 60 # Create a block project (no LLM needed, easy to test cache) 61 r = client.post( 62 "/projects", 63 json={ 64 "name": project_name, 65 "type": "block", 66 "team_id": team_id, 67 }, 68 auth=ADMIN, 69 ) 70 assert r.status_code == 201 71 project_id = r.json()["project"] 72 73 # Assign to user, enable cache, and set up a passthrough workspace 74 r = client.patch( 75 f"/projects/{project_id}", 76 json={ 77 "users": [user_name], 78 "options": { 79 "cache": True, 80 "cache_threshold": 0.85, 81 "blockly_workspace": { 82 "blocks": { 83 "blocks": [ 84 { 85 "type": "restai_set_output", 86 "inputs": { 87 "VALUE": { 88 "block": {"type": "restai_get_input"} 89 } 90 }, 91 } 92 ] 93 }, 94 "variables": [], 95 }, 96 }, 97 }, 98 auth=ADMIN, 99 ) 100 assert r.status_code == 200 101 102 103 def test_cache_miss_first_request(client): 104 """First request should not be cached.""" 105 r = client.post( 106 f"/projects/{project_id}/question", 107 json={"question": "cache test question alpha"}, 108 auth=(user_name, user_pass), 109 ) 110 assert r.status_code == 200 111 data = r.json() 112 assert data.get("answer") == "cache test question alpha" 113 assert data.get("cached") is not True 114 115 116 def test_cache_hit_same_question(client): 117 """Same question should return cached result.""" 118 r = client.post( 119 f"/projects/{project_id}/question", 120 json={"question": "cache test question alpha"}, 121 auth=(user_name, user_pass), 122 ) 123 assert r.status_code == 200 124 data = r.json() 125 assert data.get("cached") is True 126 assert data.get("answer") == "cache test question alpha" 127 128 129 def test_cache_miss_different_question(client): 130 """A completely different question should not hit cache.""" 131 r = client.post( 132 f"/projects/{project_id}/question", 133 json={"question": "something completely unrelated xyz 12345"}, 134 auth=(user_name, user_pass), 135 ) 136 assert r.status_code == 200 137 data = r.json() 138 # Should not be a cache hit for a very different question 139 # (may or may not be cached depending on embedding similarity — 140 # but with default chromadb embeddings this should miss) 141 assert data.get("answer") == "something completely unrelated xyz 12345" 142 143 144 def test_cache_clear_endpoint(client): 145 """DELETE /projects/{id}/cache should clear the cache.""" 146 # Clear cache 147 r = client.delete( 148 f"/projects/{project_id}/cache", 149 auth=ADMIN, 150 ) 151 assert r.status_code == 200 152 assert r.json().get("cleared") is True 153 154 # Same question that was cached should now miss 155 r = client.post( 156 f"/projects/{project_id}/question", 157 json={"question": "cache test question alpha"}, 158 auth=(user_name, user_pass), 159 ) 160 assert r.status_code == 200 161 data = r.json() 162 assert data.get("cached") is not True 163 164 165 def test_cache_clear_when_not_enabled(client): 166 """Clearing cache on a project without cache should return cleared=False.""" 167 # Disable cache 168 r = client.patch( 169 f"/projects/{project_id}", 170 json={"options": {"cache": False}}, 171 auth=ADMIN, 172 ) 173 assert r.status_code == 200 174 175 r = client.delete( 176 f"/projects/{project_id}/cache", 177 auth=ADMIN, 178 ) 179 assert r.status_code == 200 180 assert r.json().get("cleared") is False 181 182 183 def test_cache_default_threshold(): 184 """Default cache_threshold should be 0.85.""" 185 from restai.models.models import ProjectOptions 186 opts = ProjectOptions() 187 assert opts.cache_threshold == 0.85 188 189 190 def test_cache_threshold_bounds(): 191 """cache_threshold should be bounded 0.0 to 1.0.""" 192 from restai.models.models import ProjectOptions 193 import pydantic 194 195 # Valid values 196 ProjectOptions(cache_threshold=0.0) 197 ProjectOptions(cache_threshold=1.0) 198 ProjectOptions(cache_threshold=0.5) 199 200 # Invalid values 201 try: 202 ProjectOptions(cache_threshold=1.5) 203 assert False, "Should reject threshold > 1.0" 204 except pydantic.ValidationError: 205 pass 206 207 try: 208 ProjectOptions(cache_threshold=-0.1) 209 assert False, "Should reject threshold < 0.0" 210 except pydantic.ValidationError: 211 pass 212 213 214 def test_cache_teardown(client): 215 """Clean up resources.""" 216 if project_id: 217 client.delete(f"/projects/{project_id}", auth=ADMIN) 218 if team_id: 219 client.delete(f"/teams/{team_id}", auth=ADMIN) 220 client.delete(f"/users/{user_name}", auth=ADMIN) 221 client.delete(f"/llms/{llm_name}", auth=ADMIN)