/ restai / routers / bulk_ingest.py
bulk_ingest.py
  1  """Bulk file ingest queue for RAG projects.
  2  
  3  Three endpoints, all scoped to a project the user has access to:
  4  
  5  * ``POST /projects/{id}/ingest-bulk`` — accepts one or more files,
  6    writes each to a tempfile, creates a ``queued`` row in
  7    ``bulk_ingest_jobs``, and returns the job ids. Returns 202 because
  8    the actual ingest happens in the cron.
  9  * ``GET /projects/{id}/ingest-bulk`` — paginated list of recent jobs
 10    for the project, newest first, so the admin UI can poll / render a
 11    progress table.
 12  * ``DELETE /projects/{id}/ingest-bulk/{jobID}`` — cancel/reap a job.
 13    If it's still queued, marks as ``error`` ("cancelled") and deletes
 14    the tempfile. Done/error rows are just deleted outright.
 15  """
 16  from __future__ import annotations
 17  
 18  import os
 19  import shutil
 20  import tempfile
 21  from datetime import datetime, timezone
 22  
 23  from fastapi import APIRouter, Depends, HTTPException, Path as PathParam, Query, UploadFile
 24  from sqlalchemy.orm import Session
 25  
 26  from restai import config
 27  from restai.auth import check_not_restricted, get_current_username_project
 28  from restai.database import DBWrapper, get_db_wrapper
 29  from restai.models.databasemodels import BulkIngestJobDatabase, ProjectDatabase
 30  from restai.models.models import User, sanitize_filename
 31  
 32  
 33  router = APIRouter()
 34  
 35  
 36  # On-disk staging area for queued uploads. One subdir so clean-up /
 37  # permissions are easy to reason about. Created lazily on first write.
 38  _QUEUE_DIR = os.path.join(tempfile.gettempdir(), "restai_bulk_ingest")
 39  
 40  
 41  def _ensure_queue_dir() -> str:
 42      os.makedirs(_QUEUE_DIR, exist_ok=True)
 43      return _QUEUE_DIR
 44  
 45  
 46  def _job_to_dict(job: BulkIngestJobDatabase) -> dict:
 47      return {
 48          "id": job.id,
 49          "project_id": job.project_id,
 50          "filename": job.filename,
 51          "mime_type": job.mime_type,
 52          "size_bytes": job.size_bytes,
 53          "method": job.method,
 54          "status": job.status,
 55          "error_message": job.error_message,
 56          "documents_count": job.documents_count,
 57          "chunks_count": job.chunks_count,
 58          "created_at": job.created_at.isoformat() if job.created_at else None,
 59          "started_at": job.started_at.isoformat() if job.started_at else None,
 60          "completed_at": job.completed_at.isoformat() if job.completed_at else None,
 61      }
 62  
 63  
 64  @router.post("/projects/{projectID}/ingest-bulk", status_code=202, tags=["Knowledge"])
 65  async def enqueue_bulk_ingest(
 66      projectID: int = PathParam(description="Project ID"),
 67      files: list[UploadFile] = ...,
 68      method: str = "auto",
 69      splitter: str = "sentence",
 70      chunks: int = 256,
 71      user: User = Depends(get_current_username_project),
 72      db_wrapper: DBWrapper = Depends(get_db_wrapper),
 73  ):
 74      """Accept one or more files and queue them for async ingestion.
 75      Returns ``{"queued": [job_id, ...]}`` — poll the list endpoint for
 76      status. Only RAG projects accept bulk ingest."""
 77      check_not_restricted(user)
 78      if splitter not in ("sentence", "token"):
 79          raise HTTPException(status_code=422, detail="splitter must be 'sentence' or 'token'")
 80      if not files:
 81          raise HTTPException(status_code=400, detail="No files uploaded")
 82  
 83      project = db_wrapper.get_project_by_id(projectID)
 84      if project is None:
 85          raise HTTPException(status_code=404, detail="Project not found")
 86      if project.type != "rag":
 87          raise HTTPException(status_code=400, detail="Bulk ingest only available for RAG projects")
 88  
 89      max_bytes = config.MAX_UPLOAD_SIZE
 90      queue_dir = _ensure_queue_dir()
 91      queued_ids: list[int] = []
 92  
 93      for upload in files:
 94          safe_name = sanitize_filename(upload.filename or "upload.bin")
 95          contents = await upload.read()
 96          if len(contents) > max_bytes:
 97              # Refuse the whole request so the admin doesn't end up with
 98              # a half-queued batch — easier to reason about than silent
 99              # partial success.
100              raise HTTPException(
101                  status_code=413,
102                  detail=f"'{safe_name}' exceeds max upload size ({max_bytes // (1024*1024)} MB)",
103              )
104  
105          # Tempfile name carries the project + job intent so an admin
106          # inspecting /tmp/restai_bulk_ingest/ can correlate.
107          fd, path = tempfile.mkstemp(prefix=f"proj{projectID}_", suffix=f"_{safe_name}", dir=queue_dir)
108          try:
109              with os.fdopen(fd, "wb") as fh:
110                  fh.write(contents)
111          except Exception:
112              try:
113                  os.unlink(path)
114              except OSError:
115                  pass
116              raise
117  
118          job = BulkIngestJobDatabase(
119              project_id=projectID,
120              filename=safe_name,
121              mime_type=upload.content_type,
122              size_bytes=len(contents),
123              file_path=path,
124              method=method or "auto",
125              splitter=splitter,
126              chunks=chunks,
127              status="queued",
128              created_at=datetime.now(timezone.utc),
129          )
130          db_wrapper.db.add(job)
131          db_wrapper.db.commit()
132          db_wrapper.db.refresh(job)
133          queued_ids.append(job.id)
134  
135      return {"queued": queued_ids, "count": len(queued_ids)}
136  
137  
138  @router.get("/projects/{projectID}/ingest-bulk", tags=["Knowledge"])
139  async def list_bulk_ingest_jobs(
140      projectID: int = PathParam(description="Project ID"),
141      limit: int = Query(50, ge=1, le=500),
142      user: User = Depends(get_current_username_project),
143      db_wrapper: DBWrapper = Depends(get_db_wrapper),
144  ):
145      """Recent bulk-ingest jobs for this project, newest first."""
146      jobs = (
147          db_wrapper.db.query(BulkIngestJobDatabase)
148          .filter(BulkIngestJobDatabase.project_id == projectID)
149          .order_by(BulkIngestJobDatabase.created_at.desc())
150          .limit(limit)
151          .all()
152      )
153      return {"jobs": [_job_to_dict(j) for j in jobs]}
154  
155  
156  @router.delete("/projects/{projectID}/ingest-bulk/{jobID}", tags=["Knowledge"])
157  async def delete_bulk_ingest_job(
158      projectID: int = PathParam(description="Project ID"),
159      jobID: int = PathParam(description="Job ID"),
160      user: User = Depends(get_current_username_project),
161      db_wrapper: DBWrapper = Depends(get_db_wrapper),
162  ):
163      """Cancel or reap a bulk-ingest job. Queued jobs get marked
164      cancelled + tempfile deleted. Done/error rows are deleted
165      outright."""
166      check_not_restricted(user)
167      job = (
168          db_wrapper.db.query(BulkIngestJobDatabase)
169          .filter(
170              BulkIngestJobDatabase.id == jobID,
171              BulkIngestJobDatabase.project_id == projectID,
172          )
173          .first()
174      )
175      if job is None:
176          raise HTTPException(status_code=404, detail="Job not found")
177  
178      # Always try to remove the tempfile if it's still there.
179      if job.file_path:
180          try:
181              os.unlink(job.file_path)
182          except OSError:
183              pass
184  
185      db_wrapper.db.delete(job)
186      db_wrapper.db.commit()
187      return {"deleted": jobID}