bulk_ingest.py
1 """Bulk file ingest queue for RAG projects. 2 3 Three endpoints, all scoped to a project the user has access to: 4 5 * ``POST /projects/{id}/ingest-bulk`` — accepts one or more files, 6 writes each to a tempfile, creates a ``queued`` row in 7 ``bulk_ingest_jobs``, and returns the job ids. Returns 202 because 8 the actual ingest happens in the cron. 9 * ``GET /projects/{id}/ingest-bulk`` — paginated list of recent jobs 10 for the project, newest first, so the admin UI can poll / render a 11 progress table. 12 * ``DELETE /projects/{id}/ingest-bulk/{jobID}`` — cancel/reap a job. 13 If it's still queued, marks as ``error`` ("cancelled") and deletes 14 the tempfile. Done/error rows are just deleted outright. 15 """ 16 from __future__ import annotations 17 18 import os 19 import shutil 20 import tempfile 21 from datetime import datetime, timezone 22 23 from fastapi import APIRouter, Depends, HTTPException, Path as PathParam, Query, UploadFile 24 from sqlalchemy.orm import Session 25 26 from restai import config 27 from restai.auth import check_not_restricted, get_current_username_project 28 from restai.database import DBWrapper, get_db_wrapper 29 from restai.models.databasemodels import BulkIngestJobDatabase, ProjectDatabase 30 from restai.models.models import User, sanitize_filename 31 32 33 router = APIRouter() 34 35 36 # On-disk staging area for queued uploads. One subdir so clean-up / 37 # permissions are easy to reason about. Created lazily on first write. 38 _QUEUE_DIR = os.path.join(tempfile.gettempdir(), "restai_bulk_ingest") 39 40 41 def _ensure_queue_dir() -> str: 42 os.makedirs(_QUEUE_DIR, exist_ok=True) 43 return _QUEUE_DIR 44 45 46 def _job_to_dict(job: BulkIngestJobDatabase) -> dict: 47 return { 48 "id": job.id, 49 "project_id": job.project_id, 50 "filename": job.filename, 51 "mime_type": job.mime_type, 52 "size_bytes": job.size_bytes, 53 "method": job.method, 54 "status": job.status, 55 "error_message": job.error_message, 56 "documents_count": job.documents_count, 57 "chunks_count": job.chunks_count, 58 "created_at": job.created_at.isoformat() if job.created_at else None, 59 "started_at": job.started_at.isoformat() if job.started_at else None, 60 "completed_at": job.completed_at.isoformat() if job.completed_at else None, 61 } 62 63 64 @router.post("/projects/{projectID}/ingest-bulk", status_code=202, tags=["Knowledge"]) 65 async def enqueue_bulk_ingest( 66 projectID: int = PathParam(description="Project ID"), 67 files: list[UploadFile] = ..., 68 method: str = "auto", 69 splitter: str = "sentence", 70 chunks: int = 256, 71 user: User = Depends(get_current_username_project), 72 db_wrapper: DBWrapper = Depends(get_db_wrapper), 73 ): 74 """Accept one or more files and queue them for async ingestion. 75 Returns ``{"queued": [job_id, ...]}`` — poll the list endpoint for 76 status. Only RAG projects accept bulk ingest.""" 77 check_not_restricted(user) 78 if splitter not in ("sentence", "token"): 79 raise HTTPException(status_code=422, detail="splitter must be 'sentence' or 'token'") 80 if not files: 81 raise HTTPException(status_code=400, detail="No files uploaded") 82 83 project = db_wrapper.get_project_by_id(projectID) 84 if project is None: 85 raise HTTPException(status_code=404, detail="Project not found") 86 if project.type != "rag": 87 raise HTTPException(status_code=400, detail="Bulk ingest only available for RAG projects") 88 89 max_bytes = config.MAX_UPLOAD_SIZE 90 queue_dir = _ensure_queue_dir() 91 queued_ids: list[int] = [] 92 93 for upload in files: 94 safe_name = sanitize_filename(upload.filename or "upload.bin") 95 contents = await upload.read() 96 if len(contents) > max_bytes: 97 # Refuse the whole request so the admin doesn't end up with 98 # a half-queued batch — easier to reason about than silent 99 # partial success. 100 raise HTTPException( 101 status_code=413, 102 detail=f"'{safe_name}' exceeds max upload size ({max_bytes // (1024*1024)} MB)", 103 ) 104 105 # Tempfile name carries the project + job intent so an admin 106 # inspecting /tmp/restai_bulk_ingest/ can correlate. 107 fd, path = tempfile.mkstemp(prefix=f"proj{projectID}_", suffix=f"_{safe_name}", dir=queue_dir) 108 try: 109 with os.fdopen(fd, "wb") as fh: 110 fh.write(contents) 111 except Exception: 112 try: 113 os.unlink(path) 114 except OSError: 115 pass 116 raise 117 118 job = BulkIngestJobDatabase( 119 project_id=projectID, 120 filename=safe_name, 121 mime_type=upload.content_type, 122 size_bytes=len(contents), 123 file_path=path, 124 method=method or "auto", 125 splitter=splitter, 126 chunks=chunks, 127 status="queued", 128 created_at=datetime.now(timezone.utc), 129 ) 130 db_wrapper.db.add(job) 131 db_wrapper.db.commit() 132 db_wrapper.db.refresh(job) 133 queued_ids.append(job.id) 134 135 return {"queued": queued_ids, "count": len(queued_ids)} 136 137 138 @router.get("/projects/{projectID}/ingest-bulk", tags=["Knowledge"]) 139 async def list_bulk_ingest_jobs( 140 projectID: int = PathParam(description="Project ID"), 141 limit: int = Query(50, ge=1, le=500), 142 user: User = Depends(get_current_username_project), 143 db_wrapper: DBWrapper = Depends(get_db_wrapper), 144 ): 145 """Recent bulk-ingest jobs for this project, newest first.""" 146 jobs = ( 147 db_wrapper.db.query(BulkIngestJobDatabase) 148 .filter(BulkIngestJobDatabase.project_id == projectID) 149 .order_by(BulkIngestJobDatabase.created_at.desc()) 150 .limit(limit) 151 .all() 152 ) 153 return {"jobs": [_job_to_dict(j) for j in jobs]} 154 155 156 @router.delete("/projects/{projectID}/ingest-bulk/{jobID}", tags=["Knowledge"]) 157 async def delete_bulk_ingest_job( 158 projectID: int = PathParam(description="Project ID"), 159 jobID: int = PathParam(description="Job ID"), 160 user: User = Depends(get_current_username_project), 161 db_wrapper: DBWrapper = Depends(get_db_wrapper), 162 ): 163 """Cancel or reap a bulk-ingest job. Queued jobs get marked 164 cancelled + tempfile deleted. Done/error rows are deleted 165 outright.""" 166 check_not_restricted(user) 167 job = ( 168 db_wrapper.db.query(BulkIngestJobDatabase) 169 .filter( 170 BulkIngestJobDatabase.id == jobID, 171 BulkIngestJobDatabase.project_id == projectID, 172 ) 173 .first() 174 ) 175 if job is None: 176 raise HTTPException(status_code=404, detail="Job not found") 177 178 # Always try to remove the tempfile if it's still there. 179 if job.file_path: 180 try: 181 os.unlink(job.file_path) 182 except OSError: 183 pass 184 185 db_wrapper.db.delete(job) 186 db_wrapper.db.commit() 187 return {"deleted": jobID}