test_tools_extended.py
1 import random 2 import pytest 3 from fastapi.testclient import TestClient 4 5 from restai.config import RESTAI_DEFAULT_PASSWORD 6 from restai.main import app 7 8 ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD) 9 10 _suffix = str(random.randint(0, 999999)) 11 test_username = f"tools_user_{_suffix}" 12 test_password = "tools_test_pass" 13 14 15 @pytest.fixture(scope="module") 16 def client(): 17 with TestClient(app) as c: 18 yield c 19 20 21 def test_setup(client): 22 """Create a non-admin user for permission tests.""" 23 resp = client.post( 24 "/users", 25 json={"username": test_username, "password": test_password, "admin": False, "private": False}, 26 auth=ADMIN, 27 ) 28 assert resp.status_code in (200, 201) 29 30 31 def test_list_classifiers(client): 32 """GET /tools/classifiers returns available classifier models.""" 33 resp = client.get("/tools/classifiers", auth=ADMIN) 34 assert resp.status_code == 200 35 data = resp.json() 36 assert "classifiers" in data 37 assert isinstance(data["classifiers"], list) 38 assert "default" in data 39 40 41 def test_classifier_endpoint(client): 42 """POST /tools/classifier classifies text into provided labels.""" 43 resp = client.post( 44 "/tools/classifier", 45 json={ 46 "sequence": "This is great", 47 "labels": ["positive", "negative"], 48 }, 49 auth=ADMIN, 50 ) 51 # 200 if classifier is available, 500 if model not downloaded 52 assert resp.status_code in (200, 500) 53 if resp.status_code == 200: 54 data = resp.json() 55 assert "labels" in data or "scores" in data or "sequence" in data 56 57 58 def test_openai_compat_models_admin_only(client): 59 """GET /tools/openai-compat/models/{id} as non-admin should return 403.""" 60 resp = client.get( 61 "/tools/openai-compat/models/1", 62 auth=(test_username, test_password), 63 ) 64 assert resp.status_code == 403 65 66 67 def test_ollama_models_no_server(client): 68 """POST /tools/ollama/models with unreachable server should return 500.""" 69 resp = client.post( 70 "/tools/ollama/models", 71 json={"host": "localhost", "port": 99999}, 72 auth=ADMIN, 73 ) 74 assert resp.status_code == 500 75 76 77 def test_ollama_pull_no_server(client): 78 """POST /tools/ollama/pull with unreachable server should return 500.""" 79 resp = client.post( 80 "/tools/ollama/pull", 81 json={"name": "test", "host": "localhost", "port": 99999}, 82 auth=ADMIN, 83 ) 84 assert resp.status_code == 500 85 86 87 def test_cleanup(client): 88 client.delete(f"/users/{test_username}", auth=ADMIN)