/ agent / context_references.py
context_references.py
  1  from __future__ import annotations
  2  
  3  import asyncio
  4  import inspect
  5  import json
  6  import mimetypes
  7  import os
  8  import re
  9  import subprocess
 10  from dataclasses import dataclass, field
 11  from pathlib import Path
 12  from typing import Awaitable, Callable
 13  
 14  from agent.model_metadata import estimate_tokens_rough
 15  
 16  _QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')'
 17  REFERENCE_PATTERN = re.compile(
 18      rf"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))"
 19  )
 20  TRAILING_PUNCTUATION = ",.;!?"
 21  _SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh")
 22  _SENSITIVE_HERMES_DIRS = (Path("skills") / ".hub",)
 23  _SENSITIVE_HOME_FILES = (
 24      Path(".ssh") / "authorized_keys",
 25      Path(".ssh") / "id_rsa",
 26      Path(".ssh") / "id_ed25519",
 27      Path(".ssh") / "config",
 28      Path(".bashrc"),
 29      Path(".zshrc"),
 30      Path(".profile"),
 31      Path(".bash_profile"),
 32      Path(".zprofile"),
 33      Path(".netrc"),
 34      Path(".pgpass"),
 35      Path(".npmrc"),
 36      Path(".pypirc"),
 37  )
 38  
 39  
 40  @dataclass(frozen=True)
 41  class ContextReference:
 42      raw: str
 43      kind: str
 44      target: str
 45      start: int
 46      end: int
 47      line_start: int | None = None
 48      line_end: int | None = None
 49  
 50  
 51  @dataclass
 52  class ContextReferenceResult:
 53      message: str
 54      original_message: str
 55      references: list[ContextReference] = field(default_factory=list)
 56      warnings: list[str] = field(default_factory=list)
 57      injected_tokens: int = 0
 58      expanded: bool = False
 59      blocked: bool = False
 60  
 61  
 62  def parse_context_references(message: str) -> list[ContextReference]:
 63      refs: list[ContextReference] = []
 64      if not message:
 65          return refs
 66  
 67      for match in REFERENCE_PATTERN.finditer(message):
 68          simple = match.group("simple")
 69          if simple:
 70              refs.append(
 71                  ContextReference(
 72                      raw=match.group(0),
 73                      kind=simple,
 74                      target="",
 75                      start=match.start(),
 76                      end=match.end(),
 77                  )
 78              )
 79              continue
 80  
 81          kind = match.group("kind")
 82          value = _strip_trailing_punctuation(match.group("value") or "")
 83          line_start = None
 84          line_end = None
 85          target = _strip_reference_wrappers(value)
 86  
 87          if kind == "file":
 88              target, line_start, line_end = _parse_file_reference_value(value)
 89  
 90          refs.append(
 91              ContextReference(
 92                  raw=match.group(0),
 93                  kind=kind,
 94                  target=target,
 95                  start=match.start(),
 96                  end=match.end(),
 97                  line_start=line_start,
 98                  line_end=line_end,
 99              )
100          )
101  
102      return refs
103  
104  
105  def preprocess_context_references(
106      message: str,
107      *,
108      cwd: str | Path,
109      context_length: int,
110      url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
111      allowed_root: str | Path | None = None,
112  ) -> ContextReferenceResult:
113      coro = preprocess_context_references_async(
114          message,
115          cwd=cwd,
116          context_length=context_length,
117          url_fetcher=url_fetcher,
118          allowed_root=allowed_root,
119      )
120      # Safe for both CLI (no loop) and gateway (loop already running).
121      try:
122          loop = asyncio.get_running_loop()
123      except RuntimeError:
124          loop = None
125      if loop and loop.is_running():
126          import concurrent.futures
127          with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
128              return pool.submit(asyncio.run, coro).result()
129      return asyncio.run(coro)
130  
131  
132  async def preprocess_context_references_async(
133      message: str,
134      *,
135      cwd: str | Path,
136      context_length: int,
137      url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
138      allowed_root: str | Path | None = None,
139  ) -> ContextReferenceResult:
140      refs = parse_context_references(message)
141      if not refs:
142          return ContextReferenceResult(message=message, original_message=message)
143  
144      cwd_path = Path(cwd).expanduser().resolve()
145      # Default to the current working directory so @ references cannot escape
146      # the active workspace unless a caller explicitly widens the root.
147      allowed_root_path = (
148          Path(allowed_root).expanduser().resolve() if allowed_root is not None else cwd_path
149      )
150      warnings: list[str] = []
151      blocks: list[str] = []
152      injected_tokens = 0
153  
154      for ref in refs:
155          warning, block = await _expand_reference(
156              ref,
157              cwd_path,
158              url_fetcher=url_fetcher,
159              allowed_root=allowed_root_path,
160          )
161          if warning:
162              warnings.append(warning)
163          if block:
164              blocks.append(block)
165              injected_tokens += estimate_tokens_rough(block)
166  
167      hard_limit = max(1, int(context_length * 0.50))
168      soft_limit = max(1, int(context_length * 0.25))
169      if injected_tokens > hard_limit:
170          warnings.append(
171              f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})."
172          )
173          return ContextReferenceResult(
174              message=message,
175              original_message=message,
176              references=refs,
177              warnings=warnings,
178              injected_tokens=injected_tokens,
179              expanded=False,
180              blocked=True,
181          )
182  
183      if injected_tokens > soft_limit:
184          warnings.append(
185              f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})."
186          )
187  
188      stripped = _remove_reference_tokens(message, refs)
189      final = stripped
190      if warnings:
191          final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings)
192      if blocks:
193          final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks)
194  
195      return ContextReferenceResult(
196          message=final.strip(),
197          original_message=message,
198          references=refs,
199          warnings=warnings,
200          injected_tokens=injected_tokens,
201          expanded=bool(blocks or warnings),
202          blocked=False,
203      )
204  
205  
206  async def _expand_reference(
207      ref: ContextReference,
208      cwd: Path,
209      *,
210      url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
211      allowed_root: Path | None = None,
212  ) -> tuple[str | None, str | None]:
213      try:
214          if ref.kind == "file":
215              return _expand_file_reference(ref, cwd, allowed_root=allowed_root)
216          if ref.kind == "folder":
217              return _expand_folder_reference(ref, cwd, allowed_root=allowed_root)
218          if ref.kind == "diff":
219              return _expand_git_reference(ref, cwd, ["diff"], "git diff")
220          if ref.kind == "staged":
221              return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged")
222          if ref.kind == "git":
223              count = max(1, min(int(ref.target or "1"), 10))
224              return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p")
225          if ref.kind == "url":
226              content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher)
227              if not content:
228                  return f"{ref.raw}: no content extracted", None
229              return None, f"๐ŸŒ {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}"
230      except Exception as exc:
231          return f"{ref.raw}: {exc}", None
232  
233      return f"{ref.raw}: unsupported reference type", None
234  
235  
236  def _expand_file_reference(
237      ref: ContextReference,
238      cwd: Path,
239      *,
240      allowed_root: Path | None = None,
241  ) -> tuple[str | None, str | None]:
242      path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
243      _ensure_reference_path_allowed(path)
244      if not path.exists():
245          return f"{ref.raw}: file not found", None
246      if not path.is_file():
247          return f"{ref.raw}: path is not a file", None
248      if _is_binary_file(path):
249          return f"{ref.raw}: binary files are not supported", None
250  
251      text = path.read_text(encoding="utf-8")
252      if ref.line_start is not None:
253          lines = text.splitlines()
254          start_idx = max(ref.line_start - 1, 0)
255          end_idx = min(ref.line_end or ref.line_start, len(lines))
256          text = "\n".join(lines[start_idx:end_idx])
257  
258      lang = _code_fence_language(path)
259      label = ref.raw
260      return None, f"๐Ÿ“„ {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```"
261  
262  
263  def _expand_folder_reference(
264      ref: ContextReference,
265      cwd: Path,
266      *,
267      allowed_root: Path | None = None,
268  ) -> tuple[str | None, str | None]:
269      path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
270      _ensure_reference_path_allowed(path)
271      if not path.exists():
272          return f"{ref.raw}: folder not found", None
273      if not path.is_dir():
274          return f"{ref.raw}: path is not a folder", None
275  
276      listing = _build_folder_listing(path, cwd)
277      return None, f"๐Ÿ“ {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}"
278  
279  
280  def _expand_git_reference(
281      ref: ContextReference,
282      cwd: Path,
283      args: list[str],
284      label: str,
285  ) -> tuple[str | None, str | None]:
286      try:
287          result = subprocess.run(
288              ["git", *args],
289              cwd=cwd,
290              capture_output=True,
291              text=True,
292              timeout=30,
293          )
294      except subprocess.TimeoutExpired:
295          return f"{ref.raw}: git command timed out (30s)", None
296      if result.returncode != 0:
297          stderr = (result.stderr or "").strip() or "git command failed"
298          return f"{ref.raw}: {stderr}", None
299      content = result.stdout.strip()
300      if not content:
301          content = "(no output)"
302      return None, f"๐Ÿงพ {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```"
303  
304  
305  async def _fetch_url_content(
306      url: str,
307      *,
308      url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
309  ) -> str:
310      fetcher = url_fetcher or _default_url_fetcher
311      content = fetcher(url)
312      if inspect.isawaitable(content):
313          content = await content
314      return str(content or "").strip()
315  
316  
317  async def _default_url_fetcher(url: str) -> str:
318      from tools.web_tools import web_extract_tool
319  
320      raw = await web_extract_tool([url], format="markdown", use_llm_processing=True)
321      payload = json.loads(raw)
322      docs = payload.get("data", {}).get("documents", [])
323      if not docs:
324          return ""
325      doc = docs[0]
326      return str(doc.get("content") or doc.get("raw_content") or "").strip()
327  
328  
329  def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path:
330      path = Path(os.path.expanduser(target))
331      if not path.is_absolute():
332          path = cwd / path
333      resolved = path.resolve()
334      if allowed_root is not None:
335          try:
336              resolved.relative_to(allowed_root)
337          except ValueError as exc:
338              raise ValueError("path is outside the allowed workspace") from exc
339      return resolved
340  
341  
342  def _ensure_reference_path_allowed(path: Path) -> None:
343      from hermes_constants import get_hermes_home
344      home = Path(os.path.expanduser("~")).resolve()
345      hermes_home = get_hermes_home().resolve()
346  
347      blocked_exact = {home / rel for rel in _SENSITIVE_HOME_FILES}
348      blocked_exact.add(hermes_home / ".env")
349      blocked_dirs = [home / rel for rel in _SENSITIVE_HOME_DIRS]
350      blocked_dirs.extend(hermes_home / rel for rel in _SENSITIVE_HERMES_DIRS)
351  
352      if path in blocked_exact:
353          raise ValueError("path is a sensitive credential file and cannot be attached")
354  
355      for blocked_dir in blocked_dirs:
356          try:
357              path.relative_to(blocked_dir)
358          except ValueError:
359              continue
360          raise ValueError("path is a sensitive credential or internal Hermes path and cannot be attached")
361  
362  
363  def _strip_trailing_punctuation(value: str) -> str:
364      stripped = value.rstrip(TRAILING_PUNCTUATION)
365      while stripped.endswith((")", "]", "}")):
366          closer = stripped[-1]
367          opener = {")": "(", "]": "[", "}": "{"}[closer]
368          if stripped.count(closer) > stripped.count(opener):
369              stripped = stripped[:-1]
370              continue
371          break
372      return stripped
373  
374  
375  def _strip_reference_wrappers(value: str) -> str:
376      if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'":
377          return value[1:-1]
378      return value
379  
380  
381  def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]:
382      quoted_match = re.match(
383          r'^(?P<quote>`|"|\')(?P<path>.+?)(?P=quote)(?::(?P<start>\d+)(?:-(?P<end>\d+))?)?$',
384          value,
385      )
386      if quoted_match:
387          line_start = quoted_match.group("start")
388          line_end = quoted_match.group("end")
389          return (
390              quoted_match.group("path"),
391              int(line_start) if line_start is not None else None,
392              int(line_end or line_start) if line_start is not None else None,
393          )
394  
395      range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
396      if range_match:
397          line_start = int(range_match.group("start"))
398          return (
399              range_match.group("path"),
400              line_start,
401              int(range_match.group("end") or range_match.group("start")),
402          )
403  
404      return _strip_reference_wrappers(value), None, None
405  
406  
407  def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
408      pieces: list[str] = []
409      cursor = 0
410      for ref in refs:
411          pieces.append(message[cursor:ref.start])
412          cursor = ref.end
413      pieces.append(message[cursor:])
414      text = "".join(pieces)
415      text = re.sub(r"\s{2,}", " ", text)
416      text = re.sub(r"\s+([,.;:!?])", r"\1", text)
417      return text.strip()
418  
419  
420  def _is_binary_file(path: Path) -> bool:
421      mime, _ = mimetypes.guess_type(path.name)
422      if mime and not mime.startswith("text/") and not any(
423          path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts")
424      ):
425          return True
426      chunk = path.read_bytes()[:4096]
427      return b"\x00" in chunk
428  
429  
430  def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str:
431      lines = [f"{path.relative_to(cwd)}/"]
432      entries = _iter_visible_entries(path, cwd, limit=limit)
433      for entry in entries:
434          rel = entry.relative_to(cwd)
435          indent = "  " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0)
436          if entry.is_dir():
437              lines.append(f"{indent}- {entry.name}/")
438          else:
439              meta = _file_metadata(entry)
440              lines.append(f"{indent}- {entry.name} ({meta})")
441      if len(entries) >= limit:
442          lines.append("- ...")
443      return "\n".join(lines)
444  
445  
446  def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]:
447      rg_entries = _rg_files(path, cwd, limit=limit)
448      if rg_entries is not None:
449          output: list[Path] = []
450          seen_dirs: set[Path] = set()
451          for rel in rg_entries:
452              full = cwd / rel
453              for parent in full.parents:
454                  if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}:
455                      continue
456                  seen_dirs.add(parent)
457                  output.append(parent)
458              output.append(full)
459          return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p)))
460  
461      output = []
462      for root, dirs, files in os.walk(path):
463          dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__")
464          files = sorted(f for f in files if not f.startswith("."))
465          root_path = Path(root)
466          for d in dirs:
467              output.append(root_path / d)
468              if len(output) >= limit:
469                  return output
470          for f in files:
471              output.append(root_path / f)
472              if len(output) >= limit:
473                  return output
474      return output
475  
476  
477  def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None:
478      try:
479          result = subprocess.run(
480              ["rg", "--files", str(path.relative_to(cwd))],
481              cwd=cwd,
482              capture_output=True,
483              text=True,
484              timeout=10,
485          )
486      except (FileNotFoundError, OSError, subprocess.TimeoutExpired):
487          return None
488      if result.returncode != 0:
489          return None
490      files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()]
491      return files[:limit]
492  
493  
494  def _file_metadata(path: Path) -> str:
495      if _is_binary_file(path):
496          return f"{path.stat().st_size} bytes"
497      try:
498          line_count = path.read_text(encoding="utf-8").count("\n") + 1
499      except Exception:
500          return f"{path.stat().st_size} bytes"
501      return f"{line_count} lines"
502  
503  
504  def _code_fence_language(path: Path) -> str:
505      mapping = {
506          ".py": "python",
507          ".js": "javascript",
508          ".ts": "typescript",
509          ".tsx": "tsx",
510          ".jsx": "jsx",
511          ".json": "json",
512          ".md": "markdown",
513          ".sh": "bash",
514          ".yml": "yaml",
515          ".yaml": "yaml",
516          ".toml": "toml",
517      }
518      return mapping.get(path.suffix.lower(), "")