/ tests / test_input_validation.py
test_input_validation.py
  1  import pytest
  2  from fastapi.testclient import TestClient
  3  
  4  from restai.config import RESTAI_DEFAULT_PASSWORD
  5  from restai.main import app
  6  
  7  INVALID_NAMES = [
  8      "has/slash",
  9      "has space",
 10      "has@at",
 11      "has&amp",
 12      "has?query",
 13      "has#hash",
 14      "has%percent",
 15      "has+plus",
 16      "has=equals",
 17      "../traversal",
 18  ]
 19  
 20  VALID_NAMES = [
 21      "simple",
 22      "with-hyphen",
 23      "with_underscore",
 24      "with.dot",
 25      "MixedCase123",
 26  ]
 27  
 28  
 29  @pytest.fixture(scope="module")
 30  def client():
 31      with TestClient(app) as c:
 32          yield c
 33  
 34  
 35  def test_project_name_validation(client):
 36      for name in INVALID_NAMES:
 37          response = client.post(
 38              "/projects",
 39              json={"name": name, "type": "agent", "llm": "fake", "team_id": 1},
 40              auth=("admin", RESTAI_DEFAULT_PASSWORD),
 41          )
 42          assert response.status_code == 422, f"Expected 422 for project name {name!r}, got {response.status_code}"
 43  
 44  
 45  def test_user_name_validation(client):
 46      for name in INVALID_NAMES:
 47          response = client.post(
 48              "/users",
 49              json={"username": name, "password": "testpass"},
 50              auth=("admin", RESTAI_DEFAULT_PASSWORD),
 51          )
 52          assert response.status_code == 422, f"Expected 422 for username {name!r}, got {response.status_code}"
 53  
 54  
 55  def test_team_name_validation(client):
 56      """Team names allow special chars (teams use IDs in URLs). Only empty names are rejected."""
 57      for name in ["", "   "]:
 58          response = client.post(
 59              "/teams",
 60              json={"name": name},
 61              auth=("admin", RESTAI_DEFAULT_PASSWORD),
 62          )
 63          assert response.status_code == 422, f"Expected 422 for empty team name {name!r}, got {response.status_code}"
 64  
 65  
 66  def test_llm_name_validation(client):
 67      for name in INVALID_NAMES:
 68          response = client.post(
 69              "/llms",
 70              json={
 71                  "name": name,
 72                  "class_name": "OpenAI",
 73                  "options": {"model": "test"},
 74                  "privacy": "public",
 75              },
 76              auth=("admin", RESTAI_DEFAULT_PASSWORD),
 77          )
 78          assert response.status_code == 422, f"Expected 422 for LLM name {name!r}, got {response.status_code}"
 79  
 80  
 81  def test_embedding_name_validation(client):
 82      for name in INVALID_NAMES:
 83          response = client.post(
 84              "/embeddings",
 85              json={
 86                  "name": name,
 87                  "class_name": "OpenAI",
 88                  "options": "{}",
 89                  "privacy": "public",
 90              },
 91              auth=("admin", RESTAI_DEFAULT_PASSWORD),
 92          )
 93          assert response.status_code == 422, f"Expected 422 for embedding name {name!r}, got {response.status_code}"
 94  
 95  
 96  def test_valid_names_accepted(client):
 97      """Ensure valid names don't trigger validation errors (they may fail for other reasons like missing LLM)."""
 98      for name in VALID_NAMES:
 99          response = client.post(
100              "/users",
101              json={"username": name, "password": "testpass"},
102              auth=("admin", RESTAI_DEFAULT_PASSWORD),
103          )
104          # Should not be 422 (validation error) — may be 201 or other status
105          assert response.status_code != 422, f"Valid username {name!r} was incorrectly rejected"
106  
107      # Clean up created users
108      for name in VALID_NAMES:
109          client.delete(f"/users/{name}", auth=("admin", RESTAI_DEFAULT_PASSWORD))
110  
111  
112  def test_llm_invalid_privacy(client):
113      response = client.post(
114          "/llms",
115          json={
116              "name": "test-llm",
117              "class_name": "OpenAI",
118              "options": {"model": "test"},
119              "privacy": "secret",
120          },
121          auth=("admin", RESTAI_DEFAULT_PASSWORD),
122      )
123      assert response.status_code == 422, f"Expected 422 for invalid privacy, got {response.status_code}"
124  
125  
126  def test_llm_invalid_class_name(client):
127      response = client.post(
128          "/llms",
129          json={
130              "name": "test-llm",
131              "class_name": "FakeProvider",
132              "options": {"model": "test"},
133              "privacy": "public",
134          },
135          auth=("admin", RESTAI_DEFAULT_PASSWORD),
136      )
137      assert response.status_code == 422, f"Expected 422 for invalid class_name, got {response.status_code}"
138  
139  
140  def test_embedding_invalid_class_name(client):
141      response = client.post(
142          "/embeddings",
143          json={
144              "name": "test-emb",
145              "class_name": "FakeEmbedding",
146              "options": "{}",
147              "privacy": "public",
148          },
149          auth=("admin", RESTAI_DEFAULT_PASSWORD),
150      )
151      assert response.status_code == 422, f"Expected 422 for invalid embedding class_name, got {response.status_code}"
152  
153  
154  def test_project_invalid_type(client):
155      response = client.post(
156          "/projects",
157          json={"name": "test-proj", "type": "magic", "llm": "fake", "team_id": 1},
158          auth=("admin", RESTAI_DEFAULT_PASSWORD),
159      )
160      assert response.status_code == 422, f"Expected 422 for invalid project type, got {response.status_code}"
161  
162  
163  def test_llm_valid_enums_accepted(client):
164      """Ensure valid enum values don't trigger validation errors."""
165      response = client.post(
166          "/llms",
167          json={
168              "name": "test-valid-llm",
169              "class_name": "OpenAI",
170              "options": {"model": "gpt-test", "api_key": "sk-fake"},
171              "privacy": "public",
172          },
173          auth=("admin", RESTAI_DEFAULT_PASSWORD),
174      )
175      # Should not be 422 — may succeed (201) or fail for other reasons
176      assert response.status_code != 422, f"Valid LLM enums were incorrectly rejected"
177      # Clean up if created
178      if response.status_code == 201:
179          client.delete("/llms/test-valid-llm", auth=("admin", RESTAI_DEFAULT_PASSWORD))