rag.py
1 import json 2 from typing import Optional 3 4 from fastapi import HTTPException 5 6 from llama_index.core.response_synthesizers import get_response_synthesizer 7 from llama_index.core.retrievers import VectorIndexRetriever 8 from llama_index.core.query_engine import RetrieverQueryEngine 9 from llama_index.core.postprocessor import SimilarityPostprocessor 10 from llama_index.core.prompts import PromptTemplate 11 from llama_index.core.chat_engine import ContextChatEngine 12 from llama_index.core.postprocessor.llm_rerank import LLMRerank 13 from llama_index.postprocessor.colbert_rerank import ColbertRerank 14 from restai.chat import Chat 15 from restai.database import DBWrapper 16 from restai.eval import eval_rag 17 from restai.guard import Guard 18 from restai.llm import LLM 19 from restai.models.models import QuestionModel, ChatModel, User 20 from restai.project import Project 21 from restai.tools import tokens_from_string 22 from restai.projects.base import ProjectBase 23 from llama_index.core.utilities.sql_wrapper import SQLDatabase 24 from llama_index.core.indices.struct_store.sql_query import NLSQLTableQueryEngine 25 from sqlalchemy import create_engine 26 27 _ALLOWED_DB_SCHEMES = {"postgresql", "postgresql+psycopg2", "mysql", "mysql+pymysql", "sqlite"} 28 _ALLOWED_SCHEME_BASES = {s.split("+")[0] for s in _ALLOWED_DB_SCHEMES} 29 30 31 def _validate_connection_string(conn: str): 32 """Reject connection strings with dangerous schemes or targeting localhost/metadata.""" 33 from urllib.parse import urlparse 34 try: 35 parsed = urlparse(conn) 36 except Exception: 37 raise HTTPException(status_code=400, detail="Invalid connection string format") 38 39 scheme_full = parsed.scheme or "" 40 scheme_base = scheme_full.split("+")[0] 41 42 if scheme_base not in _ALLOWED_SCHEME_BASES: 43 raise HTTPException( 44 status_code=400, 45 detail=f"Database scheme '{scheme_full}' is not allowed. Permitted: {', '.join(sorted(_ALLOWED_DB_SCHEMES))}", 46 ) 47 48 # Block SQLite absolute paths that could read system files. 49 # urlparse("sqlite:///etc/passwd").path == "/etc/passwd" 50 if scheme_base == "sqlite" and parsed.path: 51 import os 52 path = parsed.path 53 if path.startswith("/") and not path.startswith(os.getcwd()): 54 raise HTTPException( 55 status_code=400, 56 detail="SQLite absolute paths outside the application directory are not allowed", 57 ) 58 59 60 class EntityBoostPostprocessor: 61 """Custom postprocessor that boosts retrieval scores for chunks whose source 62 contains entities mentioned in the user's query. Additive boost — does not 63 filter out non-matching chunks. Falls back gracefully if no entities found. 64 """ 65 66 def __init__(self, brain, db, project_id: int, query: str, boost_factor: float = 1.5): 67 self.brain = brain 68 self.db = db 69 self.project_id = project_id 70 self.query = query 71 self.boost_factor = boost_factor 72 self._matched_sources: Optional[set] = None 73 74 def _compute_matched_sources(self) -> set: 75 if self._matched_sources is not None: 76 return self._matched_sources 77 try: 78 import re as _re 79 from restai.knowledge_graph import find_entities_in_text, normalize_entity_name 80 from restai.models.databasemodels import KGEntityDatabase, KGEntityMentionDatabase 81 82 # Primary path: word-boundary match the query against entities ALREADY 83 # in this project's graph. NER on short queries is unreliable; the DB 84 # knows what we have, so direct matching is more robust. 85 project_entities = ( 86 self.db.db.query(KGEntityDatabase) 87 .filter(KGEntityDatabase.project_id == self.project_id) 88 .all() 89 ) 90 if not project_entities: 91 self._matched_sources = set() 92 return self._matched_sources 93 94 query_padded = " " + _re.sub(r"[^\w\s]", " ", (self.query or "").lower()) + " " 95 matched_ids = { 96 e.id for e in project_entities 97 if e.normalized and f" {e.normalized} " in query_padded 98 } 99 100 # Supplement with NER hits in case the query phrasing is different 101 try: 102 ner_hits = find_entities_in_text(self.query, self.brain) 103 if ner_hits: 104 ner_normalized = [normalize_entity_name(n) for n, _ in ner_hits] 105 extra_ids = { 106 e.id for e in project_entities 107 if e.normalized in ner_normalized 108 } 109 matched_ids |= extra_ids 110 except Exception: 111 pass 112 113 if not matched_ids: 114 self._matched_sources = set() 115 return self._matched_sources 116 117 sources = { 118 row.source for row in self.db.db.query(KGEntityMentionDatabase) 119 .filter(KGEntityMentionDatabase.entity_id.in_(list(matched_ids))) 120 .all() 121 } 122 self._matched_sources = sources 123 except Exception: 124 self._matched_sources = set() 125 return self._matched_sources 126 127 def postprocess_nodes(self, nodes, query_bundle=None, query_str=None): 128 matched = self._compute_matched_sources() 129 if not matched: 130 return nodes 131 for node in nodes: 132 try: 133 node_source = node.node.metadata.get("source") if hasattr(node, "node") else None 134 if node_source and node_source in matched: 135 if node.score is not None: 136 node.score = node.score * self.boost_factor 137 except Exception: 138 pass 139 # Re-sort after boosting 140 try: 141 nodes.sort(key=lambda n: n.score or 0, reverse=True) 142 except Exception: 143 pass 144 return nodes 145 146 147 class RAG(ProjectBase): 148 149 async def chat(self, project: Project, chatModel: ChatModel, user: User, db: DBWrapper): 150 if project.vector is None: 151 yield { 152 "question": chatModel.question, 153 "answer": "Knowledge base unavailable — vector store connection failed. Please check that the vector database is running.", 154 "sources": [], 155 "type": "chat", 156 "tokens": {"input": 0, "output": 0}, 157 "project": project.props.name, 158 "guard": False, 159 } 160 return 161 162 model: Optional[LLM] = self.brain.get_llm(project.props.llm, db) 163 context_window = model.props.context_window if model else 4096 164 token_limit = int(context_window * 0.75) 165 chat: Chat = Chat(chatModel, self.brain.chat_store, token_limit=token_limit, llm=model.llm if model else None) 166 167 output = { 168 "id": chat.chat_id, 169 "question": chatModel.question, 170 "sources": [], 171 "cached": False, 172 "guard": False, 173 "type": "chat", 174 "project": project.props.name, 175 } 176 177 if self.check_input_guard(project, chatModel.question, user, db, output): 178 yield output 179 return 180 181 threshold = project.props.options.score or 0.0 182 k = project.props.options.k or 1 183 184 sysTemplate = project.props.system or self.brain.defaultSystem 185 186 if project.props.options.colbert_rerank or project.props.options.llm_rerank: 187 final_k = k * 2 188 else: 189 final_k = k 190 191 retriever = VectorIndexRetriever( 192 index=project.vector.index, 193 similarity_top_k=final_k, 194 ) 195 196 postprocessors = [] 197 198 if project.props.options.enable_knowledge_graph: 199 postprocessors.append( 200 EntityBoostPostprocessor( 201 brain=self.brain, db=db, project_id=project.props.id, query=chatModel.question, 202 ) 203 ) 204 205 if project.props.options.colbert_rerank: 206 postprocessors.append( 207 ColbertRerank( 208 top_n=k, 209 model="colbert-ir/colbertv2.0", 210 tokenizer="colbert-ir/colbertv2.0", 211 keep_retrieval_score=True, 212 ) 213 ) 214 215 if project.props.options.llm_rerank: 216 postprocessors.append( 217 LLMRerank( 218 choice_batch_size=k, 219 top_n=k, 220 llm=model.llm, 221 ) 222 ) 223 224 postprocessors.append(SimilarityPostprocessor(similarity_cutoff=threshold)) 225 226 chat_engine = ContextChatEngine.from_defaults( 227 retriever=retriever, 228 system_prompt=sysTemplate, 229 memory=chat.memory, 230 node_postprocessors=postprocessors, 231 llm=model.llm, 232 ) 233 234 try: 235 if chatModel.stream: 236 response = chat_engine.stream_chat(chatModel.question) 237 else: 238 response = chat_engine.chat(chatModel.question) 239 240 for node in response.source_nodes: 241 source = {"score": node.score, "id": node.node_id, "text": node.text} 242 243 if "source" in node.metadata: 244 source["source"] = node.metadata.get("source", "unknown") 245 if "keywords" in node.metadata: 246 source["keywords"] = node.metadata["keywords"] 247 248 output["sources"].append(source) 249 250 if chatModel.stream: 251 parts = [] 252 if hasattr(response, "response_gen"): 253 for text in response.response_gen: 254 parts.append(text) 255 yield "data: " + json.dumps({"text": text}) + "\n\n" 256 257 answer = "".join(parts).strip() 258 if not answer or len(output["sources"]) == 0: 259 censorship = project.props.censorship or self.brain.defaultCensorship 260 output["answer"] = censorship 261 if not parts: 262 yield "data: " + json.dumps({"text": censorship}) + "\n\n" 263 else: 264 output["answer"] = answer 265 266 self.brain.post_processing_reasoning(output) 267 self.brain.post_processing_counting(output) 268 269 yield "data: " + json.dumps(output) + "\n" 270 yield "event: close\n\n" 271 else: 272 if len(response.source_nodes) == 0: 273 output["answer"] = ( 274 project.props.censorship or self.brain.defaultCensorship 275 ) 276 else: 277 output["answer"] = response.response 278 279 if project.cache: 280 project.cache.add(chatModel.question, response.response) 281 282 self.brain.post_processing_reasoning(output) 283 self.brain.post_processing_counting(output) 284 285 yield output 286 except Exception as e: 287 if chatModel.stream: 288 yield "data: Inference failed\n" 289 yield "event: error\n\n" 290 raise e 291 292 async def question( 293 self, project: Project, questionModel: QuestionModel, user: User, db: DBWrapper 294 ): 295 if project.vector is None and not project.props.options.connection: 296 yield { 297 "question": questionModel.question, 298 "answer": "Knowledge base unavailable — vector store connection failed. Please check that the vector database is running.", 299 "sources": [], 300 "type": "question", 301 "tokens": {"input": 0, "output": 0}, 302 "project": project.props.name, 303 "guard": False, 304 } 305 return 306 307 output = { 308 "question": questionModel.question, 309 "type": "question", 310 "sources": [], 311 "cached": False, 312 "guard": False, 313 "tokens": {"input": 0, "output": 0}, 314 "project": project.props.name, 315 } 316 317 if self.check_input_guard(project, questionModel.question, user, db, output): 318 yield output 319 return 320 321 model = self.brain.get_llm(project.props.llm, db) 322 323 # SQL query path: when a database connection is configured, use NL-to-SQL 324 if project.props.options.connection: 325 if questionModel.stream: 326 raise HTTPException( 327 status_code=400, 328 detail="Streaming is not supported for SQL queries." 329 ) 330 331 conn_str = project.props.options.connection 332 _validate_connection_string(conn_str) 333 engine = create_engine(conn_str) 334 try: 335 sql_database = SQLDatabase(engine) 336 337 tables = None 338 if hasattr(questionModel, 'tables') and questionModel.tables is not None: 339 tables = questionModel.tables 340 elif project.props.options.tables: 341 tables = [table.strip() for table in project.props.options.tables.split(',')] 342 343 sysTemplate = ( 344 questionModel.system or project.props.system or self.brain.defaultSystem 345 ) 346 question = sysTemplate + "\n Question: " + questionModel.question 347 348 query_engine = NLSQLTableQueryEngine( 349 llm=model.llm, 350 sql_database=sql_database, 351 tables=tables, 352 ) 353 354 response = query_engine.query(question) 355 356 output["answer"] = response.response 357 output["sources"] = [response.metadata['sql_query']] 358 output["tokens"] = { 359 "input": tokens_from_string(output["question"]), 360 "output": tokens_from_string(output["answer"]) 361 } 362 yield output 363 return 364 finally: 365 engine.dispose() 366 367 sysTemplate = ( 368 questionModel.system or project.props.system or self.brain.defaultSystem 369 ) 370 371 k = questionModel.k or project.props.options.k or 2 372 threshold = questionModel.score or project.props.options.score or 0.0 373 374 if ( 375 questionModel.colbert_rerank 376 or questionModel.llm_rerank 377 or project.props.options.colbert_rerank 378 or project.props.options.llm_rerank 379 ): 380 final_k = k * 2 381 else: 382 final_k = k 383 384 retriever = VectorIndexRetriever( 385 index=project.vector.index, 386 similarity_top_k=final_k, 387 ) 388 389 qa_prompt_tmpl = ( 390 "Context information is below.\n" 391 "---------------------\n" 392 "{context_str}\n" 393 "---------------------\n" 394 "Given the context information and not prior knowledge, " 395 "answer the query.\n" 396 "Query: {query_str}\n" 397 "Answer: " 398 ) 399 400 qa_prompt = PromptTemplate(qa_prompt_tmpl) 401 402 model.llm.system_prompt = sysTemplate 403 404 response_synthesizer = get_response_synthesizer( 405 llm=model.llm, text_qa_template=qa_prompt, streaming=questionModel.stream 406 ) 407 408 postprocessors = [] 409 410 if project.props.options.enable_knowledge_graph: 411 postprocessors.append( 412 EntityBoostPostprocessor( 413 brain=self.brain, db=db, project_id=project.props.id, query=questionModel.question, 414 ) 415 ) 416 417 if questionModel.colbert_rerank or project.props.options.colbert_rerank: 418 postprocessors.append( 419 ColbertRerank( 420 top_n=k, 421 model="colbert-ir/colbertv2.0", 422 tokenizer="colbert-ir/colbertv2.0", 423 keep_retrieval_score=True, 424 ) 425 ) 426 427 if questionModel.llm_rerank or project.props.options.llm_rerank: 428 postprocessors.append( 429 LLMRerank( 430 choice_batch_size=k, 431 top_n=k, 432 llm=model.llm, 433 ) 434 ) 435 436 postprocessors.append(SimilarityPostprocessor(similarity_cutoff=threshold)) 437 438 query_engine = RetrieverQueryEngine( 439 retriever=retriever, 440 response_synthesizer=response_synthesizer, 441 node_postprocessors=postprocessors, 442 ) 443 444 try: 445 response = query_engine.query(questionModel.question) 446 447 if hasattr(response, "source_nodes"): 448 for node in response.source_nodes: 449 output["sources"].append( 450 { 451 "source": node.metadata.get("source", "unknown"), 452 "keywords": node.metadata["keywords"], 453 "score": node.score, 454 "id": node.node_id, 455 "text": node.text, 456 } 457 ) 458 459 if questionModel.eval and not questionModel.stream: 460 metric = eval_rag( 461 questionModel.question, 462 response, 463 self.brain.get_llm("openai_gpt4", db).llm, 464 ) 465 output["evaluation"] = {"reason": metric.reason, "score": metric.score} 466 467 if questionModel.stream: 468 parts = [] 469 if hasattr(response, "response_gen"): 470 for text in response.response_gen: 471 parts.append(text) 472 yield "data: " + json.dumps({"text": text}) + "\n\n" 473 474 answer = "".join(parts).strip() 475 if not answer or len(response.source_nodes) == 0: 476 censorship = project.props.censorship or self.brain.defaultCensorship 477 output["answer"] = censorship 478 if not parts: 479 yield "data: " + json.dumps({"text": censorship}) + "\n\n" 480 else: 481 output["answer"] = answer 482 483 self.brain.post_processing_reasoning(output) 484 self.brain.post_processing_counting(output) 485 486 yield "data: " + json.dumps(output) + "\n" 487 yield "event: close\n\n" 488 else: 489 if len(response.source_nodes) == 0: 490 output["answer"] = ( 491 project.props.censorship or self.brain.defaultCensorship 492 ) 493 else: 494 output["answer"] = response.response 495 496 if project.cache: 497 project.cache.add(questionModel.question, response.response) 498 499 self.brain.post_processing_reasoning(output) 500 self.brain.post_processing_counting(output) 501 502 yield output 503 except Exception as e: 504 if questionModel.stream: 505 yield "data: Inference failed\n" 506 yield "event: error\n\n" 507 raise e