sync.py
1 """Knowledge Base Sync — source-specific sync functions used by the cron script and manual trigger.""" 2 3 import json 4 import logging 5 import os 6 import tempfile 7 import threading 8 from collections import defaultdict 9 from datetime import datetime, timezone 10 11 from restai.database import get_db_wrapper 12 13 logger = logging.getLogger(__name__) 14 15 16 def _extract_entities_for_documents(project, documents, db, brain): 17 """Group documents by source metadata and extract entities for each source. 18 Only runs if knowledge graph is enabled on the project.""" 19 if not project.props.options.enable_knowledge_graph: 20 return 21 if not brain: 22 return 23 try: 24 from restai.knowledge_graph import extract_and_persist 25 grouped = defaultdict(list) 26 for doc in documents: 27 src = doc.metadata.get("source") if hasattr(doc, "metadata") else None 28 if src: 29 grouped[src].append(doc.text) 30 for src, texts in grouped.items(): 31 try: 32 extract_and_persist(project.props.id, src, "\n".join(texts), brain, db) 33 except Exception as e: 34 logger.warning(f"Entity extraction failed for source '{src}': {e}") 35 except Exception as e: 36 logger.warning(f"Sync entity extraction failed: {e}") 37 38 39 def _sync_source(project, source, db, brain=None): 40 """Sync a single SyncSource into the project's knowledge base.""" 41 if source.type == "url": 42 _sync_url(project, source, db, brain) 43 elif source.type == "s3": 44 _sync_s3(project, source, db, brain) 45 elif source.type == "confluence": 46 _sync_confluence(project, source, db, brain) 47 elif source.type == "sharepoint": 48 _sync_sharepoint(project, source, db, brain) 49 elif source.type == "gdrive": 50 _sync_gdrive(project, source, db, brain) 51 else: 52 logger.warning(f"Unknown sync source type: {source.type}") 53 54 55 def _sync_url(project, source, db, brain=None): 56 """Sync a web URL source.""" 57 from urllib.parse import urlparse 58 from restai.helper import _is_private_ip 59 from restai.loaders.url import SeleniumWebReader 60 from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata 61 62 logger.info(f"Syncing URL source '{source.name}': {source.url}") 63 64 # SSRF guard — refuse loopback / private / link-local destinations. 65 # An admin-configured source still shouldn't be allowed to point the 66 # cron's headless browser at the cloud metadata service or an 67 # internal LAN host. 68 try: 69 hostname = urlparse(source.url).hostname 70 except Exception: 71 logger.warning(f"Skipping URL source '{source.name}': invalid url {source.url!r}") 72 return 73 if not hostname: 74 logger.warning(f"Skipping URL source '{source.name}': no hostname in {source.url!r}") 75 return 76 try: 77 if _is_private_ip(hostname): 78 logger.warning( 79 f"Skipping URL source '{source.name}': refusing to sync private/internal address {hostname}" 80 ) 81 return 82 except ValueError as e: 83 logger.warning(f"Skipping URL source '{source.name}': {e}") 84 return 85 86 loader = SeleniumWebReader() 87 documents = loader.load_data(urls=[source.url]) 88 documents = extract_keywords_for_metadata(documents) 89 90 for doc in documents: 91 doc.metadata["source"] = source.name 92 93 if project.vector: 94 try: 95 deleted = project.vector.delete_source(source.name) 96 logger.info(f"Deleted {len(deleted) if deleted else 0} old chunks for source '{source.name}'") 97 except Exception as e: 98 logger.warning(f"Failed to delete old chunks for source '{source.name}': {e}") 99 n_chunks = index_documents_classic(project, documents, source.splitter, source.chunks) 100 project.vector.save() 101 _extract_entities_for_documents(project, documents, db, brain) 102 logger.info(f"URL source '{source.name}' synced: {len(documents)} documents, {n_chunks} chunks") 103 104 105 def _sync_s3(project, source, db, brain=None): 106 """Sync files from an S3 bucket.""" 107 from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata 108 from modules.loaders import find_file_loader 109 110 try: 111 import boto3 112 except ImportError: 113 raise RuntimeError("boto3 is required for S3 sync. Install with: pip install boto3") 114 115 logger.info(f"Syncing S3 source '{source.name}': s3://{source.s3_bucket}/{source.s3_prefix or ''}") 116 117 client_kwargs = {} 118 if source.s3_region: 119 client_kwargs["region_name"] = source.s3_region 120 if source.s3_access_key and source.s3_secret_key: 121 client_kwargs["aws_access_key_id"] = source.s3_access_key 122 client_kwargs["aws_secret_access_key"] = source.s3_secret_key 123 124 s3 = boto3.client("s3", **client_kwargs) 125 126 list_kwargs = {"Bucket": source.s3_bucket} 127 if source.s3_prefix: 128 list_kwargs["Prefix"] = source.s3_prefix 129 130 all_documents = [] 131 paginator = s3.get_paginator("list_objects_v2") 132 for page in paginator.paginate(**list_kwargs): 133 for obj in page.get("Contents", []): 134 key = obj["Key"] 135 if key.endswith("/"): 136 continue 137 138 ext = os.path.splitext(key)[1].lower() 139 loader_cls = find_file_loader(ext) 140 if loader_cls is None: 141 logger.debug(f"Skipping unsupported file type: {key}") 142 continue 143 144 with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: 145 s3.download_fileobj(source.s3_bucket, key, tmp) 146 tmp_path = tmp.name 147 148 try: 149 loader = loader_cls() 150 docs = loader.load_data(file=tmp_path) 151 doc_source = f"{source.name}/{os.path.basename(key)}" 152 for doc in docs: 153 doc.metadata["source"] = doc_source 154 from restai.vectordb.tools import extract_keywords_for_metadata 155 docs = extract_keywords_for_metadata(docs) 156 all_documents.extend(docs) 157 finally: 158 os.unlink(tmp_path) 159 160 if not all_documents: 161 logger.info(f"S3 source '{source.name}': no documents found") 162 return 163 164 if project.vector: 165 try: 166 existing_sources = project.vector.list() 167 for src in existing_sources: 168 if src == source.name or src.startswith(f"{source.name}/"): 169 project.vector.delete_source(src) 170 except Exception as e: 171 logger.warning(f"Failed to delete old S3 chunks for '{source.name}': {e}") 172 173 from restai.vectordb.tools import index_documents_classic 174 n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks) 175 project.vector.save() 176 _extract_entities_for_documents(project, all_documents, db, brain) 177 logger.info(f"S3 source '{source.name}' synced: {len(all_documents)} documents, {n_chunks} chunks") 178 179 180 def _sync_confluence(project, source, db, brain=None): 181 """Sync pages from a Confluence Cloud space.""" 182 import requests 183 from llama_index.core.schema import Document 184 from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata 185 186 base_url = (source.confluence_base_url or "").rstrip("/") 187 space_key = source.confluence_space_key 188 email = source.confluence_email 189 api_token = source.confluence_api_token 190 191 if not base_url or not space_key or not email or not api_token: 192 raise ValueError("Confluence source requires base_url, space_key, email, and api_token") 193 194 logger.info(f"Syncing Confluence source '{source.name}': {base_url}/wiki/spaces/{space_key}") 195 196 auth = (email, api_token) 197 headers = {"Accept": "application/json"} 198 199 all_documents = [] 200 url = f"{base_url}/wiki/api/v2/spaces/{space_key}/pages" 201 params = {"limit": 50, "body-format": "storage"} 202 203 while url: 204 resp = requests.get(url, auth=auth, headers=headers, params=params, timeout=30) 205 resp.raise_for_status() 206 data = resp.json() 207 208 for page in data.get("results", []): 209 title = page.get("title", "Untitled") 210 body_html = page.get("body", {}).get("storage", {}).get("value", "") 211 212 if not body_html: 213 continue 214 215 from html.parser import HTMLParser 216 from io import StringIO 217 218 class _HTMLStripper(HTMLParser): 219 def __init__(self): 220 super().__init__() 221 self._text = StringIO() 222 def handle_data(self, d): 223 self._text.write(d) 224 def get_text(self): 225 return self._text.getvalue() 226 227 stripper = _HTMLStripper() 228 stripper.feed(body_html) 229 text = stripper.get_text().strip() 230 231 if not text: 232 continue 233 234 doc_source = f"{source.name}/{title}" 235 all_documents.append(Document( 236 text=text, 237 metadata={"source": doc_source, "title": title, "url": f"{base_url}/wiki/spaces/{space_key}/pages/{page.get('id', '')}"}, 238 )) 239 240 next_link = data.get("_links", {}).get("next") 241 if next_link: 242 url = f"{base_url}{next_link}" if next_link.startswith("/") else next_link 243 params = None 244 else: 245 url = None 246 247 if not all_documents: 248 logger.info(f"Confluence source '{source.name}': no pages found") 249 return 250 251 all_documents = extract_keywords_for_metadata(all_documents) 252 253 if project.vector: 254 try: 255 existing_sources = project.vector.list() 256 for src in existing_sources: 257 if src == source.name or src.startswith(f"{source.name}/"): 258 project.vector.delete_source(src) 259 except Exception as e: 260 logger.warning(f"Failed to delete old Confluence chunks for '{source.name}': {e}") 261 262 n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks) 263 project.vector.save() 264 _extract_entities_for_documents(project, all_documents, db, brain) 265 logger.info(f"Confluence source '{source.name}' synced: {len(all_documents)} pages, {n_chunks} chunks") 266 267 268 def _sync_sharepoint(project, source, db, brain=None): 269 """Sync files from a SharePoint Online document library via Microsoft Graph API.""" 270 import requests as req 271 from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata 272 from modules.loaders import find_file_loader 273 274 tenant_id = source.sharepoint_tenant_id 275 client_id = source.sharepoint_client_id 276 client_secret = source.sharepoint_client_secret 277 site_name = source.sharepoint_site_name 278 folder_path = source.sharepoint_folder 279 280 if not tenant_id or not client_id or not client_secret or not site_name: 281 raise ValueError("SharePoint source requires tenant_id, client_id, client_secret, and site_name") 282 283 logger.info(f"Syncing SharePoint source '{source.name}': site={site_name}") 284 285 token_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token" 286 token_resp = req.post(token_url, data={ 287 "grant_type": "client_credentials", 288 "client_id": client_id, 289 "client_secret": client_secret, 290 "scope": "https://graph.microsoft.com/.default", 291 }, timeout=15) 292 token_resp.raise_for_status() 293 access_token = token_resp.json()["access_token"] 294 295 headers = {"Authorization": f"Bearer {access_token}"} 296 graph = "https://graph.microsoft.com/v1.0" 297 298 site_resp = req.get(f"{graph}/sites?search={site_name}", headers=headers, timeout=15) 299 site_resp.raise_for_status() 300 sites = site_resp.json().get("value", []) 301 if not sites: 302 raise ValueError(f"SharePoint site '{site_name}' not found") 303 site_id = sites[0]["id"] 304 305 drive_resp = req.get(f"{graph}/sites/{site_id}/drive", headers=headers, timeout=15) 306 drive_resp.raise_for_status() 307 drive_id = drive_resp.json()["id"] 308 309 if folder_path: 310 folder_path = folder_path.strip("/") 311 list_url = f"{graph}/drives/{drive_id}/root:/{folder_path}:/children" 312 else: 313 list_url = f"{graph}/drives/{drive_id}/root/children" 314 315 all_documents = [] 316 url = list_url 317 318 while url: 319 resp = req.get(url, headers=headers, timeout=30) 320 resp.raise_for_status() 321 data = resp.json() 322 323 for item in data.get("value", []): 324 if "folder" in item: 325 continue 326 327 name = item.get("name", "") 328 ext = os.path.splitext(name)[1].lower() 329 loader_cls = find_file_loader(ext) 330 if loader_cls is None: 331 logger.debug(f"Skipping unsupported file type: {name}") 332 continue 333 334 download_url = item.get("@microsoft.graph.downloadUrl") 335 if not download_url: 336 continue 337 338 file_resp = req.get(download_url, timeout=60) 339 file_resp.raise_for_status() 340 341 with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: 342 tmp.write(file_resp.content) 343 tmp_path = tmp.name 344 345 try: 346 loader = loader_cls() 347 docs = loader.load_data(file=tmp_path) 348 doc_source = f"{source.name}/{name}" 349 for doc in docs: 350 doc.metadata["source"] = doc_source 351 docs = extract_keywords_for_metadata(docs) 352 all_documents.extend(docs) 353 finally: 354 os.unlink(tmp_path) 355 356 url = data.get("@odata.nextLink") 357 358 if not all_documents: 359 logger.info(f"SharePoint source '{source.name}': no documents found") 360 return 361 362 if project.vector: 363 try: 364 existing_sources = project.vector.list() 365 for src in existing_sources: 366 if src == source.name or src.startswith(f"{source.name}/"): 367 project.vector.delete_source(src) 368 except Exception as e: 369 logger.warning(f"Failed to delete old SharePoint chunks for '{source.name}': {e}") 370 371 n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks) 372 project.vector.save() 373 _extract_entities_for_documents(project, all_documents, db, brain) 374 logger.info(f"SharePoint source '{source.name}' synced: {len(all_documents)} files, {n_chunks} chunks") 375 376 377 def _sync_gdrive(project, source, db, brain=None): 378 """Sync files from a Google Drive folder via service account.""" 379 import requests as req 380 from llama_index.core.schema import Document 381 from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata 382 from modules.loaders import find_file_loader 383 384 sa_json = source.gdrive_service_account_json 385 folder_id = source.gdrive_folder_id 386 387 if not sa_json or not folder_id: 388 raise ValueError("Google Drive source requires service_account_json and folder_id") 389 390 logger.info(f"Syncing Google Drive source '{source.name}': folder={folder_id}") 391 392 import json as _json 393 import time as _time 394 import jwt as _jwt 395 396 sa = _json.loads(sa_json) 397 now = int(_time.time()) 398 payload = { 399 "iss": sa["client_email"], 400 "scope": "https://www.googleapis.com/auth/drive.readonly", 401 "aud": "https://oauth2.googleapis.com/token", 402 "iat": now, 403 "exp": now + 3600, 404 } 405 signed_jwt = _jwt.encode(payload, sa["private_key"], algorithm="RS256") 406 407 token_resp = req.post("https://oauth2.googleapis.com/token", data={ 408 "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", 409 "assertion": signed_jwt, 410 }, timeout=15) 411 token_resp.raise_for_status() 412 access_token = token_resp.json()["access_token"] 413 414 headers = {"Authorization": f"Bearer {access_token}"} 415 416 all_documents = [] 417 page_token = None 418 query = f"'{folder_id}' in parents and trashed = false and mimeType != 'application/vnd.google-apps.folder'" 419 420 while True: 421 params = { 422 "q": query, 423 "fields": "nextPageToken, files(id, name, mimeType)", 424 "pageSize": 100, 425 } 426 if page_token: 427 params["pageToken"] = page_token 428 429 resp = req.get("https://www.googleapis.com/drive/v3/files", headers=headers, params=params, timeout=30) 430 resp.raise_for_status() 431 data = resp.json() 432 433 for item in data.get("files", []): 434 name = item["name"] 435 mime = item["mimeType"] 436 file_id = item["id"] 437 438 export_mime = None 439 export_ext = None 440 if mime == "application/vnd.google-apps.document": 441 export_mime = "text/plain" 442 export_ext = ".txt" 443 elif mime == "application/vnd.google-apps.spreadsheet": 444 export_mime = "text/csv" 445 export_ext = ".csv" 446 elif mime == "application/vnd.google-apps.presentation": 447 export_mime = "text/plain" 448 export_ext = ".txt" 449 else: 450 ext = os.path.splitext(name)[1].lower() 451 if not find_file_loader(ext): 452 logger.debug(f"Skipping unsupported file: {name}") 453 continue 454 export_ext = ext 455 456 if export_mime: 457 dl_resp = req.get( 458 f"https://www.googleapis.com/drive/v3/files/{file_id}/export", 459 headers=headers, params={"mimeType": export_mime}, timeout=60, 460 ) 461 else: 462 dl_resp = req.get( 463 f"https://www.googleapis.com/drive/v3/files/{file_id}?alt=media", 464 headers=headers, timeout=60, 465 ) 466 dl_resp.raise_for_status() 467 468 if export_mime and export_mime.startswith("text/"): 469 text = dl_resp.text.strip() 470 if text: 471 doc_source = f"{source.name}/{name}" 472 all_documents.append(Document( 473 text=text, 474 metadata={"source": doc_source, "title": name}, 475 )) 476 continue 477 478 loader_cls = find_file_loader(export_ext) 479 if not loader_cls: 480 continue 481 482 with tempfile.NamedTemporaryFile(suffix=export_ext, delete=False) as tmp: 483 tmp.write(dl_resp.content) 484 tmp_path = tmp.name 485 486 try: 487 loader = loader_cls() 488 docs = loader.load_data(file=tmp_path) 489 doc_source = f"{source.name}/{name}" 490 for doc in docs: 491 doc.metadata["source"] = doc_source 492 all_documents.extend(docs) 493 finally: 494 os.unlink(tmp_path) 495 496 page_token = data.get("nextPageToken") 497 if not page_token: 498 break 499 500 if not all_documents: 501 logger.info(f"Google Drive source '{source.name}': no documents found") 502 return 503 504 all_documents = extract_keywords_for_metadata(all_documents) 505 506 if project.vector: 507 try: 508 existing_sources = project.vector.list() 509 for src in existing_sources: 510 if src == source.name or src.startswith(f"{source.name}/"): 511 project.vector.delete_source(src) 512 except Exception as e: 513 logger.warning(f"Failed to delete old Google Drive chunks for '{source.name}': {e}") 514 515 n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks) 516 project.vector.save() 517 _extract_entities_for_documents(project, all_documents, db, brain) 518 logger.info(f"Google Drive source '{source.name}' synced: {len(all_documents)} files, {n_chunks} chunks") 519 520 521 # --- Manual trigger (used by the "Sync Now" button in the frontend) --- 522 523 def run_sync_now(project_id: int, brain): 524 """Run a one-off sync in a background thread. Ignores intervals — syncs all sources immediately.""" 525 def _run(): 526 from restai.models.databasemodels import ProjectDatabase 527 528 db = get_db_wrapper() 529 try: 530 project = brain.find_project(project_id, db) 531 if not project or project.props.type != "rag": 532 return 533 opts = project.props.options 534 sources = opts.sync_sources if opts else None 535 if not sources: 536 return 537 for i, source in enumerate(sources): 538 try: 539 logger.info(f"Manual sync source '{source.name}' for project {project_id}") 540 _sync_source(project, source, db, brain) 541 proj_db = db.db.query(ProjectDatabase).filter(ProjectDatabase.id == project_id).first() 542 if proj_db: 543 current_opts = json.loads(proj_db.options) if proj_db.options else {} 544 src_list = current_opts.get("sync_sources", []) 545 if i < len(src_list): 546 src_list[i]["last_sync"] = datetime.now(timezone.utc).isoformat() 547 current_opts["sync_sources"] = src_list 548 proj_db.options = json.dumps(current_opts) 549 db.db.commit() 550 except Exception as e: 551 logger.error(f"Manual sync failed for source '{source.name}': {e}") 552 finally: 553 db.db.close() 554 555 threading.Thread(target=_run, name=f"sync-manual-{project_id}", daemon=True).start()