/ restai / projects / rag.py
rag.py
  1  import json
  2  from typing import Optional
  3  
  4  from fastapi import HTTPException
  5  
  6  from llama_index.core.response_synthesizers import get_response_synthesizer
  7  from llama_index.core.retrievers import VectorIndexRetriever
  8  from llama_index.core.query_engine import RetrieverQueryEngine
  9  from llama_index.core.postprocessor import SimilarityPostprocessor
 10  from llama_index.core.prompts import PromptTemplate
 11  from llama_index.core.chat_engine import ContextChatEngine
 12  from llama_index.core.postprocessor.llm_rerank import LLMRerank
 13  from llama_index.postprocessor.colbert_rerank import ColbertRerank
 14  from restai.chat import Chat
 15  from restai.database import DBWrapper
 16  from restai.eval import eval_rag
 17  from restai.guard import Guard
 18  from restai.llm import LLM
 19  from restai.models.models import QuestionModel, ChatModel, User
 20  from restai.project import Project
 21  from restai.tools import tokens_from_string
 22  from restai.projects.base import ProjectBase
 23  from llama_index.core.utilities.sql_wrapper import SQLDatabase
 24  from llama_index.core.indices.struct_store.sql_query import NLSQLTableQueryEngine
 25  from sqlalchemy import create_engine
 26  
 27  _ALLOWED_DB_SCHEMES = {"postgresql", "postgresql+psycopg2", "mysql", "mysql+pymysql", "sqlite"}
 28  _ALLOWED_SCHEME_BASES = {s.split("+")[0] for s in _ALLOWED_DB_SCHEMES}
 29  
 30  
 31  def _validate_connection_string(conn: str):
 32      """Reject connection strings with dangerous schemes or targeting localhost/metadata."""
 33      from urllib.parse import urlparse
 34      try:
 35          parsed = urlparse(conn)
 36      except Exception:
 37          raise HTTPException(status_code=400, detail="Invalid connection string format")
 38  
 39      scheme_full = parsed.scheme or ""
 40      scheme_base = scheme_full.split("+")[0]
 41  
 42      if scheme_base not in _ALLOWED_SCHEME_BASES:
 43          raise HTTPException(
 44              status_code=400,
 45              detail=f"Database scheme '{scheme_full}' is not allowed. Permitted: {', '.join(sorted(_ALLOWED_DB_SCHEMES))}",
 46          )
 47  
 48      # Block SQLite absolute paths that could read system files.
 49      # urlparse("sqlite:///etc/passwd").path == "/etc/passwd"
 50      if scheme_base == "sqlite" and parsed.path:
 51          import os
 52          path = parsed.path
 53          if path.startswith("/") and not path.startswith(os.getcwd()):
 54              raise HTTPException(
 55                  status_code=400,
 56                  detail="SQLite absolute paths outside the application directory are not allowed",
 57              )
 58  
 59  
 60  class EntityBoostPostprocessor:
 61      """Custom postprocessor that boosts retrieval scores for chunks whose source
 62      contains entities mentioned in the user's query. Additive boost — does not
 63      filter out non-matching chunks. Falls back gracefully if no entities found.
 64      """
 65  
 66      def __init__(self, brain, db, project_id: int, query: str, boost_factor: float = 1.5):
 67          self.brain = brain
 68          self.db = db
 69          self.project_id = project_id
 70          self.query = query
 71          self.boost_factor = boost_factor
 72          self._matched_sources: Optional[set] = None
 73  
 74      def _compute_matched_sources(self) -> set:
 75          if self._matched_sources is not None:
 76              return self._matched_sources
 77          try:
 78              import re as _re
 79              from restai.knowledge_graph import find_entities_in_text, normalize_entity_name
 80              from restai.models.databasemodels import KGEntityDatabase, KGEntityMentionDatabase
 81  
 82              # Primary path: word-boundary match the query against entities ALREADY
 83              # in this project's graph. NER on short queries is unreliable; the DB
 84              # knows what we have, so direct matching is more robust.
 85              project_entities = (
 86                  self.db.db.query(KGEntityDatabase)
 87                  .filter(KGEntityDatabase.project_id == self.project_id)
 88                  .all()
 89              )
 90              if not project_entities:
 91                  self._matched_sources = set()
 92                  return self._matched_sources
 93  
 94              query_padded = " " + _re.sub(r"[^\w\s]", " ", (self.query or "").lower()) + " "
 95              matched_ids = {
 96                  e.id for e in project_entities
 97                  if e.normalized and f" {e.normalized} " in query_padded
 98              }
 99  
100              # Supplement with NER hits in case the query phrasing is different
101              try:
102                  ner_hits = find_entities_in_text(self.query, self.brain)
103                  if ner_hits:
104                      ner_normalized = [normalize_entity_name(n) for n, _ in ner_hits]
105                      extra_ids = {
106                          e.id for e in project_entities
107                          if e.normalized in ner_normalized
108                      }
109                      matched_ids |= extra_ids
110              except Exception:
111                  pass
112  
113              if not matched_ids:
114                  self._matched_sources = set()
115                  return self._matched_sources
116  
117              sources = {
118                  row.source for row in self.db.db.query(KGEntityMentionDatabase)
119                  .filter(KGEntityMentionDatabase.entity_id.in_(list(matched_ids)))
120                  .all()
121              }
122              self._matched_sources = sources
123          except Exception:
124              self._matched_sources = set()
125          return self._matched_sources
126  
127      def postprocess_nodes(self, nodes, query_bundle=None, query_str=None):
128          matched = self._compute_matched_sources()
129          if not matched:
130              return nodes
131          for node in nodes:
132              try:
133                  node_source = node.node.metadata.get("source") if hasattr(node, "node") else None
134                  if node_source and node_source in matched:
135                      if node.score is not None:
136                          node.score = node.score * self.boost_factor
137              except Exception:
138                  pass
139          # Re-sort after boosting
140          try:
141              nodes.sort(key=lambda n: n.score or 0, reverse=True)
142          except Exception:
143              pass
144          return nodes
145  
146  
147  class RAG(ProjectBase):
148  
149      async def chat(self, project: Project, chatModel: ChatModel, user: User, db: DBWrapper):
150          if project.vector is None:
151              yield {
152                  "question": chatModel.question,
153                  "answer": "Knowledge base unavailable — vector store connection failed. Please check that the vector database is running.",
154                  "sources": [],
155                  "type": "chat",
156                  "tokens": {"input": 0, "output": 0},
157                  "project": project.props.name,
158                  "guard": False,
159              }
160              return
161  
162          model: Optional[LLM] = self.brain.get_llm(project.props.llm, db)
163          context_window = model.props.context_window if model else 4096
164          token_limit = int(context_window * 0.75)
165          chat: Chat = Chat(chatModel, self.brain.chat_store, token_limit=token_limit, llm=model.llm if model else None)
166  
167          output = {
168              "id": chat.chat_id,
169              "question": chatModel.question,
170              "sources": [],
171              "cached": False,
172              "guard": False,
173              "type": "chat",
174              "project": project.props.name,
175          }
176  
177          if self.check_input_guard(project, chatModel.question, user, db, output):
178              yield output
179              return
180  
181          threshold = project.props.options.score or 0.0
182          k = project.props.options.k or 1
183  
184          sysTemplate = project.props.system or self.brain.defaultSystem
185  
186          if project.props.options.colbert_rerank or project.props.options.llm_rerank:
187              final_k = k * 2
188          else:
189              final_k = k
190  
191          retriever = VectorIndexRetriever(
192              index=project.vector.index,
193              similarity_top_k=final_k,
194          )
195  
196          postprocessors = []
197  
198          if project.props.options.enable_knowledge_graph:
199              postprocessors.append(
200                  EntityBoostPostprocessor(
201                      brain=self.brain, db=db, project_id=project.props.id, query=chatModel.question,
202                  )
203              )
204  
205          if project.props.options.colbert_rerank:
206              postprocessors.append(
207                  ColbertRerank(
208                      top_n=k,
209                      model="colbert-ir/colbertv2.0",
210                      tokenizer="colbert-ir/colbertv2.0",
211                      keep_retrieval_score=True,
212                  )
213              )
214  
215          if project.props.options.llm_rerank:
216              postprocessors.append(
217                  LLMRerank(
218                      choice_batch_size=k,
219                      top_n=k,
220                      llm=model.llm,
221                  )
222              )
223  
224          postprocessors.append(SimilarityPostprocessor(similarity_cutoff=threshold))
225  
226          chat_engine = ContextChatEngine.from_defaults(
227              retriever=retriever,
228              system_prompt=sysTemplate,
229              memory=chat.memory,
230              node_postprocessors=postprocessors,
231              llm=model.llm,
232          )
233  
234          try:
235              if chatModel.stream:
236                  response = chat_engine.stream_chat(chatModel.question)
237              else:
238                  response = chat_engine.chat(chatModel.question)
239  
240              for node in response.source_nodes:
241                  source = {"score": node.score, "id": node.node_id, "text": node.text}
242  
243                  if "source" in node.metadata:
244                      source["source"] = node.metadata.get("source", "unknown")
245                  if "keywords" in node.metadata:
246                      source["keywords"] = node.metadata["keywords"]
247  
248                  output["sources"].append(source)
249  
250              if chatModel.stream:
251                  parts = []
252                  if hasattr(response, "response_gen"):
253                      for text in response.response_gen:
254                          parts.append(text)
255                          yield "data: " + json.dumps({"text": text}) + "\n\n"
256  
257                  answer = "".join(parts).strip()
258                  if not answer or len(output["sources"]) == 0:
259                      censorship = project.props.censorship or self.brain.defaultCensorship
260                      output["answer"] = censorship
261                      if not parts:
262                          yield "data: " + json.dumps({"text": censorship}) + "\n\n"
263                  else:
264                      output["answer"] = answer
265  
266                  self.brain.post_processing_reasoning(output)
267                  self.brain.post_processing_counting(output)
268  
269                  yield "data: " + json.dumps(output) + "\n"
270                  yield "event: close\n\n"
271              else:
272                  if len(response.source_nodes) == 0:
273                      output["answer"] = (
274                          project.props.censorship or self.brain.defaultCensorship
275                      )
276                  else:
277                      output["answer"] = response.response
278  
279                      if project.cache:
280                          project.cache.add(chatModel.question, response.response)
281  
282                  self.brain.post_processing_reasoning(output)
283                  self.brain.post_processing_counting(output)
284  
285                  yield output
286          except Exception as e:
287              if chatModel.stream:
288                  yield "data: Inference failed\n"
289                  yield "event: error\n\n"
290              raise e
291  
292      async def question(
293          self, project: Project, questionModel: QuestionModel, user: User, db: DBWrapper
294      ):
295          if project.vector is None and not project.props.options.connection:
296              yield {
297                  "question": questionModel.question,
298                  "answer": "Knowledge base unavailable — vector store connection failed. Please check that the vector database is running.",
299                  "sources": [],
300                  "type": "question",
301                  "tokens": {"input": 0, "output": 0},
302                  "project": project.props.name,
303                  "guard": False,
304              }
305              return
306  
307          output = {
308              "question": questionModel.question,
309              "type": "question",
310              "sources": [],
311              "cached": False,
312              "guard": False,
313              "tokens": {"input": 0, "output": 0},
314              "project": project.props.name,
315          }
316  
317          if self.check_input_guard(project, questionModel.question, user, db, output):
318              yield output
319              return
320  
321          model = self.brain.get_llm(project.props.llm, db)
322  
323          # SQL query path: when a database connection is configured, use NL-to-SQL
324          if project.props.options.connection:
325              if questionModel.stream:
326                  raise HTTPException(
327                      status_code=400,
328                      detail="Streaming is not supported for SQL queries."
329                  )
330  
331              conn_str = project.props.options.connection
332              _validate_connection_string(conn_str)
333              engine = create_engine(conn_str)
334              try:
335                  sql_database = SQLDatabase(engine)
336  
337                  tables = None
338                  if hasattr(questionModel, 'tables') and questionModel.tables is not None:
339                      tables = questionModel.tables
340                  elif project.props.options.tables:
341                      tables = [table.strip() for table in project.props.options.tables.split(',')]
342  
343                  sysTemplate = (
344                      questionModel.system or project.props.system or self.brain.defaultSystem
345                  )
346                  question = sysTemplate + "\n Question: " + questionModel.question
347  
348                  query_engine = NLSQLTableQueryEngine(
349                      llm=model.llm,
350                      sql_database=sql_database,
351                      tables=tables,
352                  )
353  
354                  response = query_engine.query(question)
355  
356                  output["answer"] = response.response
357                  output["sources"] = [response.metadata['sql_query']]
358                  output["tokens"] = {
359                      "input": tokens_from_string(output["question"]),
360                      "output": tokens_from_string(output["answer"])
361                  }
362                  yield output
363                  return
364              finally:
365                  engine.dispose()
366  
367          sysTemplate = (
368              questionModel.system or project.props.system or self.brain.defaultSystem
369          )
370  
371          k = questionModel.k or project.props.options.k or 2
372          threshold = questionModel.score or project.props.options.score or 0.0
373  
374          if (
375              questionModel.colbert_rerank
376              or questionModel.llm_rerank
377              or project.props.options.colbert_rerank
378              or project.props.options.llm_rerank
379          ):
380              final_k = k * 2
381          else:
382              final_k = k
383  
384          retriever = VectorIndexRetriever(
385              index=project.vector.index,
386              similarity_top_k=final_k,
387          )
388  
389          qa_prompt_tmpl = (
390              "Context information is below.\n"
391              "---------------------\n"
392              "{context_str}\n"
393              "---------------------\n"
394              "Given the context information and not prior knowledge, "
395              "answer the query.\n"
396              "Query: {query_str}\n"
397              "Answer: "
398          )
399  
400          qa_prompt = PromptTemplate(qa_prompt_tmpl)
401  
402          model.llm.system_prompt = sysTemplate
403  
404          response_synthesizer = get_response_synthesizer(
405              llm=model.llm, text_qa_template=qa_prompt, streaming=questionModel.stream
406          )
407  
408          postprocessors = []
409  
410          if project.props.options.enable_knowledge_graph:
411              postprocessors.append(
412                  EntityBoostPostprocessor(
413                      brain=self.brain, db=db, project_id=project.props.id, query=questionModel.question,
414                  )
415              )
416  
417          if questionModel.colbert_rerank or project.props.options.colbert_rerank:
418              postprocessors.append(
419                  ColbertRerank(
420                      top_n=k,
421                      model="colbert-ir/colbertv2.0",
422                      tokenizer="colbert-ir/colbertv2.0",
423                      keep_retrieval_score=True,
424                  )
425              )
426  
427          if questionModel.llm_rerank or project.props.options.llm_rerank:
428              postprocessors.append(
429                  LLMRerank(
430                      choice_batch_size=k,
431                      top_n=k,
432                      llm=model.llm,
433                  )
434              )
435  
436          postprocessors.append(SimilarityPostprocessor(similarity_cutoff=threshold))
437  
438          query_engine = RetrieverQueryEngine(
439              retriever=retriever,
440              response_synthesizer=response_synthesizer,
441              node_postprocessors=postprocessors,
442          )
443  
444          try:
445              response = query_engine.query(questionModel.question)
446  
447              if hasattr(response, "source_nodes"):
448                  for node in response.source_nodes:
449                      output["sources"].append(
450                          {
451                              "source": node.metadata.get("source", "unknown"),
452                              "keywords": node.metadata["keywords"],
453                              "score": node.score,
454                              "id": node.node_id,
455                              "text": node.text,
456                          }
457                      )
458  
459              if questionModel.eval and not questionModel.stream:
460                  metric = eval_rag(
461                      questionModel.question,
462                      response,
463                      self.brain.get_llm("openai_gpt4", db).llm,
464                  )
465                  output["evaluation"] = {"reason": metric.reason, "score": metric.score}
466  
467              if questionModel.stream:
468                  parts = []
469                  if hasattr(response, "response_gen"):
470                      for text in response.response_gen:
471                          parts.append(text)
472                          yield "data: " + json.dumps({"text": text}) + "\n\n"
473  
474                  answer = "".join(parts).strip()
475                  if not answer or len(response.source_nodes) == 0:
476                      censorship = project.props.censorship or self.brain.defaultCensorship
477                      output["answer"] = censorship
478                      if not parts:
479                          yield "data: " + json.dumps({"text": censorship}) + "\n\n"
480                  else:
481                      output["answer"] = answer
482  
483                  self.brain.post_processing_reasoning(output)
484                  self.brain.post_processing_counting(output)
485  
486                  yield "data: " + json.dumps(output) + "\n"
487                  yield "event: close\n\n"
488              else:
489                  if len(response.source_nodes) == 0:
490                      output["answer"] = (
491                          project.props.censorship or self.brain.defaultCensorship
492                      )
493                  else:
494                      output["answer"] = response.response
495  
496                      if project.cache:
497                          project.cache.add(questionModel.question, response.response)
498  
499                  self.brain.post_processing_reasoning(output)
500                  self.brain.post_processing_counting(output)
501  
502                  yield output
503          except Exception as e:
504              if questionModel.stream:
505                  yield "data: Inference failed\n"
506                  yield "event: error\n\n"
507              raise e