test_api_key_quota.py
1 """Per-API-key monthly token quota tests. 2 3 Covers: 4 * PATCH /users/{u}/apikeys/{id} updates token_quota_monthly + reset_usage 5 * check_api_key_quota raises 429 when tokens_used_this_month >= quota 6 * check_api_key_quota rolls the counter over when quota_reset_at lapses 7 * record_api_key_tokens increments the counter 8 9 The actual inference path (helper.py) is exercised in other tests; here 10 we unit-test the two budget.py helpers directly and the PATCH endpoint 11 end-to-end. 12 """ 13 from __future__ import annotations 14 15 import random 16 from datetime import datetime, timedelta, timezone 17 from types import SimpleNamespace 18 19 import pytest 20 from fastapi import HTTPException 21 from fastapi.testclient import TestClient 22 23 from restai.budget import ( 24 _first_of_next_month, 25 check_api_key_quota, 26 record_api_key_tokens, 27 ) 28 from restai.config import RESTAI_DEFAULT_PASSWORD 29 from restai.database import get_db_wrapper 30 from restai.main import app 31 from restai.models.databasemodels import ApiKeyDatabase 32 33 34 ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD) 35 36 37 @pytest.fixture(scope="module") 38 def client(): 39 with TestClient(app) as c: 40 yield c 41 42 43 @pytest.fixture 44 def api_key(client): 45 """Create a fresh API key for the admin user, yield the DB row id, 46 clean up after.""" 47 suffix = str(random.randint(0, 1_000_000)) 48 r = client.post( 49 "/users/admin/apikeys", 50 json={"description": f"quota_test_{suffix}"}, 51 auth=ADMIN, 52 ) 53 assert r.status_code == 201, r.text 54 key_id = r.json()["id"] 55 yield key_id 56 client.delete(f"/users/admin/apikeys/{key_id}", auth=ADMIN) 57 58 59 # ─── unit tests ───────────────────────────────────────────────────────── 60 61 def test_check_api_key_quota_noop_without_api_key_id(): 62 """Basic / cookie auth has no api_key_id — must be a no-op.""" 63 user = SimpleNamespace(api_key_id=None) 64 db = SimpleNamespace(db=SimpleNamespace()) # not touched 65 # Would AttributeError if the function dereferenced db. 66 check_api_key_quota(user, db) 67 68 69 def test_check_api_key_quota_noop_when_unlimited(api_key): 70 """token_quota_monthly=NULL is unlimited — skip the cap check.""" 71 db = get_db_wrapper() 72 try: 73 key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first() 74 assert key.token_quota_monthly is None 75 key.tokens_used_this_month = 999_999_999 76 db.db.commit() 77 user = SimpleNamespace(api_key_id=api_key) 78 check_api_key_quota(user, db) # must NOT raise 79 finally: 80 db.db.close() 81 82 83 def test_check_api_key_quota_raises_when_exceeded(api_key): 84 db = get_db_wrapper() 85 try: 86 key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first() 87 key.token_quota_monthly = 1000 88 key.tokens_used_this_month = 1000 89 # Future reset date so the rollover branch doesn't fire. 90 key.quota_reset_at = datetime.now(timezone.utc) + timedelta(days=7) 91 db.db.commit() 92 user = SimpleNamespace(api_key_id=api_key) 93 with pytest.raises(HTTPException) as ei: 94 check_api_key_quota(user, db) 95 assert ei.value.status_code == 429 96 assert "quota reached" in str(ei.value.detail).lower() 97 finally: 98 db.db.close() 99 100 101 def test_check_api_key_quota_rolls_over_on_lapsed_reset(api_key): 102 """quota_reset_at in the past → counter zeros, new reset date set, 103 NO 429.""" 104 db = get_db_wrapper() 105 try: 106 key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first() 107 key.token_quota_monthly = 10 108 key.tokens_used_this_month = 999 109 key.quota_reset_at = datetime(2024, 1, 1, tzinfo=timezone.utc) # past 110 db.db.commit() 111 112 user = SimpleNamespace(api_key_id=api_key) 113 check_api_key_quota(user, db) # must NOT raise 114 115 db.db.refresh(key) 116 assert key.tokens_used_this_month == 0 117 # SQLite strips tzinfo on storage; normalize before comparing. 118 reset = key.quota_reset_at 119 if reset.tzinfo is None: 120 reset = reset.replace(tzinfo=timezone.utc) 121 assert reset > datetime.now(timezone.utc) 122 finally: 123 db.db.close() 124 125 126 def test_record_api_key_tokens_bumps_counter(api_key): 127 db = get_db_wrapper() 128 try: 129 key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first() 130 key.tokens_used_this_month = 100 131 db.db.commit() 132 record_api_key_tokens(api_key, 50, db) 133 db.db.refresh(key) 134 assert key.tokens_used_this_month == 150 135 finally: 136 db.db.close() 137 138 139 def test_record_api_key_tokens_silent_on_unknown_id(): 140 db = get_db_wrapper() 141 try: 142 # Shouldn't raise — key may have been deleted between auth and log. 143 record_api_key_tokens(999_999_999, 10, db) 144 finally: 145 db.db.close() 146 147 148 def test_first_of_next_month_rolls_december(): 149 dec = datetime(2025, 12, 15, tzinfo=timezone.utc) 150 out = _first_of_next_month(dec) 151 assert out == datetime(2026, 1, 1, tzinfo=timezone.utc) 152 153 154 def test_first_of_next_month_rolls_mid_year(): 155 mar = datetime(2026, 3, 9, 14, 30, tzinfo=timezone.utc) 156 out = _first_of_next_month(mar) 157 assert out == datetime(2026, 4, 1, tzinfo=timezone.utc) 158 159 160 # ─── PATCH endpoint ───────────────────────────────────────────────────── 161 162 def test_patch_sets_quota(client, api_key): 163 r = client.patch( 164 f"/users/admin/apikeys/{api_key}", 165 json={"token_quota_monthly": 5000, "description": "monthly-capped"}, 166 auth=ADMIN, 167 ) 168 assert r.status_code == 200, r.text 169 body = r.json() 170 assert body["token_quota_monthly"] == 5000 171 assert body["description"] == "monthly-capped" 172 173 174 def test_patch_clears_quota_with_zero(client, api_key): 175 # First set a cap. 176 client.patch(f"/users/admin/apikeys/{api_key}", json={"token_quota_monthly": 100}, auth=ADMIN) 177 # Then clear with 0. 178 r = client.patch( 179 f"/users/admin/apikeys/{api_key}", 180 json={"token_quota_monthly": 0}, 181 auth=ADMIN, 182 ) 183 assert r.status_code == 200 184 assert r.json()["token_quota_monthly"] is None 185 186 187 def test_patch_reset_usage(client, api_key): 188 # Stamp some usage directly. 189 db = get_db_wrapper() 190 try: 191 key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first() 192 key.tokens_used_this_month = 12345 193 db.db.commit() 194 finally: 195 db.db.close() 196 197 r = client.patch( 198 f"/users/admin/apikeys/{api_key}", 199 json={"reset_usage": True}, 200 auth=ADMIN, 201 ) 202 assert r.status_code == 200 203 body = r.json() 204 assert body["tokens_used_this_month"] == 0 205 assert body["quota_reset_at"] is not None