test_llms.py
1 import random 2 import pytest 3 from fastapi.testclient import TestClient 4 5 from restai.config import RESTAI_DEFAULT_PASSWORD 6 from restai.main import app 7 8 test_llm_name = "test_llm_" + str(random.randint(0, 1000000)) 9 test_user = "test_llm_user_" + str(random.randint(0, 1000000)) 10 test_llm_id = None 11 12 13 @pytest.fixture(scope="module") 14 def client(): 15 with TestClient(app) as c: 16 yield c 17 18 19 def test_get_llms(client): 20 response = client.get("/llms", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 21 assert response.status_code == 200 22 assert isinstance(response.json(), list) 23 24 25 def test_create_llm(client): 26 global test_llm_id 27 response = client.post( 28 "/llms", 29 json={ 30 "name": test_llm_name, 31 "class_name": "OpenAI", 32 "options": {"model": "gpt-test", "api_key": "sk-fake123"}, 33 "privacy": "public", 34 }, 35 auth=("admin", RESTAI_DEFAULT_PASSWORD), 36 ) 37 assert response.status_code == 201 38 data = response.json() 39 assert data["name"] == test_llm_name 40 test_llm_id = data["id"] 41 42 43 def test_create_llm_non_admin(client): 44 # Create a non-admin user 45 client.post( 46 "/users", 47 json={ 48 "username": test_user, 49 "password": "testpass", 50 "admin": False, 51 "private": False, 52 }, 53 auth=("admin", RESTAI_DEFAULT_PASSWORD), 54 ) 55 56 response = client.post( 57 "/llms", 58 json={ 59 "name": "should_fail_llm", 60 "class_name": "OpenAI", 61 "options": {"model": "gpt-test", "api_key": "sk-fake"}, 62 "privacy": "public", 63 }, 64 auth=(test_user, "testpass"), 65 ) 66 assert response.status_code == 403 67 68 # Clean up user 69 client.delete( 70 f"/users/{test_user}", 71 auth=("admin", RESTAI_DEFAULT_PASSWORD), 72 ) 73 74 75 def test_get_llm(client): 76 response = client.get( 77 f"/llms/{test_llm_id}", 78 auth=("admin", RESTAI_DEFAULT_PASSWORD), 79 ) 80 assert response.status_code == 200 81 data = response.json() 82 assert data["name"] == test_llm_name 83 assert data["class_name"] == "OpenAI" 84 assert data["privacy"] == "public" 85 # API key should be masked 86 options = data["options"] 87 if isinstance(options, str): 88 import json 89 options = json.loads(options) 90 assert options.get("api_key") == "********" 91 92 93 def test_update_llm(client): 94 response = client.patch( 95 f"/llms/{test_llm_id}", 96 json={"description": "Updated test LLM"}, 97 auth=("admin", RESTAI_DEFAULT_PASSWORD), 98 ) 99 assert response.status_code == 200 100 101 # Verify update 102 response = client.get( 103 f"/llms/{test_llm_id}", 104 auth=("admin", RESTAI_DEFAULT_PASSWORD), 105 ) 106 assert response.status_code == 200 107 assert response.json()["description"] == "Updated test LLM" 108 109 110 def test_delete_llm(client): 111 response = client.delete( 112 f"/llms/{test_llm_id}", 113 auth=("admin", RESTAI_DEFAULT_PASSWORD), 114 ) 115 assert response.status_code == 200 116 117 118 def test_delete_llm_not_found(client): 119 response = client.delete( 120 "/llms/999999", 121 auth=("admin", RESTAI_DEFAULT_PASSWORD), 122 ) 123 assert response.status_code == 404