agent.py
1 """agent project type — direct LLM chat with optional tool calling. 2 3 Built on the agent2 runtime (`restai.agent2`), which is the non-llamaindex 4 agent loop. Supports built-in tools, MCP servers, multimodal image input, 5 fallback LLMs, output guards, history compression, ReAct fallback for 6 tool-callless models, and token-by-token streaming. 7 8 Agent projects without any tools configured behave like a plain LLM chat — 9 the runtime exits after one turn with no extra overhead. Add tools or MCP 10 servers in the project's Tools tab to turn them into actual agents. 11 """ 12 import json 13 import logging 14 from uuid import uuid4 15 16 from fastapi import HTTPException 17 18 from restai.agent2 import ( 19 Agent2Runtime, 20 Agent2UnsupportedLLMError, 21 MCPSessionPool, 22 adapt_function_tools, 23 build_provider_for_llm, 24 ) 25 from restai.agent2.memory import get_session, save_session 26 from restai.agent2.tool_adapter import AdaptedTool 27 from restai.agent2.types import ImageBlock, ToolUseBlock 28 from restai.database import DBWrapper 29 from restai.models.models import ChatModel, QuestionModel, User 30 from restai.project import Project 31 from restai.projects.base import ProjectBase 32 from restai.tools import tokens_from_string 33 from restai import memory_bank 34 35 36 _IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp") 37 38 39 def _is_image_attachment(f) -> bool: 40 """True if the attachment should go through the multimodal vision flow.""" 41 mime = (getattr(f, "mime_type", None) or "").lower() 42 if mime.startswith("image/"): 43 return True 44 name = (getattr(f, "name", "") or "").lower() 45 return name.endswith(_IMAGE_EXTS) 46 47 48 def _augment_system_prompt_with_memory_bank(project, db, base_system_prompt: str | None) -> str | None: 49 """When the project has memory_bank_enabled, prepend the rendered memory 50 bank block to the system prompt. Cheap on the no-bank-yet path: a single 51 indexed query that returns zero rows yields an empty string immediately. 52 Failures degrade silently — the worst case is the LLM not seeing memory 53 on this turn, never a broken request.""" 54 try: 55 if not getattr(project.props.options, "memory_bank_enabled", False): 56 return base_system_prompt 57 max_tokens = int(getattr(project.props.options, "memory_bank_max_tokens", 2000) or 2000) 58 block = memory_bank.render_for_prompt(db, project.props.id, max_tokens) 59 except Exception: 60 return base_system_prompt 61 if not block: 62 return base_system_prompt 63 if base_system_prompt: 64 return f"{block}\n\n{base_system_prompt}" 65 return block 66 67 68 def _project_has_terminal(project) -> bool: 69 """True iff the project has the `terminal` tool enabled. Only that tool 70 (the Docker-sandbox shell) reads files from `/home/user/uploads/` — for 71 projects without it, pushing attachments into a container is dead weight 72 and can even fail loudly (container creation, tar size limits), so we 73 short-circuit the upload entirely.""" 74 try: 75 raw = (getattr(project.props.options, "tools", None) or "") 76 except Exception: 77 raw = "" 78 names = {t.strip().lower() for t in raw.split(",") if t.strip()} 79 return "terminal" in names 80 81 82 def _route_attachments(files, chat_id, prompt, brain, existing_image=None, project=None): 83 """Route non-image file attachments to the Docker sandbox (or politely 84 drop them when `terminal` isn't configured). 85 86 `helper._normalize_image_inputs` canonicalizes the two image-input 87 paths before we're called: by the time this function runs, any image 88 that arrived in `files[]` has already been promoted to 89 `chatModel.image` and removed from `files`. So `files` here should 90 only contain non-image attachments. The image branch below is kept as 91 a defensive fallback for direct callers that bypass the helper. 92 93 Returns ``(augmented_prompt, image_data_url_or_existing)``. If the 94 caller already passed an explicit `image` on the request, it wins 95 over anything in `files` (backward-compat with the old `image` field). 96 """ 97 if not files: 98 return prompt, existing_image 99 100 images = [f for f in files if _is_image_attachment(f)] 101 docs = [f for f in files if not _is_image_attachment(f)] 102 103 image_url = existing_image 104 if images and image_url is None: 105 primary = images[0] 106 mime = primary.mime_type or "image/png" 107 image_url = f"data:{mime};base64,{primary.content}" 108 109 if docs and project is not None and _project_has_terminal(project): 110 prompt, _ = _upload_files_and_augment_prompt(docs, chat_id, prompt, brain) 111 elif docs: 112 # File attachments came in but this project can't read them — let 113 # the LLM know instead of silently dropping them. 114 names = ", ".join(f.name for f in docs[:5]) 115 if len(docs) > 5: 116 names += f", …(+{len(docs) - 5} more)" 117 prompt += ( 118 "\n\n[Attached file(s) ignored: this project has no tool that can " 119 f"process them ({names}). Enable the `terminal` tool on the " 120 "project to let the agent read uploaded files.]" 121 ) 122 123 return prompt, image_url 124 125 126 def _upload_files_and_augment_prompt(files, chat_id, prompt, brain): 127 """Push user-attached files into the agent's sandbox container and return 128 the original prompt augmented with a manifest the LLM can see. 129 130 Returns ``(prompt, warning_or_none)``. When Docker isn't configured we 131 skip the upload and append a note so the LLM knows the files weren't 132 available. 133 """ 134 if not files: 135 return prompt, None 136 137 docker = getattr(brain, "docker_manager", None) 138 if docker is None: 139 note = "\n\n[The user attached files but the agent sandbox (Docker) isn't configured on this RESTai instance, so the files cannot be processed.]" 140 return prompt + note, "no_docker" 141 142 import base64 143 decoded: list[tuple[str, bytes]] = [] 144 for f in files: 145 try: 146 raw = base64.b64decode(f.content, validate=False) 147 except Exception: 148 continue 149 if raw: 150 decoded.append((f.name, raw)) 151 152 if not decoded: 153 return prompt, None 154 155 try: 156 manifest = docker.put_files(chat_id or "ephemeral", decoded) 157 except Exception as e: 158 return prompt + f"\n\n[File upload to sandbox failed: {e}]", "upload_failed" 159 160 if not manifest: 161 return prompt, None 162 163 lines = ["", "[Files attached by the user (available in /home/user/uploads/ — use the terminal tool to inspect them):]"] 164 for entry in manifest: 165 lines.append(f" - {entry['path']} ({entry['size']} bytes)") 166 return prompt + "\n" + "\n".join(lines), None 167 168 169 def _make_project_tool_adapted(tool_row, brain) -> AdaptedTool: 170 """Create an AdaptedTool from a ProjectToolDatabase row. 171 The tool runs code in the Docker sandbox.""" 172 import json as _json 173 174 try: 175 schema = _json.loads(tool_row.parameters) if isinstance(tool_row.parameters, str) else tool_row.parameters 176 except (_json.JSONDecodeError, TypeError): 177 schema = {"type": "object", "properties": {}, "required": []} 178 179 tool_code = tool_row.code 180 tool_name = tool_row.name 181 182 async def _run_project_tool(**kwargs): 183 _brain = kwargs.pop("_brain", brain) 184 _chat_id = kwargs.pop("_chat_id", None) 185 kwargs.pop("_project_id", None) 186 if not _brain or not getattr(_brain, "docker_manager", None): 187 return "ERROR: Docker is not configured." 188 args_json = _json.dumps(kwargs) 189 script = f"import json, sys\nargs = json.loads(sys.stdin.readline() or '{{}}')\n{tool_code}" 190 return _brain.docker_manager.run_script(_chat_id or "ephemeral", script, stdin_data=args_json) 191 192 return AdaptedTool( 193 name=tool_name, 194 description=tool_row.description or tool_name, 195 input_schema=schema, 196 fn=_run_project_tool, 197 is_async=True, 198 accepts_kwargs=True, 199 ) 200 201 202 def _wrap_image_error(err: Exception, has_image: bool) -> Exception: 203 """Wrap an LLM provider error with a clearer message when an image was 204 likely the cause (the model doesn't support vision).""" 205 if not has_image: 206 return err 207 return HTTPException( 208 status_code=400, 209 detail=( 210 "This LLM rejected the request, likely because it does not support " 211 "image input. Try a vision-capable model (e.g. OllamaMultiModal, " 212 "gpt-4o, claude-3+, gemini-2.0-flash) or remove the image. " 213 f"Original error: {err}" 214 ), 215 ) 216 217 218 class Agent(ProjectBase): 219 220 def _build_runtime( 221 self, 222 project: Project, 223 db: DBWrapper, 224 system_prompt: str | None, 225 extra_tools: list | None = None, 226 ) -> Agent2Runtime: 227 llm_db = db.get_llm_by_name(project.props.llm) 228 if llm_db is None: 229 raise ValueError(f"LLM '{project.props.llm}' not found") 230 231 provider, prov_config = build_provider_for_llm(llm_db) 232 233 raw_tool_names = set( 234 t.strip() for t in (project.props.options.tools or "").split(",") if t.strip() 235 ) 236 raw_tools = self.brain.get_tools(raw_tool_names) if raw_tool_names else [] 237 adapted = adapt_function_tools(raw_tools) 238 if extra_tools: 239 adapted.extend(extra_tools) 240 241 # Load agent-created project tools from DB 242 from restai.database import DBWrapper as _DBW 243 _db = _DBW() 244 try: 245 project_tools = _db.get_project_tools(project.props.id) 246 for pt in project_tools: 247 if pt.enabled: 248 adapted.append(_make_project_tool_adapted(pt, self.brain)) 249 finally: 250 _db.db.close() 251 252 return Agent2Runtime( 253 provider=provider, 254 config=prov_config, 255 tools=adapted, 256 system_prompt=system_prompt or "", 257 max_turns=project.props.options.max_iterations, 258 mode=project.props.options.agent_mode or "auto", 259 ) 260 261 @staticmethod 262 def _record_step(steps: list, reasoning_buf: list, tool_call: ToolUseBlock, tool_output: str): 263 reasoning_buf.append("Action: " + tool_call.name) 264 reasoning_buf.append("Action Input: " + json.dumps(tool_call.input)) 265 reasoning_buf.append("Action Output: " + tool_output) 266 steps.append({ 267 "actions": [{ 268 "action": tool_call.name, 269 "input": tool_call.input, 270 "output": tool_output, 271 }], 272 "output": "", 273 }) 274 275 @staticmethod 276 def _count_tokens(output: dict) -> None: 277 """Lightweight tiktoken-based token estimate.""" 278 try: 279 output["tokens"] = { 280 "input": tokens_from_string(output.get("question") or ""), 281 "output": tokens_from_string(output.get("answer") or ""), 282 "accuracy": "low", 283 } 284 except Exception: 285 output["tokens"] = {"input": 0, "output": 0, "accuracy": "low"} 286 287 def _finalize_reasoning(self, output: dict, reasoning_buf: list, steps: list) -> None: 288 """Build the reasoning dict from tool steps if any, then let 289 `post_processing_reasoning` extract `<think>...</think>` blocks from 290 the answer if no tool reasoning was recorded. Strips think tags from 291 the answer either way.""" 292 if steps: 293 output["reasoning"] = {"output": "\n".join(reasoning_buf), "steps": steps} 294 self.brain.post_processing_reasoning(output) 295 296 async def _drive_runtime( 297 self, 298 runtime, 299 *, 300 prompt: str, 301 session, 302 image_block, 303 stream: bool, 304 project: Project, 305 output: dict, 306 ): 307 """Drive the runtime's event loop, yield text deltas as `str`, mutate 308 `output["answer"]` and `output["reasoning"]` along the way. Used by 309 both `chat()` and `question()` to share the (otherwise identical) 310 per-event handling.""" 311 import re as _re 312 313 steps: list = [] 314 reasoning_buf: list = [] 315 # Pair tool calls with their results so the reasoning panel renders correctly 316 pending_tool_calls: dict = {} 317 # Per-call timing + structured trace. Keyed by tool_use_id so we 318 # can compute latency = tool_result_ts - tool_use_ts. Flushed 319 # into `output["tool_trace"]` when we finalize — the log viewer 320 # renders this as a timeline. 321 import time as _time 322 tool_call_started_at: dict = {} 323 tool_trace: list = [] 324 # `draw_image` tool URLs collected from tool results — appended to the 325 # final answer if the LLM didn't echo them (some models summarize tool 326 # results instead of quoting them, which would silently swallow the 327 # image link). 328 image_urls: list[str] = [] 329 _image_url_re = _re.compile( 330 r"!\[[^\]]*\]\((https?://[^)\s]+/image/cache/[A-Fa-f0-9]+\.[A-Za-z0-9]+|/image/cache/[A-Fa-f0-9]+\.[A-Za-z0-9]+)\)" 331 ) 332 333 async for event in runtime.run_iter( 334 prompt, 335 session=session, 336 image=image_block, 337 stream=stream, 338 ): 339 if event.type == "text_delta": 340 delta = event.data.get("text", "") 341 if delta: 342 yield delta 343 344 elif event.type == "assistant": 345 msg = event.message 346 if msg: 347 for block in msg.content: 348 if isinstance(block, ToolUseBlock): 349 pending_tool_calls[block.id] = block 350 tool_call_started_at[block.id] = _time.monotonic() 351 352 elif event.type == "tool_result": 353 msg = event.message 354 if msg: 355 for block in msg.content: 356 tool_use_id = getattr(block, "tool_use_id", None) 357 tool_call = pending_tool_calls.pop(tool_use_id, None) 358 content = getattr(block, "content", "") or "" 359 if tool_call is not None: 360 self._record_step(steps, reasoning_buf, tool_call, content) 361 # Per-tool trace row. Latency comes from the 362 # assistant → tool_result gap. Status is 363 # best-effort: the convention across our 364 # builtin tools is to return `"ERROR: ..."` 365 # or `"OK: ..."`, so a prefix check is good 366 # enough without wrapping every tool. 367 started = tool_call_started_at.pop(tool_use_id, None) 368 latency_ms = ( 369 int((_time.monotonic() - started) * 1000) if started is not None else None 370 ) 371 status = "error" if str(content).strip().startswith("ERROR:") else "ok" 372 try: 373 input_preview = json.dumps(tool_call.input, default=str) 374 except Exception: 375 input_preview = str(tool_call.input) 376 if len(input_preview) > 500: 377 input_preview = input_preview[:500] + "…" 378 err_preview = None 379 if status == "error": 380 err_preview = str(content)[:500] 381 tool_trace.append({ 382 "tool": tool_call.name, 383 "args": input_preview, 384 "latency_ms": latency_ms, 385 "status": status, 386 "error": err_preview, 387 }) 388 # Capture every image-cache URL the tool emitted so 389 # we can guarantee it ends up in front of the user. 390 for m in _image_url_re.finditer(content): 391 url = m.group(1) 392 if url not in image_urls: 393 image_urls.append(url) 394 395 elif event.type == "final": 396 output["answer"] = event.data.get("final_text", "") or "" 397 if event.data.get("stop_reason") == "max_turns" and not output["answer"]: 398 output["answer"] = ( 399 project.props.censorship 400 or "I'm sorry, I tried my best but couldn't reach a final answer." 401 ) 402 403 # If the LLM dropped a draw_image URL on its way to writing the final 404 # answer, splice it back in. Most models echo the markdown verbatim 405 # when instructed; this is the belt-and-braces safety net for the 406 # ones that summarize ("Image generated!") without the link. 407 if image_urls: 408 answer = output.get("answer") or "" 409 missing = [u for u in image_urls if u not in answer] 410 if missing: 411 appendix = "\n\n" + "\n\n".join(f"" for u in missing) 412 output["answer"] = (answer + appendix).strip() 413 414 self._finalize_reasoning(output, reasoning_buf, steps) 415 416 # Hand the tool trace off to log_inference via the output dict. 417 # Empty list → None so we don't bloat the DB with "[]" rows. 418 if tool_trace: 419 output["tool_trace"] = tool_trace 420 421 async def chat(self, project: Project, chatModel: ChatModel, user: User, db: DBWrapper): 422 chat_id = chatModel.id or str(uuid4()) 423 424 output = { 425 "question": chatModel.question, 426 "type": "agent", 427 "sources": [], 428 "guard": False, 429 "tokens": {"input": 0, "output": 0}, 430 "project": project.props.name, 431 "id": chat_id, 432 } 433 434 if self.check_input_guard(project, chatModel.question, user, db, output): 435 if chatModel.stream: 436 yield "data: " + json.dumps({"text": output.get("answer", "")}) + "\n\n" 437 yield "data: " + json.dumps(output) + "\n" 438 yield "event: close\n\n" 439 else: 440 yield output 441 return 442 443 streamed_any_text = False 444 try: 445 async with MCPSessionPool() as mcp_pool: 446 try: 447 mcp_tools = await mcp_pool.connect_servers( 448 project.props.options.mcp_servers or [] 449 ) 450 except Exception: 451 mcp_tools = [] 452 453 try: 454 sys_prompt = _augment_system_prompt_with_memory_bank( 455 project, db, project.props.system, 456 ) 457 runtime = self._build_runtime( 458 project, db, sys_prompt, extra_tools=mcp_tools 459 ) 460 except Agent2UnsupportedLLMError as e: 461 err_msg = str(e) 462 if chatModel.stream: 463 yield "data: " + json.dumps({"text": err_msg}) + "\n\n" 464 output["answer"] = err_msg 465 yield "data: " + json.dumps(output) + "\n" 466 yield "event: close\n\n" 467 else: 468 output["answer"] = err_msg 469 yield output 470 return 471 472 runtime._chat_id = chat_id 473 runtime._brain = self.brain 474 runtime._project_id = project.props.id 475 session = await get_session(self.brain, chat_id) 476 477 prompt_text, image_url = _route_attachments( 478 getattr(chatModel, "files", None), chat_id, chatModel.question, self.brain, 479 existing_image=chatModel.image, project=project, 480 ) 481 image_block = ImageBlock.from_data_url(image_url) if image_url else None 482 483 try: 484 async for delta in self._drive_runtime( 485 runtime, 486 prompt=prompt_text, 487 session=session, 488 image_block=image_block, 489 stream=chatModel.stream, 490 project=project, 491 output=output, 492 ): 493 streamed_any_text = True 494 yield "data: " + json.dumps({"text": delta}) + "\n\n" 495 496 await save_session(self.brain, chat_id, session) 497 self._count_tokens(output) 498 self.check_output_guard(project, user, db, output) 499 500 if chatModel.stream: 501 # Emit the final answer text only if streaming didn't 502 # already deliver it (e.g. fell back to ReAct mid-run). 503 if not streamed_any_text and output.get("answer"): 504 yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n" 505 yield "data: " + json.dumps(output) + "\n" 506 yield "event: close\n\n" 507 508 except Exception as e: 509 wrapped = _wrap_image_error(e, bool(chatModel.image)) 510 err_msg = project.props.censorship or f"Agent failed: {wrapped}" 511 output["answer"] = err_msg 512 self._count_tokens(output) 513 if chatModel.stream: 514 yield "data: " + json.dumps({"text": err_msg}) + "\n\n" 515 yield "data: " + json.dumps(output) + "\n" 516 yield "event: close\n\n" 517 streamed_any_text = True 518 except BaseException as e: 519 # Catch ExceptionGroup from MCP session pool cleanup failures 520 # to prevent "No response returned" crashes 521 if "answer" not in output: 522 logging.warning("Agent chat failed during MCP cleanup: %s", e) 523 output["answer"] = project.props.censorship or "An error occurred processing your request." 524 525 # Non-streaming yield MUST be outside the `async with MCPSessionPool()` 526 # block. When the caller does `async for line in gen: return line`, the 527 # generator is abandoned after the first yield. If that yield happens 528 # inside the async-with, the pool's __aexit__ runs in a GC/finalizer 529 # task → "exit cancel scope in different task" → corrupted HTTP response. 530 if chatModel.stream: 531 # If streaming failed before emitting anything, emit the error as SSE 532 if "answer" in output and not streamed_any_text: 533 yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n" 534 yield "data: " + json.dumps(output) + "\n" 535 yield "event: close\n\n" 536 else: 537 if "answer" not in output: 538 output["answer"] = project.props.censorship or "No response generated." 539 yield output 540 541 async def question( 542 self, project: Project, questionModel: QuestionModel, user: User, db: DBWrapper 543 ): 544 output = { 545 "question": questionModel.question, 546 "type": "agent", 547 "sources": [], 548 "guard": False, 549 "tokens": {"input": 0, "output": 0}, 550 "project": project.props.name, 551 } 552 553 if self.check_input_guard(project, questionModel.question, user, db, output): 554 if questionModel.stream: 555 yield "data: " + json.dumps({"text": output.get("answer", "")}) + "\n\n" 556 yield "data: " + json.dumps(output) + "\n" 557 yield "event: close\n\n" 558 else: 559 yield output 560 return 561 562 system_prompt = questionModel.system or project.props.system 563 system_prompt = _augment_system_prompt_with_memory_bank(project, db, system_prompt) 564 565 async with MCPSessionPool() as mcp_pool: 566 try: 567 mcp_tools = await mcp_pool.connect_servers( 568 project.props.options.mcp_servers or [] 569 ) 570 except Exception: 571 mcp_tools = [] 572 573 try: 574 runtime = self._build_runtime( 575 project, db, system_prompt, extra_tools=mcp_tools 576 ) 577 except Agent2UnsupportedLLMError as e: 578 output["answer"] = str(e) 579 if questionModel.stream: 580 yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n" 581 yield "data: " + json.dumps(output) + "\n" 582 yield "event: close\n\n" 583 else: 584 yield output 585 return 586 587 runtime._brain = self.brain 588 runtime._project_id = project.props.id 589 streamed_any_text = False 590 591 # Ephemeral chat id so file uploads still land in a sandbox the 592 # terminal tool can read from inside this same invocation. 593 eph_chat = f"q_{uuid4().hex[:12]}" 594 runtime._chat_id = eph_chat 595 prompt_text, image_url = _route_attachments( 596 getattr(questionModel, "files", None), eph_chat, questionModel.question, self.brain, 597 existing_image=questionModel.image, project=project, 598 ) 599 image_block = ImageBlock.from_data_url(image_url) if image_url else None 600 601 try: 602 async for delta in self._drive_runtime( 603 runtime, 604 prompt=prompt_text, 605 session=None, 606 image_block=image_block, 607 stream=questionModel.stream, 608 project=project, 609 output=output, 610 ): 611 streamed_any_text = True 612 yield "data: " + json.dumps({"text": delta}) + "\n\n" 613 614 self._count_tokens(output) 615 self.check_output_guard(project, user, db, output) 616 except Exception as e: 617 wrapped = _wrap_image_error(e, bool(questionModel.image)) 618 err_msg = project.props.censorship or f"Agent failed: {wrapped}" 619 output["answer"] = err_msg 620 self._count_tokens(output) 621 if questionModel.stream: 622 yield "data: " + json.dumps({"text": err_msg}) + "\n\n" 623 yield "data: " + json.dumps(output) + "\n" 624 yield "event: close\n\n" 625 return 626 627 if questionModel.stream: 628 if not streamed_any_text and output.get("answer"): 629 yield "data: " + json.dumps({"text": output["answer"]}) + "\n\n" 630 yield "data: " + json.dumps(output) + "\n" 631 yield "event: close\n\n" 632 else: 633 yield output