/ tests / test_api_key_quota.py
test_api_key_quota.py
  1  """Per-API-key monthly token quota tests.
  2  
  3  Covers:
  4  * PATCH /users/{u}/apikeys/{id} updates token_quota_monthly + reset_usage
  5  * check_api_key_quota raises 429 when tokens_used_this_month >= quota
  6  * check_api_key_quota rolls the counter over when quota_reset_at lapses
  7  * record_api_key_tokens increments the counter
  8  
  9  The actual inference path (helper.py) is exercised in other tests; here
 10  we unit-test the two budget.py helpers directly and the PATCH endpoint
 11  end-to-end.
 12  """
 13  from __future__ import annotations
 14  
 15  import random
 16  from datetime import datetime, timedelta, timezone
 17  from types import SimpleNamespace
 18  
 19  import pytest
 20  from fastapi import HTTPException
 21  from fastapi.testclient import TestClient
 22  
 23  from restai.budget import (
 24      _first_of_next_month,
 25      check_api_key_quota,
 26      record_api_key_tokens,
 27  )
 28  from restai.config import RESTAI_DEFAULT_PASSWORD
 29  from restai.database import get_db_wrapper
 30  from restai.main import app
 31  from restai.models.databasemodels import ApiKeyDatabase
 32  
 33  
 34  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 35  
 36  
 37  @pytest.fixture(scope="module")
 38  def client():
 39      with TestClient(app) as c:
 40          yield c
 41  
 42  
 43  @pytest.fixture
 44  def api_key(client):
 45      """Create a fresh API key for the admin user, yield the DB row id,
 46      clean up after."""
 47      suffix = str(random.randint(0, 1_000_000))
 48      r = client.post(
 49          "/users/admin/apikeys",
 50          json={"description": f"quota_test_{suffix}"},
 51          auth=ADMIN,
 52      )
 53      assert r.status_code == 201, r.text
 54      key_id = r.json()["id"]
 55      yield key_id
 56      client.delete(f"/users/admin/apikeys/{key_id}", auth=ADMIN)
 57  
 58  
 59  # ─── unit tests ─────────────────────────────────────────────────────────
 60  
 61  def test_check_api_key_quota_noop_without_api_key_id():
 62      """Basic / cookie auth has no api_key_id — must be a no-op."""
 63      user = SimpleNamespace(api_key_id=None)
 64      db = SimpleNamespace(db=SimpleNamespace())  # not touched
 65      # Would AttributeError if the function dereferenced db.
 66      check_api_key_quota(user, db)
 67  
 68  
 69  def test_check_api_key_quota_noop_when_unlimited(api_key):
 70      """token_quota_monthly=NULL is unlimited — skip the cap check."""
 71      db = get_db_wrapper()
 72      try:
 73          key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first()
 74          assert key.token_quota_monthly is None
 75          key.tokens_used_this_month = 999_999_999
 76          db.db.commit()
 77          user = SimpleNamespace(api_key_id=api_key)
 78          check_api_key_quota(user, db)  # must NOT raise
 79      finally:
 80          db.db.close()
 81  
 82  
 83  def test_check_api_key_quota_raises_when_exceeded(api_key):
 84      db = get_db_wrapper()
 85      try:
 86          key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first()
 87          key.token_quota_monthly = 1000
 88          key.tokens_used_this_month = 1000
 89          # Future reset date so the rollover branch doesn't fire.
 90          key.quota_reset_at = datetime.now(timezone.utc) + timedelta(days=7)
 91          db.db.commit()
 92          user = SimpleNamespace(api_key_id=api_key)
 93          with pytest.raises(HTTPException) as ei:
 94              check_api_key_quota(user, db)
 95          assert ei.value.status_code == 429
 96          assert "quota reached" in str(ei.value.detail).lower()
 97      finally:
 98          db.db.close()
 99  
100  
101  def test_check_api_key_quota_rolls_over_on_lapsed_reset(api_key):
102      """quota_reset_at in the past → counter zeros, new reset date set,
103      NO 429."""
104      db = get_db_wrapper()
105      try:
106          key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first()
107          key.token_quota_monthly = 10
108          key.tokens_used_this_month = 999
109          key.quota_reset_at = datetime(2024, 1, 1, tzinfo=timezone.utc)  # past
110          db.db.commit()
111  
112          user = SimpleNamespace(api_key_id=api_key)
113          check_api_key_quota(user, db)  # must NOT raise
114  
115          db.db.refresh(key)
116          assert key.tokens_used_this_month == 0
117          # SQLite strips tzinfo on storage; normalize before comparing.
118          reset = key.quota_reset_at
119          if reset.tzinfo is None:
120              reset = reset.replace(tzinfo=timezone.utc)
121          assert reset > datetime.now(timezone.utc)
122      finally:
123          db.db.close()
124  
125  
126  def test_record_api_key_tokens_bumps_counter(api_key):
127      db = get_db_wrapper()
128      try:
129          key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first()
130          key.tokens_used_this_month = 100
131          db.db.commit()
132          record_api_key_tokens(api_key, 50, db)
133          db.db.refresh(key)
134          assert key.tokens_used_this_month == 150
135      finally:
136          db.db.close()
137  
138  
139  def test_record_api_key_tokens_silent_on_unknown_id():
140      db = get_db_wrapper()
141      try:
142          # Shouldn't raise — key may have been deleted between auth and log.
143          record_api_key_tokens(999_999_999, 10, db)
144      finally:
145          db.db.close()
146  
147  
148  def test_first_of_next_month_rolls_december():
149      dec = datetime(2025, 12, 15, tzinfo=timezone.utc)
150      out = _first_of_next_month(dec)
151      assert out == datetime(2026, 1, 1, tzinfo=timezone.utc)
152  
153  
154  def test_first_of_next_month_rolls_mid_year():
155      mar = datetime(2026, 3, 9, 14, 30, tzinfo=timezone.utc)
156      out = _first_of_next_month(mar)
157      assert out == datetime(2026, 4, 1, tzinfo=timezone.utc)
158  
159  
160  # ─── PATCH endpoint ─────────────────────────────────────────────────────
161  
162  def test_patch_sets_quota(client, api_key):
163      r = client.patch(
164          f"/users/admin/apikeys/{api_key}",
165          json={"token_quota_monthly": 5000, "description": "monthly-capped"},
166          auth=ADMIN,
167      )
168      assert r.status_code == 200, r.text
169      body = r.json()
170      assert body["token_quota_monthly"] == 5000
171      assert body["description"] == "monthly-capped"
172  
173  
174  def test_patch_clears_quota_with_zero(client, api_key):
175      # First set a cap.
176      client.patch(f"/users/admin/apikeys/{api_key}", json={"token_quota_monthly": 100}, auth=ADMIN)
177      # Then clear with 0.
178      r = client.patch(
179          f"/users/admin/apikeys/{api_key}",
180          json={"token_quota_monthly": 0},
181          auth=ADMIN,
182      )
183      assert r.status_code == 200
184      assert r.json()["token_quota_monthly"] is None
185  
186  
187  def test_patch_reset_usage(client, api_key):
188      # Stamp some usage directly.
189      db = get_db_wrapper()
190      try:
191          key = db.db.query(ApiKeyDatabase).filter(ApiKeyDatabase.id == api_key).first()
192          key.tokens_used_this_month = 12345
193          db.db.commit()
194      finally:
195          db.db.close()
196  
197      r = client.patch(
198          f"/users/admin/apikeys/{api_key}",
199          json={"reset_usage": True},
200          auth=ADMIN,
201      )
202      assert r.status_code == 200
203      body = r.json()
204      assert body["tokens_used_this_month"] == 0
205      assert body["quota_reset_at"] is not None