/ restai / projects / base.py
base.py
 1  from restai.brain import Brain
 2  from abc import ABC, abstractmethod
 3  from restai.database import DBWrapper
 4  from restai.models.models import ChatModel, QuestionModel, User
 5  from fastapi import HTTPException
 6  from restai.project import Project
 7  
 8  
 9  class ProjectBase(ABC):
10      def __init__(self, brain: Brain):
11          self.brain: Brain = brain
12  
13      @abstractmethod
14      async def chat(self, project: Project, chat_model: ChatModel, user: User, db: DBWrapper):
15          raise HTTPException(status_code=400, detail="Chat mode not available for this project type.")
16  
17      @abstractmethod
18      async def question(self, project: Project, question_model: QuestionModel, user: User, db: DBWrapper):
19          raise HTTPException(status_code=400, detail="Question mode not available for this project type.")
20  
21      def check_input_guard(self, project: Project, question_text: str, user: User, db: DBWrapper, output: dict) -> bool:
22          """Check input guard. Returns True if the request should be blocked (in block mode).
23  
24          Modifies output dict in-place (sets answer, guard flag).
25          Logs guard events via background-compatible function.
26          """
27          if not project.props.guard:
28              return False
29  
30          from restai.guard import Guard
31          from restai.tools import log_guard_event
32  
33          guard = Guard(project.props.guard, self.brain, db)
34          result = guard.verify(question_text, phase="input")
35          if not result:
36              return False
37  
38          guard_mode = project.props.options.guard_mode or "block"
39          action = "block" if result.blocked else "pass"
40          if result.blocked and guard_mode == "warn":
41              action = "warn"
42  
43          log_guard_event(project, project.props.guard, user, "input", action, guard_mode, question_text, result.raw_response, db)
44  
45          if result.blocked and guard_mode == "block":
46              output["answer"] = project.props.censorship or self.brain.defaultCensorship
47              output["guard"] = True
48              # Flag the inference log entry so the log viewer can show the
49              # guard block distinctly from a normal answer.
50              output["status"] = "guard_block"
51              self.brain.post_processing_counting(output)
52              return True
53          elif result.blocked:
54              output["guard"] = True
55              output["status"] = "guard_block"
56  
57          return False
58  
59      def check_output_guard(self, project: Project, user: User, db: DBWrapper, output: dict) -> None:
60          """Check the output guard against the answer in `output`. Mutates the
61          output dict in place — sets `answer` to censorship and `guard=True` if
62          a configured output guard blocks in `block` mode; just sets `guard=True`
63          in `warn` mode. No-op when `guard_output` isn't configured or the
64          answer is empty."""
65          guard_name = project.props.options.guard_output if project.props.options else None
66          if not guard_name or not output.get("answer"):
67              return
68  
69          from restai.guard import Guard
70          from restai.tools import log_guard_event
71  
72          out_guard = Guard(guard_name, self.brain, db)
73          out_result = out_guard.verify(output["answer"], phase="output")
74          if not out_result:
75              return
76  
77          guard_mode = project.props.options.guard_mode or "block"
78          action = "block" if out_result.blocked else "pass"
79          if out_result.blocked and guard_mode == "warn":
80              action = "warn"
81          log_guard_event(
82              project, guard_name, user, "output", action, guard_mode,
83              output["answer"], out_result.raw_response, db,
84          )
85  
86          if out_result.blocked and guard_mode == "block":
87              output["answer"] = project.props.censorship or self.brain.defaultCensorship
88              output["guard"] = True
89          elif out_result.blocked:
90              output["guard"] = True