/ tests / test_embeddings.py
test_embeddings.py
  1  import random
  2  import pytest
  3  from fastapi.testclient import TestClient
  4  
  5  from restai.config import RESTAI_DEFAULT_PASSWORD
  6  from restai.main import app
  7  
  8  test_embedding_name = "test_embedding_" + str(random.randint(0, 1000000))
  9  test_user = "test_emb_user_" + str(random.randint(0, 1000000))
 10  test_embedding_id = None
 11  
 12  
 13  @pytest.fixture(scope="module")
 14  def client():
 15      with TestClient(app) as c:
 16          yield c
 17  
 18  
 19  def test_get_embeddings(client):
 20      response = client.get("/embeddings", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 21      assert response.status_code == 200
 22      assert isinstance(response.json(), list)
 23  
 24  
 25  def test_create_embedding(client):
 26      global test_embedding_id
 27      response = client.post(
 28          "/embeddings",
 29          json={
 30              "name": test_embedding_name,
 31              "class_name": "Ollama",
 32              "options": "{}",
 33              "privacy": "public",
 34              "dimension": 768,
 35          },
 36          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 37      )
 38      assert response.status_code == 201
 39      data = response.json()
 40      assert data["name"] == test_embedding_name
 41      test_embedding_id = data["id"]
 42  
 43  
 44  def test_create_embedding_non_admin(client):
 45      # Create a non-admin user
 46      client.post(
 47          "/users",
 48          json={
 49              "username": test_user,
 50              "password": "testpass",
 51              "admin": False,
 52              "private": False,
 53          },
 54          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 55      )
 56  
 57      response = client.post(
 58          "/embeddings",
 59          json={
 60              "name": "should_fail_embedding",
 61              "class_name": "Ollama",
 62              "options": "{}",
 63              "privacy": "public",
 64              "dimension": 768,
 65          },
 66          auth=(test_user, "testpass"),
 67      )
 68      assert response.status_code == 403
 69  
 70      # Clean up user
 71      client.delete(
 72          f"/users/{test_user}",
 73          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 74      )
 75  
 76  
 77  def test_get_embedding(client):
 78      response = client.get(
 79          f"/embeddings/{test_embedding_id}",
 80          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 81      )
 82      assert response.status_code == 200
 83      data = response.json()
 84      assert data["name"] == test_embedding_name
 85      assert data["class_name"] == "Ollama"
 86      assert data["privacy"] == "public"
 87  
 88  
 89  def test_update_embedding(client):
 90      response = client.patch(
 91          f"/embeddings/{test_embedding_id}",
 92          json={"description": "Updated test embedding", "dimension": 512},
 93          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 94      )
 95      assert response.status_code == 200
 96  
 97      # Verify update
 98      response = client.get(
 99          f"/embeddings/{test_embedding_id}",
100          auth=("admin", RESTAI_DEFAULT_PASSWORD),
101      )
102      assert response.status_code == 200
103      data = response.json()
104      assert data["description"] == "Updated test embedding"
105      assert data["dimension"] == 512
106  
107  
108  def test_delete_embedding(client):
109      response = client.delete(
110          f"/embeddings/{test_embedding_id}",
111          auth=("admin", RESTAI_DEFAULT_PASSWORD),
112      )
113      assert response.status_code == 200
114  
115  
116  def test_delete_embedding_not_found(client):
117      response = client.delete(
118          "/embeddings/999999",
119          auth=("admin", RESTAI_DEFAULT_PASSWORD),
120      )
121      assert response.status_code == 404