/ tests / test_tools_extended.py
test_tools_extended.py
 1  import random
 2  import pytest
 3  from fastapi.testclient import TestClient
 4  
 5  from restai.config import RESTAI_DEFAULT_PASSWORD
 6  from restai.main import app
 7  
 8  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 9  
10  _suffix = str(random.randint(0, 999999))
11  test_username = f"tools_user_{_suffix}"
12  test_password = "tools_test_pass"
13  
14  
15  @pytest.fixture(scope="module")
16  def client():
17      with TestClient(app) as c:
18          yield c
19  
20  
21  def test_setup(client):
22      """Create a non-admin user for permission tests."""
23      resp = client.post(
24          "/users",
25          json={"username": test_username, "password": test_password, "admin": False, "private": False},
26          auth=ADMIN,
27      )
28      assert resp.status_code in (200, 201)
29  
30  
31  def test_list_classifiers(client):
32      """GET /tools/classifiers returns available classifier models."""
33      resp = client.get("/tools/classifiers", auth=ADMIN)
34      assert resp.status_code == 200
35      data = resp.json()
36      assert "classifiers" in data
37      assert isinstance(data["classifiers"], list)
38      assert "default" in data
39  
40  
41  def test_classifier_endpoint(client):
42      """POST /tools/classifier classifies text into provided labels."""
43      resp = client.post(
44          "/tools/classifier",
45          json={
46              "sequence": "This is great",
47              "labels": ["positive", "negative"],
48          },
49          auth=ADMIN,
50      )
51      # 200 if classifier is available, 500 if model not downloaded
52      assert resp.status_code in (200, 500)
53      if resp.status_code == 200:
54          data = resp.json()
55          assert "labels" in data or "scores" in data or "sequence" in data
56  
57  
58  def test_openai_compat_models_admin_only(client):
59      """GET /tools/openai-compat/models/{id} as non-admin should return 403."""
60      resp = client.get(
61          "/tools/openai-compat/models/1",
62          auth=(test_username, test_password),
63      )
64      assert resp.status_code == 403
65  
66  
67  def test_ollama_models_no_server(client):
68      """POST /tools/ollama/models with unreachable server should return 500."""
69      resp = client.post(
70          "/tools/ollama/models",
71          json={"host": "localhost", "port": 99999},
72          auth=ADMIN,
73      )
74      assert resp.status_code == 500
75  
76  
77  def test_ollama_pull_no_server(client):
78      """POST /tools/ollama/pull with unreachable server should return 500."""
79      resp = client.post(
80          "/tools/ollama/pull",
81          json={"name": "test", "host": "localhost", "port": 99999},
82          auth=ADMIN,
83      )
84      assert resp.status_code == 500
85  
86  
87  def test_cleanup(client):
88      client.delete(f"/users/{test_username}", auth=ADMIN)