test_embeddings.py
1 import random 2 import pytest 3 from fastapi.testclient import TestClient 4 5 from restai.config import RESTAI_DEFAULT_PASSWORD 6 from restai.main import app 7 8 test_embedding_name = "test_embedding_" + str(random.randint(0, 1000000)) 9 test_user = "test_emb_user_" + str(random.randint(0, 1000000)) 10 test_embedding_id = None 11 12 13 @pytest.fixture(scope="module") 14 def client(): 15 with TestClient(app) as c: 16 yield c 17 18 19 def test_get_embeddings(client): 20 response = client.get("/embeddings", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 21 assert response.status_code == 200 22 assert isinstance(response.json(), list) 23 24 25 def test_create_embedding(client): 26 global test_embedding_id 27 response = client.post( 28 "/embeddings", 29 json={ 30 "name": test_embedding_name, 31 "class_name": "Ollama", 32 "options": "{}", 33 "privacy": "public", 34 "dimension": 768, 35 }, 36 auth=("admin", RESTAI_DEFAULT_PASSWORD), 37 ) 38 assert response.status_code == 201 39 data = response.json() 40 assert data["name"] == test_embedding_name 41 test_embedding_id = data["id"] 42 43 44 def test_create_embedding_non_admin(client): 45 # Create a non-admin user 46 client.post( 47 "/users", 48 json={ 49 "username": test_user, 50 "password": "testpass", 51 "admin": False, 52 "private": False, 53 }, 54 auth=("admin", RESTAI_DEFAULT_PASSWORD), 55 ) 56 57 response = client.post( 58 "/embeddings", 59 json={ 60 "name": "should_fail_embedding", 61 "class_name": "Ollama", 62 "options": "{}", 63 "privacy": "public", 64 "dimension": 768, 65 }, 66 auth=(test_user, "testpass"), 67 ) 68 assert response.status_code == 403 69 70 # Clean up user 71 client.delete( 72 f"/users/{test_user}", 73 auth=("admin", RESTAI_DEFAULT_PASSWORD), 74 ) 75 76 77 def test_get_embedding(client): 78 response = client.get( 79 f"/embeddings/{test_embedding_id}", 80 auth=("admin", RESTAI_DEFAULT_PASSWORD), 81 ) 82 assert response.status_code == 200 83 data = response.json() 84 assert data["name"] == test_embedding_name 85 assert data["class_name"] == "Ollama" 86 assert data["privacy"] == "public" 87 88 89 def test_update_embedding(client): 90 response = client.patch( 91 f"/embeddings/{test_embedding_id}", 92 json={"description": "Updated test embedding", "dimension": 512}, 93 auth=("admin", RESTAI_DEFAULT_PASSWORD), 94 ) 95 assert response.status_code == 200 96 97 # Verify update 98 response = client.get( 99 f"/embeddings/{test_embedding_id}", 100 auth=("admin", RESTAI_DEFAULT_PASSWORD), 101 ) 102 assert response.status_code == 200 103 data = response.json() 104 assert data["description"] == "Updated test embedding" 105 assert data["dimension"] == 512 106 107 108 def test_delete_embedding(client): 109 response = client.delete( 110 f"/embeddings/{test_embedding_id}", 111 auth=("admin", RESTAI_DEFAULT_PASSWORD), 112 ) 113 assert response.status_code == 200 114 115 116 def test_delete_embedding_not_found(client): 117 response = client.delete( 118 "/embeddings/999999", 119 auth=("admin", RESTAI_DEFAULT_PASSWORD), 120 ) 121 assert response.status_code == 404