/ restai / sync.py
sync.py
  1  """Knowledge Base Sync — source-specific sync functions used by the cron script and manual trigger."""
  2  
  3  import json
  4  import logging
  5  import os
  6  import tempfile
  7  import threading
  8  from collections import defaultdict
  9  from datetime import datetime, timezone
 10  
 11  from restai.database import get_db_wrapper
 12  
 13  logger = logging.getLogger(__name__)
 14  
 15  
 16  def _extract_entities_for_documents(project, documents, db, brain):
 17      """Group documents by source metadata and extract entities for each source.
 18      Only runs if knowledge graph is enabled on the project."""
 19      if not project.props.options.enable_knowledge_graph:
 20          return
 21      if not brain:
 22          return
 23      try:
 24          from restai.knowledge_graph import extract_and_persist
 25          grouped = defaultdict(list)
 26          for doc in documents:
 27              src = doc.metadata.get("source") if hasattr(doc, "metadata") else None
 28              if src:
 29                  grouped[src].append(doc.text)
 30          for src, texts in grouped.items():
 31              try:
 32                  extract_and_persist(project.props.id, src, "\n".join(texts), brain, db)
 33              except Exception as e:
 34                  logger.warning(f"Entity extraction failed for source '{src}': {e}")
 35      except Exception as e:
 36          logger.warning(f"Sync entity extraction failed: {e}")
 37  
 38  
 39  def _sync_source(project, source, db, brain=None):
 40      """Sync a single SyncSource into the project's knowledge base."""
 41      if source.type == "url":
 42          _sync_url(project, source, db, brain)
 43      elif source.type == "s3":
 44          _sync_s3(project, source, db, brain)
 45      elif source.type == "confluence":
 46          _sync_confluence(project, source, db, brain)
 47      elif source.type == "sharepoint":
 48          _sync_sharepoint(project, source, db, brain)
 49      elif source.type == "gdrive":
 50          _sync_gdrive(project, source, db, brain)
 51      else:
 52          logger.warning(f"Unknown sync source type: {source.type}")
 53  
 54  
 55  def _sync_url(project, source, db, brain=None):
 56      """Sync a web URL source."""
 57      from urllib.parse import urlparse
 58      from restai.helper import _is_private_ip
 59      from restai.loaders.url import SeleniumWebReader
 60      from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata
 61  
 62      logger.info(f"Syncing URL source '{source.name}': {source.url}")
 63  
 64      # SSRF guard — refuse loopback / private / link-local destinations.
 65      # An admin-configured source still shouldn't be allowed to point the
 66      # cron's headless browser at the cloud metadata service or an
 67      # internal LAN host.
 68      try:
 69          hostname = urlparse(source.url).hostname
 70      except Exception:
 71          logger.warning(f"Skipping URL source '{source.name}': invalid url {source.url!r}")
 72          return
 73      if not hostname:
 74          logger.warning(f"Skipping URL source '{source.name}': no hostname in {source.url!r}")
 75          return
 76      try:
 77          if _is_private_ip(hostname):
 78              logger.warning(
 79                  f"Skipping URL source '{source.name}': refusing to sync private/internal address {hostname}"
 80              )
 81              return
 82      except ValueError as e:
 83          logger.warning(f"Skipping URL source '{source.name}': {e}")
 84          return
 85  
 86      loader = SeleniumWebReader()
 87      documents = loader.load_data(urls=[source.url])
 88      documents = extract_keywords_for_metadata(documents)
 89  
 90      for doc in documents:
 91          doc.metadata["source"] = source.name
 92  
 93      if project.vector:
 94          try:
 95              deleted = project.vector.delete_source(source.name)
 96              logger.info(f"Deleted {len(deleted) if deleted else 0} old chunks for source '{source.name}'")
 97          except Exception as e:
 98              logger.warning(f"Failed to delete old chunks for source '{source.name}': {e}")
 99      n_chunks = index_documents_classic(project, documents, source.splitter, source.chunks)
100      project.vector.save()
101      _extract_entities_for_documents(project, documents, db, brain)
102      logger.info(f"URL source '{source.name}' synced: {len(documents)} documents, {n_chunks} chunks")
103  
104  
105  def _sync_s3(project, source, db, brain=None):
106      """Sync files from an S3 bucket."""
107      from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata
108      from modules.loaders import find_file_loader
109  
110      try:
111          import boto3
112      except ImportError:
113          raise RuntimeError("boto3 is required for S3 sync. Install with: pip install boto3")
114  
115      logger.info(f"Syncing S3 source '{source.name}': s3://{source.s3_bucket}/{source.s3_prefix or ''}")
116  
117      client_kwargs = {}
118      if source.s3_region:
119          client_kwargs["region_name"] = source.s3_region
120      if source.s3_access_key and source.s3_secret_key:
121          client_kwargs["aws_access_key_id"] = source.s3_access_key
122          client_kwargs["aws_secret_access_key"] = source.s3_secret_key
123  
124      s3 = boto3.client("s3", **client_kwargs)
125  
126      list_kwargs = {"Bucket": source.s3_bucket}
127      if source.s3_prefix:
128          list_kwargs["Prefix"] = source.s3_prefix
129  
130      all_documents = []
131      paginator = s3.get_paginator("list_objects_v2")
132      for page in paginator.paginate(**list_kwargs):
133          for obj in page.get("Contents", []):
134              key = obj["Key"]
135              if key.endswith("/"):
136                  continue
137  
138              ext = os.path.splitext(key)[1].lower()
139              loader_cls = find_file_loader(ext)
140              if loader_cls is None:
141                  logger.debug(f"Skipping unsupported file type: {key}")
142                  continue
143  
144              with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
145                  s3.download_fileobj(source.s3_bucket, key, tmp)
146                  tmp_path = tmp.name
147  
148              try:
149                  loader = loader_cls()
150                  docs = loader.load_data(file=tmp_path)
151                  doc_source = f"{source.name}/{os.path.basename(key)}"
152                  for doc in docs:
153                      doc.metadata["source"] = doc_source
154                  from restai.vectordb.tools import extract_keywords_for_metadata
155                  docs = extract_keywords_for_metadata(docs)
156                  all_documents.extend(docs)
157              finally:
158                  os.unlink(tmp_path)
159  
160      if not all_documents:
161          logger.info(f"S3 source '{source.name}': no documents found")
162          return
163  
164      if project.vector:
165          try:
166              existing_sources = project.vector.list()
167              for src in existing_sources:
168                  if src == source.name or src.startswith(f"{source.name}/"):
169                      project.vector.delete_source(src)
170          except Exception as e:
171              logger.warning(f"Failed to delete old S3 chunks for '{source.name}': {e}")
172  
173      from restai.vectordb.tools import index_documents_classic
174      n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks)
175      project.vector.save()
176      _extract_entities_for_documents(project, all_documents, db, brain)
177      logger.info(f"S3 source '{source.name}' synced: {len(all_documents)} documents, {n_chunks} chunks")
178  
179  
180  def _sync_confluence(project, source, db, brain=None):
181      """Sync pages from a Confluence Cloud space."""
182      import requests
183      from llama_index.core.schema import Document
184      from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata
185  
186      base_url = (source.confluence_base_url or "").rstrip("/")
187      space_key = source.confluence_space_key
188      email = source.confluence_email
189      api_token = source.confluence_api_token
190  
191      if not base_url or not space_key or not email or not api_token:
192          raise ValueError("Confluence source requires base_url, space_key, email, and api_token")
193  
194      logger.info(f"Syncing Confluence source '{source.name}': {base_url}/wiki/spaces/{space_key}")
195  
196      auth = (email, api_token)
197      headers = {"Accept": "application/json"}
198  
199      all_documents = []
200      url = f"{base_url}/wiki/api/v2/spaces/{space_key}/pages"
201      params = {"limit": 50, "body-format": "storage"}
202  
203      while url:
204          resp = requests.get(url, auth=auth, headers=headers, params=params, timeout=30)
205          resp.raise_for_status()
206          data = resp.json()
207  
208          for page in data.get("results", []):
209              title = page.get("title", "Untitled")
210              body_html = page.get("body", {}).get("storage", {}).get("value", "")
211  
212              if not body_html:
213                  continue
214  
215              from html.parser import HTMLParser
216              from io import StringIO
217  
218              class _HTMLStripper(HTMLParser):
219                  def __init__(self):
220                      super().__init__()
221                      self._text = StringIO()
222                  def handle_data(self, d):
223                      self._text.write(d)
224                  def get_text(self):
225                      return self._text.getvalue()
226  
227              stripper = _HTMLStripper()
228              stripper.feed(body_html)
229              text = stripper.get_text().strip()
230  
231              if not text:
232                  continue
233  
234              doc_source = f"{source.name}/{title}"
235              all_documents.append(Document(
236                  text=text,
237                  metadata={"source": doc_source, "title": title, "url": f"{base_url}/wiki/spaces/{space_key}/pages/{page.get('id', '')}"},
238              ))
239  
240          next_link = data.get("_links", {}).get("next")
241          if next_link:
242              url = f"{base_url}{next_link}" if next_link.startswith("/") else next_link
243              params = None
244          else:
245              url = None
246  
247      if not all_documents:
248          logger.info(f"Confluence source '{source.name}': no pages found")
249          return
250  
251      all_documents = extract_keywords_for_metadata(all_documents)
252  
253      if project.vector:
254          try:
255              existing_sources = project.vector.list()
256              for src in existing_sources:
257                  if src == source.name or src.startswith(f"{source.name}/"):
258                      project.vector.delete_source(src)
259          except Exception as e:
260              logger.warning(f"Failed to delete old Confluence chunks for '{source.name}': {e}")
261  
262      n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks)
263      project.vector.save()
264      _extract_entities_for_documents(project, all_documents, db, brain)
265      logger.info(f"Confluence source '{source.name}' synced: {len(all_documents)} pages, {n_chunks} chunks")
266  
267  
268  def _sync_sharepoint(project, source, db, brain=None):
269      """Sync files from a SharePoint Online document library via Microsoft Graph API."""
270      import requests as req
271      from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata
272      from modules.loaders import find_file_loader
273  
274      tenant_id = source.sharepoint_tenant_id
275      client_id = source.sharepoint_client_id
276      client_secret = source.sharepoint_client_secret
277      site_name = source.sharepoint_site_name
278      folder_path = source.sharepoint_folder
279  
280      if not tenant_id or not client_id or not client_secret or not site_name:
281          raise ValueError("SharePoint source requires tenant_id, client_id, client_secret, and site_name")
282  
283      logger.info(f"Syncing SharePoint source '{source.name}': site={site_name}")
284  
285      token_url = f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
286      token_resp = req.post(token_url, data={
287          "grant_type": "client_credentials",
288          "client_id": client_id,
289          "client_secret": client_secret,
290          "scope": "https://graph.microsoft.com/.default",
291      }, timeout=15)
292      token_resp.raise_for_status()
293      access_token = token_resp.json()["access_token"]
294  
295      headers = {"Authorization": f"Bearer {access_token}"}
296      graph = "https://graph.microsoft.com/v1.0"
297  
298      site_resp = req.get(f"{graph}/sites?search={site_name}", headers=headers, timeout=15)
299      site_resp.raise_for_status()
300      sites = site_resp.json().get("value", [])
301      if not sites:
302          raise ValueError(f"SharePoint site '{site_name}' not found")
303      site_id = sites[0]["id"]
304  
305      drive_resp = req.get(f"{graph}/sites/{site_id}/drive", headers=headers, timeout=15)
306      drive_resp.raise_for_status()
307      drive_id = drive_resp.json()["id"]
308  
309      if folder_path:
310          folder_path = folder_path.strip("/")
311          list_url = f"{graph}/drives/{drive_id}/root:/{folder_path}:/children"
312      else:
313          list_url = f"{graph}/drives/{drive_id}/root/children"
314  
315      all_documents = []
316      url = list_url
317  
318      while url:
319          resp = req.get(url, headers=headers, timeout=30)
320          resp.raise_for_status()
321          data = resp.json()
322  
323          for item in data.get("value", []):
324              if "folder" in item:
325                  continue
326  
327              name = item.get("name", "")
328              ext = os.path.splitext(name)[1].lower()
329              loader_cls = find_file_loader(ext)
330              if loader_cls is None:
331                  logger.debug(f"Skipping unsupported file type: {name}")
332                  continue
333  
334              download_url = item.get("@microsoft.graph.downloadUrl")
335              if not download_url:
336                  continue
337  
338              file_resp = req.get(download_url, timeout=60)
339              file_resp.raise_for_status()
340  
341              with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
342                  tmp.write(file_resp.content)
343                  tmp_path = tmp.name
344  
345              try:
346                  loader = loader_cls()
347                  docs = loader.load_data(file=tmp_path)
348                  doc_source = f"{source.name}/{name}"
349                  for doc in docs:
350                      doc.metadata["source"] = doc_source
351                  docs = extract_keywords_for_metadata(docs)
352                  all_documents.extend(docs)
353              finally:
354                  os.unlink(tmp_path)
355  
356          url = data.get("@odata.nextLink")
357  
358      if not all_documents:
359          logger.info(f"SharePoint source '{source.name}': no documents found")
360          return
361  
362      if project.vector:
363          try:
364              existing_sources = project.vector.list()
365              for src in existing_sources:
366                  if src == source.name or src.startswith(f"{source.name}/"):
367                      project.vector.delete_source(src)
368          except Exception as e:
369              logger.warning(f"Failed to delete old SharePoint chunks for '{source.name}': {e}")
370  
371      n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks)
372      project.vector.save()
373      _extract_entities_for_documents(project, all_documents, db, brain)
374      logger.info(f"SharePoint source '{source.name}' synced: {len(all_documents)} files, {n_chunks} chunks")
375  
376  
377  def _sync_gdrive(project, source, db, brain=None):
378      """Sync files from a Google Drive folder via service account."""
379      import requests as req
380      from llama_index.core.schema import Document
381      from restai.vectordb.tools import index_documents_classic, extract_keywords_for_metadata
382      from modules.loaders import find_file_loader
383  
384      sa_json = source.gdrive_service_account_json
385      folder_id = source.gdrive_folder_id
386  
387      if not sa_json or not folder_id:
388          raise ValueError("Google Drive source requires service_account_json and folder_id")
389  
390      logger.info(f"Syncing Google Drive source '{source.name}': folder={folder_id}")
391  
392      import json as _json
393      import time as _time
394      import jwt as _jwt
395  
396      sa = _json.loads(sa_json)
397      now = int(_time.time())
398      payload = {
399          "iss": sa["client_email"],
400          "scope": "https://www.googleapis.com/auth/drive.readonly",
401          "aud": "https://oauth2.googleapis.com/token",
402          "iat": now,
403          "exp": now + 3600,
404      }
405      signed_jwt = _jwt.encode(payload, sa["private_key"], algorithm="RS256")
406  
407      token_resp = req.post("https://oauth2.googleapis.com/token", data={
408          "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
409          "assertion": signed_jwt,
410      }, timeout=15)
411      token_resp.raise_for_status()
412      access_token = token_resp.json()["access_token"]
413  
414      headers = {"Authorization": f"Bearer {access_token}"}
415  
416      all_documents = []
417      page_token = None
418      query = f"'{folder_id}' in parents and trashed = false and mimeType != 'application/vnd.google-apps.folder'"
419  
420      while True:
421          params = {
422              "q": query,
423              "fields": "nextPageToken, files(id, name, mimeType)",
424              "pageSize": 100,
425          }
426          if page_token:
427              params["pageToken"] = page_token
428  
429          resp = req.get("https://www.googleapis.com/drive/v3/files", headers=headers, params=params, timeout=30)
430          resp.raise_for_status()
431          data = resp.json()
432  
433          for item in data.get("files", []):
434              name = item["name"]
435              mime = item["mimeType"]
436              file_id = item["id"]
437  
438              export_mime = None
439              export_ext = None
440              if mime == "application/vnd.google-apps.document":
441                  export_mime = "text/plain"
442                  export_ext = ".txt"
443              elif mime == "application/vnd.google-apps.spreadsheet":
444                  export_mime = "text/csv"
445                  export_ext = ".csv"
446              elif mime == "application/vnd.google-apps.presentation":
447                  export_mime = "text/plain"
448                  export_ext = ".txt"
449              else:
450                  ext = os.path.splitext(name)[1].lower()
451                  if not find_file_loader(ext):
452                      logger.debug(f"Skipping unsupported file: {name}")
453                      continue
454                  export_ext = ext
455  
456              if export_mime:
457                  dl_resp = req.get(
458                      f"https://www.googleapis.com/drive/v3/files/{file_id}/export",
459                      headers=headers, params={"mimeType": export_mime}, timeout=60,
460                  )
461              else:
462                  dl_resp = req.get(
463                      f"https://www.googleapis.com/drive/v3/files/{file_id}?alt=media",
464                      headers=headers, timeout=60,
465                  )
466              dl_resp.raise_for_status()
467  
468              if export_mime and export_mime.startswith("text/"):
469                  text = dl_resp.text.strip()
470                  if text:
471                      doc_source = f"{source.name}/{name}"
472                      all_documents.append(Document(
473                          text=text,
474                          metadata={"source": doc_source, "title": name},
475                      ))
476                  continue
477  
478              loader_cls = find_file_loader(export_ext)
479              if not loader_cls:
480                  continue
481  
482              with tempfile.NamedTemporaryFile(suffix=export_ext, delete=False) as tmp:
483                  tmp.write(dl_resp.content)
484                  tmp_path = tmp.name
485  
486              try:
487                  loader = loader_cls()
488                  docs = loader.load_data(file=tmp_path)
489                  doc_source = f"{source.name}/{name}"
490                  for doc in docs:
491                      doc.metadata["source"] = doc_source
492                  all_documents.extend(docs)
493              finally:
494                  os.unlink(tmp_path)
495  
496          page_token = data.get("nextPageToken")
497          if not page_token:
498              break
499  
500      if not all_documents:
501          logger.info(f"Google Drive source '{source.name}': no documents found")
502          return
503  
504      all_documents = extract_keywords_for_metadata(all_documents)
505  
506      if project.vector:
507          try:
508              existing_sources = project.vector.list()
509              for src in existing_sources:
510                  if src == source.name or src.startswith(f"{source.name}/"):
511                      project.vector.delete_source(src)
512          except Exception as e:
513              logger.warning(f"Failed to delete old Google Drive chunks for '{source.name}': {e}")
514  
515      n_chunks = index_documents_classic(project, all_documents, source.splitter, source.chunks)
516      project.vector.save()
517      _extract_entities_for_documents(project, all_documents, db, brain)
518      logger.info(f"Google Drive source '{source.name}' synced: {len(all_documents)} files, {n_chunks} chunks")
519  
520  
521  # --- Manual trigger (used by the "Sync Now" button in the frontend) ---
522  
523  def run_sync_now(project_id: int, brain):
524      """Run a one-off sync in a background thread. Ignores intervals — syncs all sources immediately."""
525      def _run():
526          from restai.models.databasemodels import ProjectDatabase
527  
528          db = get_db_wrapper()
529          try:
530              project = brain.find_project(project_id, db)
531              if not project or project.props.type != "rag":
532                  return
533              opts = project.props.options
534              sources = opts.sync_sources if opts else None
535              if not sources:
536                  return
537              for i, source in enumerate(sources):
538                  try:
539                      logger.info(f"Manual sync source '{source.name}' for project {project_id}")
540                      _sync_source(project, source, db, brain)
541                      proj_db = db.db.query(ProjectDatabase).filter(ProjectDatabase.id == project_id).first()
542                      if proj_db:
543                          current_opts = json.loads(proj_db.options) if proj_db.options else {}
544                          src_list = current_opts.get("sync_sources", [])
545                          if i < len(src_list):
546                              src_list[i]["last_sync"] = datetime.now(timezone.utc).isoformat()
547                              current_opts["sync_sources"] = src_list
548                              proj_db.options = json.dumps(current_opts)
549                              db.db.commit()
550                  except Exception as e:
551                      logger.error(f"Manual sync failed for source '{source.name}': {e}")
552          finally:
553              db.db.close()
554  
555      threading.Thread(target=_run, name=f"sync-manual-{project_id}", daemon=True).start()