base.py
1 from restai.brain import Brain 2 from abc import ABC, abstractmethod 3 from restai.database import DBWrapper 4 from restai.models.models import ChatModel, QuestionModel, User 5 from fastapi import HTTPException 6 from restai.project import Project 7 8 9 class ProjectBase(ABC): 10 def __init__(self, brain: Brain): 11 self.brain: Brain = brain 12 13 @abstractmethod 14 async def chat(self, project: Project, chat_model: ChatModel, user: User, db: DBWrapper): 15 raise HTTPException(status_code=400, detail="Chat mode not available for this project type.") 16 17 @abstractmethod 18 async def question(self, project: Project, question_model: QuestionModel, user: User, db: DBWrapper): 19 raise HTTPException(status_code=400, detail="Question mode not available for this project type.") 20 21 def check_input_guard(self, project: Project, question_text: str, user: User, db: DBWrapper, output: dict) -> bool: 22 """Check input guard. Returns True if the request should be blocked (in block mode). 23 24 Modifies output dict in-place (sets answer, guard flag). 25 Logs guard events via background-compatible function. 26 """ 27 if not project.props.guard: 28 return False 29 30 from restai.guard import Guard 31 from restai.tools import log_guard_event 32 33 guard = Guard(project.props.guard, self.brain, db) 34 result = guard.verify(question_text, phase="input") 35 if not result: 36 return False 37 38 guard_mode = project.props.options.guard_mode or "block" 39 action = "block" if result.blocked else "pass" 40 if result.blocked and guard_mode == "warn": 41 action = "warn" 42 43 log_guard_event(project, project.props.guard, user, "input", action, guard_mode, question_text, result.raw_response, db) 44 45 if result.blocked and guard_mode == "block": 46 output["answer"] = project.props.censorship or self.brain.defaultCensorship 47 output["guard"] = True 48 # Flag the inference log entry so the log viewer can show the 49 # guard block distinctly from a normal answer. 50 output["status"] = "guard_block" 51 self.brain.post_processing_counting(output) 52 return True 53 elif result.blocked: 54 output["guard"] = True 55 output["status"] = "guard_block" 56 57 return False 58 59 def check_output_guard(self, project: Project, user: User, db: DBWrapper, output: dict) -> None: 60 """Check the output guard against the answer in `output`. Mutates the 61 output dict in place — sets `answer` to censorship and `guard=True` if 62 a configured output guard blocks in `block` mode; just sets `guard=True` 63 in `warn` mode. No-op when `guard_output` isn't configured or the 64 answer is empty.""" 65 guard_name = project.props.options.guard_output if project.props.options else None 66 if not guard_name or not output.get("answer"): 67 return 68 69 from restai.guard import Guard 70 from restai.tools import log_guard_event 71 72 out_guard = Guard(guard_name, self.brain, db) 73 out_result = out_guard.verify(output["answer"], phase="output") 74 if not out_result: 75 return 76 77 guard_mode = project.props.options.guard_mode or "block" 78 action = "block" if out_result.blocked else "pass" 79 if out_result.blocked and guard_mode == "warn": 80 action = "warn" 81 log_guard_event( 82 project, guard_name, user, "output", action, guard_mode, 83 output["answer"], out_result.raw_response, db, 84 ) 85 86 if out_result.blocked and guard_mode == "block": 87 output["answer"] = project.props.censorship or self.brain.defaultCensorship 88 output["guard"] = True 89 elif out_result.blocked: 90 output["guard"] = True