/ tests / test_bulk_ingest.py
test_bulk_ingest.py
  1  """Bulk ingest queue endpoint tests.
  2  
  3  Covers the HTTP surface (enqueue + list + delete). The actual cron
  4  that drains the queue runs out-of-process and is not exercised here;
  5  we just verify that rows land correctly and the cleanup path works.
  6  """
  7  from __future__ import annotations
  8  
  9  import io
 10  import os
 11  import random
 12  
 13  import pytest
 14  from fastapi.testclient import TestClient
 15  
 16  from restai.config import RESTAI_DEFAULT_PASSWORD
 17  from restai.database import get_db_wrapper
 18  from restai.main import app
 19  from restai.models.databasemodels import BulkIngestJobDatabase
 20  
 21  
 22  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 23  
 24  
 25  @pytest.fixture(scope="module")
 26  def client():
 27      with TestClient(app) as c:
 28          yield c
 29  
 30  
 31  @pytest.fixture(scope="module")
 32  def rag_project(client):
 33      teams = client.get("/teams", auth=ADMIN).json().get("teams", []) or []
 34      if not teams:
 35          pytest.skip("no team available")
 36      llms = (client.get("/info", auth=ADMIN).json() or {}).get("llms") or []
 37      embs = (client.get("/info", auth=ADMIN).json() or {}).get("embeddings") or []
 38      if not llms or not embs:
 39          pytest.skip("no LLMs or embeddings configured")
 40  
 41      name = f"bulk_test_{random.randint(0, 999999)}"
 42      r = client.post(
 43          "/projects",
 44          json={
 45              "name": name, "type": "rag",
 46              "llm": llms[0]["name"], "embeddings": embs[0]["name"],
 47              "team_id": teams[0]["id"], "vectorstore": "chromadb",
 48          },
 49          auth=ADMIN,
 50      )
 51      if r.status_code not in (200, 201):
 52          pytest.skip(f"could not create RAG project: {r.status_code} {r.text}")
 53      pid = r.json().get("project") or r.json().get("id")
 54      yield pid
 55      client.delete(f"/projects/{pid}", auth=ADMIN)
 56  
 57  
 58  def test_list_empty_initially(client, rag_project):
 59      r = client.get(f"/projects/{rag_project}/ingest-bulk", auth=ADMIN)
 60      assert r.status_code == 200
 61      # Existing rows from prior runs in the same DB would still be here —
 62      # what matters is that the endpoint is reachable and returns a
 63      # well-formed payload.
 64      body = r.json()
 65      assert "jobs" in body
 66      assert isinstance(body["jobs"], list)
 67  
 68  
 69  def test_enqueue_creates_queued_rows(client, rag_project):
 70      files = [
 71          ("files", ("a.txt", io.BytesIO(b"hello world a"), "text/plain")),
 72          ("files", ("b.txt", io.BytesIO(b"hello world b"), "text/plain")),
 73      ]
 74      r = client.post(
 75          f"/projects/{rag_project}/ingest-bulk",
 76          files=files, auth=ADMIN,
 77      )
 78      assert r.status_code == 202, r.text
 79      body = r.json()
 80      assert body["count"] == 2
 81      queued = body["queued"]
 82      assert len(queued) == 2
 83  
 84      try:
 85          # Verify both rows are queued and have a tempfile on disk
 86          db = get_db_wrapper()
 87          try:
 88              rows = (
 89                  db.db.query(BulkIngestJobDatabase)
 90                  .filter(BulkIngestJobDatabase.id.in_(queued))
 91                  .all()
 92              )
 93              assert len(rows) == 2
 94              for row in rows:
 95                  assert row.status == "queued"
 96                  assert row.size_bytes > 0
 97                  assert os.path.isfile(row.file_path), row.file_path
 98                  assert row.filename.endswith(".txt")
 99          finally:
100              db.db.close()
101  
102          # List should now show them
103          listing = client.get(f"/projects/{rag_project}/ingest-bulk", auth=ADMIN).json()
104          ids = [j["id"] for j in listing["jobs"]]
105          for jid in queued:
106              assert jid in ids
107      finally:
108          for jid in queued:
109              client.delete(f"/projects/{rag_project}/ingest-bulk/{jid}", auth=ADMIN)
110  
111  
112  def test_delete_removes_tempfile_and_row(client, rag_project):
113      files = [("files", ("c.txt", io.BytesIO(b"delete me"), "text/plain"))]
114      r = client.post(f"/projects/{rag_project}/ingest-bulk", files=files, auth=ADMIN)
115      jid = r.json()["queued"][0]
116  
117      db = get_db_wrapper()
118      try:
119          row = db.db.query(BulkIngestJobDatabase).filter(BulkIngestJobDatabase.id == jid).first()
120          staged_path = row.file_path
121          assert os.path.isfile(staged_path)
122      finally:
123          db.db.close()
124  
125      r = client.delete(f"/projects/{rag_project}/ingest-bulk/{jid}", auth=ADMIN)
126      assert r.status_code == 200
127  
128      # Row gone, file gone
129      db = get_db_wrapper()
130      try:
131          row = db.db.query(BulkIngestJobDatabase).filter(BulkIngestJobDatabase.id == jid).first()
132          assert row is None
133      finally:
134          db.db.close()
135      assert not os.path.isfile(staged_path)
136  
137  
138  def test_enqueue_rejects_empty_upload(client, rag_project):
139      r = client.post(f"/projects/{rag_project}/ingest-bulk", auth=ADMIN)
140      # FastAPI returns 422 when the required `files` param is missing.
141      assert r.status_code in (400, 422)
142  
143  
144  def test_enqueue_rejects_non_rag_project(client):
145      """Agent projects don't have a vectorstore — reject bulk ingest."""
146      teams = client.get("/teams", auth=ADMIN).json().get("teams", []) or []
147      llms = (client.get("/info", auth=ADMIN).json() or {}).get("llms") or []
148      if not teams or not llms:
149          pytest.skip("fixtures unavailable")
150  
151      name = f"bulk_agent_{random.randint(0, 999999)}"
152      r = client.post(
153          "/projects",
154          json={"name": name, "type": "agent", "llm": llms[0]["name"], "team_id": teams[0]["id"]},
155          auth=ADMIN,
156      )
157      if r.status_code not in (200, 201):
158          pytest.skip(f"could not create agent project: {r.status_code}")
159      pid = r.json().get("project") or r.json().get("id")
160      try:
161          files = [("files", ("a.txt", io.BytesIO(b"x"), "text/plain"))]
162          r = client.post(f"/projects/{pid}/ingest-bulk", files=files, auth=ADMIN)
163          assert r.status_code == 400
164          assert "rag" in r.json()["detail"].lower()
165      finally:
166          client.delete(f"/projects/{pid}", auth=ADMIN)