/ restai / projects / agent.py
agent.py
  1  """agent project type — direct LLM chat with optional tool calling.
  2  
  3  Built on the agent2 runtime (`restai.agent2`), which is the non-llamaindex
  4  agent loop. Supports built-in tools, MCP servers, multimodal image input,
  5  fallback LLMs, output guards, history compression, ReAct fallback for
  6  tool-callless models, and token-by-token streaming.
  7  
  8  Agent projects without any tools configured behave like a plain LLM chat —
  9  the runtime exits after one turn with no extra overhead. Add tools or MCP
 10  servers in the project's Tools tab to turn them into actual agents.
 11  """
 12  import json
 13  import logging
 14  from uuid import uuid4
 15  
 16  from fastapi import HTTPException
 17  
 18  from restai.agent2 import (
 19      Agent2Runtime,
 20      Agent2UnsupportedLLMError,
 21      MCPSessionPool,
 22      adapt_function_tools,
 23      build_provider_for_llm,
 24  )
 25  from restai.agent2.memory import get_session, save_session
 26  from restai.agent2.tool_adapter import AdaptedTool
 27  from restai.agent2.types import ImageBlock, ToolUseBlock
 28  from restai.database import DBWrapper
 29  from restai.models.models import ChatModel, QuestionModel, User
 30  from restai.project import Project
 31  from restai.projects.base import ProjectBase
 32  from restai.tools import tokens_from_string
 33  from restai import memory_bank
 34  
 35  
 36  _IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp")
 37  
 38  
 39  def _is_image_attachment(f) -> bool:
 40      """True if the attachment should go through the multimodal vision flow."""
 41      mime = (getattr(f, "mime_type", None) or "").lower()
 42      if mime.startswith("image/"):
 43          return True
 44      name = (getattr(f, "name", "") or "").lower()
 45      return name.endswith(_IMAGE_EXTS)
 46  
 47  
 48  def _augment_system_prompt_with_memory_bank(project, db, base_system_prompt: str | None) -> str | None:
 49      """When the project has memory_bank_enabled, prepend the rendered memory
 50      bank block to the system prompt. Cheap on the no-bank-yet path: a single
 51      indexed query that returns zero rows yields an empty string immediately.
 52      Failures degrade silently — the worst case is the LLM not seeing memory
 53      on this turn, never a broken request."""
 54      try:
 55          if not getattr(project.props.options, "memory_bank_enabled", False):
 56              return base_system_prompt
 57          max_tokens = int(getattr(project.props.options, "memory_bank_max_tokens", 2000) or 2000)
 58          block = memory_bank.render_for_prompt(db, project.props.id, max_tokens)
 59      except Exception:
 60          return base_system_prompt
 61      if not block:
 62          return base_system_prompt
 63      if base_system_prompt:
 64          return f"{block}\n\n{base_system_prompt}"
 65      return block
 66  
 67  
 68  def _project_has_terminal(project) -> bool:
 69      """True iff the project has the `terminal` tool enabled. Only that tool
 70      (the Docker-sandbox shell) reads files from `/home/user/uploads/` — for
 71      projects without it, pushing attachments into a container is dead weight
 72      and can even fail loudly (container creation, tar size limits), so we
 73      short-circuit the upload entirely."""
 74      try:
 75          raw = (getattr(project.props.options, "tools", None) or "")
 76      except Exception:
 77          raw = ""
 78      names = {t.strip().lower() for t in raw.split(",") if t.strip()}
 79      return "terminal" in names
 80  
 81  
 82  def _route_attachments(files, chat_id, prompt, brain, existing_image=None, project=None):
 83      """Route non-image file attachments to the Docker sandbox (or politely
 84      drop them when `terminal` isn't configured).
 85  
 86      `helper._normalize_image_inputs` canonicalizes the two image-input
 87      paths before we're called: by the time this function runs, any image
 88      that arrived in `files[]` has already been promoted to
 89      `chatModel.image` and removed from `files`. So `files` here should
 90      only contain non-image attachments. The image branch below is kept as
 91      a defensive fallback for direct callers that bypass the helper.
 92  
 93      Returns ``(augmented_prompt, image_data_url_or_existing)``. If the
 94      caller already passed an explicit `image` on the request, it wins
 95      over anything in `files` (backward-compat with the old `image` field).
 96      """
 97      if not files:
 98          return prompt, existing_image
 99  
100      images = [f for f in files if _is_image_attachment(f)]
101      docs = [f for f in files if not _is_image_attachment(f)]
102  
103      image_url = existing_image
104      if images and image_url is None:
105          primary = images[0]
106          mime = primary.mime_type or "image/png"
107          image_url = f"data:{mime};base64,{primary.content}"
108  
109      if docs and project is not None and _project_has_terminal(project):
110          prompt, _ = _upload_files_and_augment_prompt(docs, chat_id, prompt, brain)
111      elif docs:
112          # File attachments came in but this project can't read them — let
113          # the LLM know instead of silently dropping them.
114          names = ", ".join(f.name for f in docs[:5])
115          if len(docs) > 5:
116              names += f", …(+{len(docs) - 5} more)"
117          prompt += (
118              "\n\n[Attached file(s) ignored: this project has no tool that can "
119              f"process them ({names}). Enable the `terminal` tool on the "
120              "project to let the agent read uploaded files.]"
121          )
122  
123      return prompt, image_url
124  
125  
126  def _upload_files_and_augment_prompt(files, chat_id, prompt, brain):
127      """Push user-attached files into the agent's sandbox container and return
128      the original prompt augmented with a manifest the LLM can see.
129  
130      Returns ``(prompt, warning_or_none)``. When Docker isn't configured we
131      skip the upload and append a note so the LLM knows the files weren't
132      available.
133      """
134      if not files:
135          return prompt, None
136  
137      docker = getattr(brain, "docker_manager", None)
138      if docker is None:
139          note = "\n\n[The user attached files but the agent sandbox (Docker) isn't configured on this RESTai instance, so the files cannot be processed.]"
140          return prompt + note, "no_docker"
141  
142      import base64
143      decoded: list[tuple[str, bytes]] = []
144      for f in files:
145          try:
146              raw = base64.b64decode(f.content, validate=False)
147          except Exception:
148              continue
149          if raw:
150              decoded.append((f.name, raw))
151  
152      if not decoded:
153          return prompt, None
154  
155      try:
156          manifest = docker.put_files(chat_id or "ephemeral", decoded)
157      except Exception as e:
158          return prompt + f"\n\n[File upload to sandbox failed: {e}]", "upload_failed"
159  
160      if not manifest:
161          return prompt, None
162  
163      lines = ["", "[Files attached by the user (available in /home/user/uploads/ — use the terminal tool to inspect them):]"]
164      for entry in manifest:
165          lines.append(f"  - {entry['path']}  ({entry['size']} bytes)")
166      return prompt + "\n" + "\n".join(lines), None
167  
168  
169  def _make_project_tool_adapted(tool_row, brain) -> AdaptedTool:
170      """Create an AdaptedTool from a ProjectToolDatabase row.
171      The tool runs code in the Docker sandbox."""
172      import json as _json
173  
174      try:
175          schema = _json.loads(tool_row.parameters) if isinstance(tool_row.parameters, str) else tool_row.parameters
176      except (_json.JSONDecodeError, TypeError):
177          schema = {"type": "object", "properties": {}, "required": []}
178  
179      tool_code = tool_row.code
180      tool_name = tool_row.name
181  
182      async def _run_project_tool(**kwargs):
183          _brain = kwargs.pop("_brain", brain)
184          _chat_id = kwargs.pop("_chat_id", None)
185          kwargs.pop("_project_id", None)
186          if not _brain or not getattr(_brain, "docker_manager", None):
187              return "ERROR: Docker is not configured."
188          args_json = _json.dumps(kwargs)
189          script = f"import json, sys\nargs = json.loads(sys.stdin.readline() or '{{}}')\n{tool_code}"
190          return _brain.docker_manager.run_script(_chat_id or "ephemeral", script, stdin_data=args_json)
191  
192      return AdaptedTool(
193          name=tool_name,
194          description=tool_row.description or tool_name,
195          input_schema=schema,
196          fn=_run_project_tool,
197          is_async=True,
198          accepts_kwargs=True,
199      )
200  
201  
202  def _wrap_image_error(err: Exception, has_image: bool) -> Exception:
203      """Wrap an LLM provider error with a clearer message when an image was
204      likely the cause (the model doesn't support vision)."""
205      if not has_image:
206          return err
207      return HTTPException(
208          status_code=400,
209          detail=(
210              "This LLM rejected the request, likely because it does not support "
211              "image input. Try a vision-capable model (e.g. OllamaMultiModal, "
212              "gpt-4o, claude-3+, gemini-2.0-flash) or remove the image. "
213              f"Original error: {err}"
214          ),
215      )
216  
217  
218  class Agent(ProjectBase):
219  
220      def _build_runtime(
221          self,
222          project: Project,
223          db: DBWrapper,
224          system_prompt: str | None,
225          extra_tools: list | None = None,
226      ) -> Agent2Runtime:
227          llm_db = db.get_llm_by_name(project.props.llm)
228          if llm_db is None:
229              raise ValueError(f"LLM '{project.props.llm}' not found")
230  
231          provider, prov_config = build_provider_for_llm(llm_db)
232  
233          raw_tool_names = set(
234              t.strip() for t in (project.props.options.tools or "").split(",") if t.strip()
235          )
236          raw_tools = self.brain.get_tools(raw_tool_names) if raw_tool_names else []
237          adapted = adapt_function_tools(raw_tools)
238          if extra_tools:
239              adapted.extend(extra_tools)
240  
241          # Load agent-created project tools from DB
242          from restai.database import DBWrapper as _DBW
243          _db = _DBW()
244          try:
245              project_tools = _db.get_project_tools(project.props.id)
246              for pt in project_tools:
247                  if pt.enabled:
248                      adapted.append(_make_project_tool_adapted(pt, self.brain))
249          finally:
250              _db.db.close()
251  
252          return Agent2Runtime(
253              provider=provider,
254              config=prov_config,
255              tools=adapted,
256              system_prompt=system_prompt or "",
257              max_turns=project.props.options.max_iterations,
258              mode=project.props.options.agent_mode or "auto",
259          )
260  
261      @staticmethod
262      def _record_step(steps: list, reasoning_buf: list, tool_call: ToolUseBlock, tool_output: str):
263          reasoning_buf.append("Action: " + tool_call.name)
264          reasoning_buf.append("Action Input: " + json.dumps(tool_call.input))
265          reasoning_buf.append("Action Output: " + tool_output)
266          steps.append({
267              "actions": [{
268                  "action": tool_call.name,
269                  "input": tool_call.input,
270                  "output": tool_output,
271              }],
272              "output": "",
273          })
274  
275      @staticmethod
276      def _count_tokens(output: dict) -> None:
277          """Lightweight tiktoken-based token estimate."""
278          try:
279              output["tokens"] = {
280                  "input": tokens_from_string(output.get("question") or ""),
281                  "output": tokens_from_string(output.get("answer") or ""),
282                  "accuracy": "low",
283              }
284          except Exception:
285              output["tokens"] = {"input": 0, "output": 0, "accuracy": "low"}
286  
287      def _finalize_reasoning(self, output: dict, reasoning_buf: list, steps: list) -> None:
288          """Build the reasoning dict from tool steps if any, then let
289          `post_processing_reasoning` extract `<think>...</think>` blocks from
290          the answer if no tool reasoning was recorded. Strips think tags from
291          the answer either way."""
292          if steps:
293              output["reasoning"] = {"output": "\n".join(reasoning_buf), "steps": steps}
294          self.brain.post_processing_reasoning(output)
295  
296      async def _drive_runtime(
297          self,
298          runtime,
299          *,
300          prompt: str,
301          session,
302          image_block,
303          stream: bool,
304          project: Project,
305          output: dict,
306      ):
307          """Drive the runtime's event loop, yield text deltas as `str`, mutate
308          `output["answer"]` and `output["reasoning"]` along the way. Used by
309          both `chat()` and `question()` to share the (otherwise identical)
310          per-event handling."""
311          import re as _re
312  
313          steps: list = []
314          reasoning_buf: list = []
315          # Pair tool calls with their results so the reasoning panel renders correctly
316          pending_tool_calls: dict = {}
317          # Per-call timing + structured trace. Keyed by tool_use_id so we
318          # can compute latency = tool_result_ts - tool_use_ts. Flushed
319          # into `output["tool_trace"]` when we finalize — the log viewer
320          # renders this as a timeline.
321          import time as _time
322          tool_call_started_at: dict = {}
323          tool_trace: list = []
324          # `draw_image` tool URLs collected from tool results — appended to the
325          # final answer if the LLM didn't echo them (some models summarize tool
326          # results instead of quoting them, which would silently swallow the
327          # image link).
328          image_urls: list[str] = []
329          _image_url_re = _re.compile(
330              r"!\[[^\]]*\]\((https?://[^)\s]+/image/cache/[A-Fa-f0-9]+\.[A-Za-z0-9]+|/image/cache/[A-Fa-f0-9]+\.[A-Za-z0-9]+)\)"
331          )
332  
333          async for event in runtime.run_iter(
334              prompt,
335              session=session,
336              image=image_block,
337              stream=stream,
338          ):
339              if event.type == "text_delta":
340                  delta = event.data.get("text", "")
341                  if delta:
342                      yield delta
343  
344              elif event.type == "assistant":
345                  msg = event.message
346                  if msg:
347                      for block in msg.content:
348                          if isinstance(block, ToolUseBlock):
349                              pending_tool_calls[block.id] = block
350                              tool_call_started_at[block.id] = _time.monotonic()
351  
352              elif event.type == "tool_result":
353                  msg = event.message
354                  if msg:
355                      for block in msg.content:
356                          tool_use_id = getattr(block, "tool_use_id", None)
357                          tool_call = pending_tool_calls.pop(tool_use_id, None)
358                          content = getattr(block, "content", "") or ""
359                          if tool_call is not None:
360                              self._record_step(steps, reasoning_buf, tool_call, content)
361                              # Per-tool trace row. Latency comes from the
362                              # assistant → tool_result gap. Status is
363                              # best-effort: the convention across our
364                              # builtin tools is to return `"ERROR: ..."`
365                              # or `"OK: ..."`, so a prefix check is good
366                              # enough without wrapping every tool.
367                              started = tool_call_started_at.pop(tool_use_id, None)
368                              latency_ms = (
369                                  int((_time.monotonic() - started) * 1000) if started is not None else None
370                              )
371                              status = "error" if str(content).strip().startswith("ERROR:") else "ok"
372                              try:
373                                  input_preview = json.dumps(tool_call.input, default=str)
374                              except Exception:
375                                  input_preview = str(tool_call.input)
376                              if len(input_preview) > 500:
377                                  input_preview = input_preview[:500] + "…"
378                              err_preview = None
379                              if status == "error":
380                                  err_preview = str(content)[:500]
381                              tool_trace.append({
382                                  "tool": tool_call.name,
383                                  "args": input_preview,
384                                  "latency_ms": latency_ms,
385                                  "status": status,
386                                  "error": err_preview,
387                              })
388                              # Capture every image-cache URL the tool emitted so
389                              # we can guarantee it ends up in front of the user.
390                              for m in _image_url_re.finditer(content):
391                                  url = m.group(1)
392                                  if url not in image_urls:
393                                      image_urls.append(url)
394  
395              elif event.type == "final":
396                  output["answer"] = event.data.get("final_text", "") or ""
397                  if event.data.get("stop_reason") == "max_turns" and not output["answer"]:
398                      output["answer"] = (
399                          project.props.censorship
400                          or "I'm sorry, I tried my best but couldn't reach a final answer."
401                      )
402  
403          # If the LLM dropped a draw_image URL on its way to writing the final
404          # answer, splice it back in. Most models echo the markdown verbatim
405          # when instructed; this is the belt-and-braces safety net for the
406          # ones that summarize ("Image generated!") without the link.
407          if image_urls:
408              answer = output.get("answer") or ""
409              missing = [u for u in image_urls if u not in answer]
410              if missing:
411                  appendix = "\n\n" + "\n\n".join(f"![]({u})" for u in missing)
412                  output["answer"] = (answer + appendix).strip()
413  
414          self._finalize_reasoning(output, reasoning_buf, steps)
415  
416          # Hand the tool trace off to log_inference via the output dict.
417          # Empty list → None so we don't bloat the DB with "[]" rows.
418          if tool_trace:
419              output["tool_trace"] = tool_trace
420  
421      async def chat(self, project: Project, chatModel: ChatModel, user: User, db: DBWrapper):
422          chat_id = chatModel.id or str(uuid4())
423  
424          output = {
425              "question": chatModel.question,
426              "type": "agent",
427              "sources": [],
428              "guard": False,
429              "tokens": {"input": 0, "output": 0},
430              "project": project.props.name,
431              "id": chat_id,
432          }
433  
434          if self.check_input_guard(project, chatModel.question, user, db, output):
435              if chatModel.stream:
436                  yield "data: " + json.dumps({"text": output.get("answer", "")}) + "\n\n"
437                  yield "data: " + json.dumps(output) + "\n"
438                  yield "event: close\n\n"
439              else:
440                  yield output
441              return
442  
443          streamed_any_text = False
444          try:
445              async with MCPSessionPool() as mcp_pool:
446                  try:
447                      mcp_tools = await mcp_pool.connect_servers(
448                          project.props.options.mcp_servers or []
449                      )
450                  except Exception:
451                      mcp_tools = []
452  
453                  try:
454                      sys_prompt = _augment_system_prompt_with_memory_bank(
455                          project, db, project.props.system,
456                      )
457                      runtime = self._build_runtime(
458                          project, db, sys_prompt, extra_tools=mcp_tools
459                      )
460                  except Agent2UnsupportedLLMError as e:
461                      err_msg = str(e)
462                      if chatModel.stream:
463                          yield "data: " + json.dumps({"text": err_msg}) + "\n\n"
464                          output["answer"] = err_msg
465                          yield "data: " + json.dumps(output) + "\n"
466                          yield "event: close\n\n"
467                      else:
468                          output["answer"] = err_msg
469                          yield output
470                      return
471  
472                  runtime._chat_id = chat_id
473                  runtime._brain = self.brain
474                  runtime._project_id = project.props.id
475                  session = await get_session(self.brain, chat_id)
476  
477                  prompt_text, image_url = _route_attachments(
478                      getattr(chatModel, "files", None), chat_id, chatModel.question, self.brain,
479                      existing_image=chatModel.image, project=project,
480                  )
481                  image_block = ImageBlock.from_data_url(image_url) if image_url else None
482  
483                  try:
484                      async for delta in self._drive_runtime(
485                          runtime,
486                          prompt=prompt_text,
487                          session=session,
488                          image_block=image_block,
489                          stream=chatModel.stream,
490                          project=project,
491                          output=output,
492                      ):
493                          streamed_any_text = True
494                          yield "data: " + json.dumps({"text": delta}) + "\n\n"
495  
496                      await save_session(self.brain, chat_id, session)
497                      self._count_tokens(output)
498                      self.check_output_guard(project, user, db, output)
499  
500                      if chatModel.stream:
501                          # Emit the final answer text only if streaming didn't
502                          # already deliver it (e.g. fell back to ReAct mid-run).
503                          if not streamed_any_text and output.get("answer"):
504                              yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n"
505                          yield "data: " + json.dumps(output) + "\n"
506                          yield "event: close\n\n"
507  
508                  except Exception as e:
509                      wrapped = _wrap_image_error(e, bool(chatModel.image))
510                      err_msg = project.props.censorship or f"Agent failed: {wrapped}"
511                      output["answer"] = err_msg
512                      self._count_tokens(output)
513                      if chatModel.stream:
514                          yield "data: " + json.dumps({"text": err_msg}) + "\n\n"
515                          yield "data: " + json.dumps(output) + "\n"
516                          yield "event: close\n\n"
517                          streamed_any_text = True
518          except BaseException as e:
519              # Catch ExceptionGroup from MCP session pool cleanup failures
520              # to prevent "No response returned" crashes
521              if "answer" not in output:
522                  logging.warning("Agent chat failed during MCP cleanup: %s", e)
523                  output["answer"] = project.props.censorship or "An error occurred processing your request."
524  
525          # Non-streaming yield MUST be outside the `async with MCPSessionPool()`
526          # block. When the caller does `async for line in gen: return line`, the
527          # generator is abandoned after the first yield. If that yield happens
528          # inside the async-with, the pool's __aexit__ runs in a GC/finalizer
529          # task → "exit cancel scope in different task" → corrupted HTTP response.
530          if chatModel.stream:
531              # If streaming failed before emitting anything, emit the error as SSE
532              if "answer" in output and not streamed_any_text:
533                  yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n"
534                  yield "data: " + json.dumps(output) + "\n"
535                  yield "event: close\n\n"
536          else:
537              if "answer" not in output:
538                  output["answer"] = project.props.censorship or "No response generated."
539              yield output
540  
541      async def question(
542          self, project: Project, questionModel: QuestionModel, user: User, db: DBWrapper
543      ):
544          output = {
545              "question": questionModel.question,
546              "type": "agent",
547              "sources": [],
548              "guard": False,
549              "tokens": {"input": 0, "output": 0},
550              "project": project.props.name,
551          }
552  
553          if self.check_input_guard(project, questionModel.question, user, db, output):
554              if questionModel.stream:
555                  yield "data: " + json.dumps({"text": output.get("answer", "")}) + "\n\n"
556                  yield "data: " + json.dumps(output) + "\n"
557                  yield "event: close\n\n"
558              else:
559                  yield output
560              return
561  
562          system_prompt = questionModel.system or project.props.system
563          system_prompt = _augment_system_prompt_with_memory_bank(project, db, system_prompt)
564  
565          async with MCPSessionPool() as mcp_pool:
566              try:
567                  mcp_tools = await mcp_pool.connect_servers(
568                      project.props.options.mcp_servers or []
569                  )
570              except Exception:
571                  mcp_tools = []
572  
573              try:
574                  runtime = self._build_runtime(
575                      project, db, system_prompt, extra_tools=mcp_tools
576                  )
577              except Agent2UnsupportedLLMError as e:
578                  output["answer"] = str(e)
579                  if questionModel.stream:
580                      yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n"
581                      yield "data: " + json.dumps(output) + "\n"
582                      yield "event: close\n\n"
583                  else:
584                      yield output
585                  return
586  
587              runtime._brain = self.brain
588              runtime._project_id = project.props.id
589              streamed_any_text = False
590  
591              # Ephemeral chat id so file uploads still land in a sandbox the
592              # terminal tool can read from inside this same invocation.
593              eph_chat = f"q_{uuid4().hex[:12]}"
594              runtime._chat_id = eph_chat
595              prompt_text, image_url = _route_attachments(
596                  getattr(questionModel, "files", None), eph_chat, questionModel.question, self.brain,
597                  existing_image=questionModel.image, project=project,
598              )
599              image_block = ImageBlock.from_data_url(image_url) if image_url else None
600  
601              try:
602                  async for delta in self._drive_runtime(
603                      runtime,
604                      prompt=prompt_text,
605                      session=None,
606                      image_block=image_block,
607                      stream=questionModel.stream,
608                      project=project,
609                      output=output,
610                  ):
611                      streamed_any_text = True
612                      yield "data: " + json.dumps({"text": delta}) + "\n\n"
613  
614                  self._count_tokens(output)
615                  self.check_output_guard(project, user, db, output)
616              except Exception as e:
617                  wrapped = _wrap_image_error(e, bool(questionModel.image))
618                  err_msg = project.props.censorship or f"Agent failed: {wrapped}"
619                  output["answer"] = err_msg
620                  self._count_tokens(output)
621                  if questionModel.stream:
622                      yield "data: " + json.dumps({"text": err_msg}) + "\n\n"
623                      yield "data: " + json.dumps(output) + "\n"
624                      yield "event: close\n\n"
625                      return
626  
627          if questionModel.stream:
628              if not streamed_any_text and output.get("answer"):
629                  yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n"
630              yield "data: " + json.dumps(output) + "\n"
631              yield "event: close\n\n"
632          else:
633              yield output