/ tests / test_llms.py
test_llms.py
  1  import random
  2  import pytest
  3  from fastapi.testclient import TestClient
  4  
  5  from restai.config import RESTAI_DEFAULT_PASSWORD
  6  from restai.main import app
  7  
  8  test_llm_name = "test_llm_" + str(random.randint(0, 1000000))
  9  test_user = "test_llm_user_" + str(random.randint(0, 1000000))
 10  test_llm_id = None
 11  
 12  
 13  @pytest.fixture(scope="module")
 14  def client():
 15      with TestClient(app) as c:
 16          yield c
 17  
 18  
 19  def test_get_llms(client):
 20      response = client.get("/llms", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 21      assert response.status_code == 200
 22      assert isinstance(response.json(), list)
 23  
 24  
 25  def test_create_llm(client):
 26      global test_llm_id
 27      response = client.post(
 28          "/llms",
 29          json={
 30              "name": test_llm_name,
 31              "class_name": "OpenAI",
 32              "options": {"model": "gpt-test", "api_key": "sk-fake123"},
 33              "privacy": "public",
 34          },
 35          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 36      )
 37      assert response.status_code == 201
 38      data = response.json()
 39      assert data["name"] == test_llm_name
 40      test_llm_id = data["id"]
 41  
 42  
 43  def test_create_llm_non_admin(client):
 44      # Create a non-admin user
 45      client.post(
 46          "/users",
 47          json={
 48              "username": test_user,
 49              "password": "testpass",
 50              "admin": False,
 51              "private": False,
 52          },
 53          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 54      )
 55  
 56      response = client.post(
 57          "/llms",
 58          json={
 59              "name": "should_fail_llm",
 60              "class_name": "OpenAI",
 61              "options": {"model": "gpt-test", "api_key": "sk-fake"},
 62              "privacy": "public",
 63          },
 64          auth=(test_user, "testpass"),
 65      )
 66      assert response.status_code == 403
 67  
 68      # Clean up user
 69      client.delete(
 70          f"/users/{test_user}",
 71          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 72      )
 73  
 74  
 75  def test_get_llm(client):
 76      response = client.get(
 77          f"/llms/{test_llm_id}",
 78          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 79      )
 80      assert response.status_code == 200
 81      data = response.json()
 82      assert data["name"] == test_llm_name
 83      assert data["class_name"] == "OpenAI"
 84      assert data["privacy"] == "public"
 85      # API key should be masked
 86      options = data["options"]
 87      if isinstance(options, str):
 88          import json
 89          options = json.loads(options)
 90      assert options.get("api_key") == "********"
 91  
 92  
 93  def test_update_llm(client):
 94      response = client.patch(
 95          f"/llms/{test_llm_id}",
 96          json={"description": "Updated test LLM"},
 97          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 98      )
 99      assert response.status_code == 200
100  
101      # Verify update
102      response = client.get(
103          f"/llms/{test_llm_id}",
104          auth=("admin", RESTAI_DEFAULT_PASSWORD),
105      )
106      assert response.status_code == 200
107      assert response.json()["description"] == "Updated test LLM"
108  
109  
110  def test_delete_llm(client):
111      response = client.delete(
112          f"/llms/{test_llm_id}",
113          auth=("admin", RESTAI_DEFAULT_PASSWORD),
114      )
115      assert response.status_code == 200
116  
117  
118  def test_delete_llm_not_found(client):
119      response = client.delete(
120          "/llms/999999",
121          auth=("admin", RESTAI_DEFAULT_PASSWORD),
122      )
123      assert response.status_code == 404