context_references.py
1 from __future__ import annotations 2 3 import asyncio 4 import inspect 5 import json 6 import mimetypes 7 import os 8 import re 9 import subprocess 10 from dataclasses import dataclass, field 11 from pathlib import Path 12 from typing import Awaitable, Callable 13 14 from agent.model_metadata import estimate_tokens_rough 15 16 _QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')' 17 REFERENCE_PATTERN = re.compile( 18 rf"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))" 19 ) 20 TRAILING_PUNCTUATION = ",.;!?" 21 _SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh") 22 _SENSITIVE_HERMES_DIRS = (Path("skills") / ".hub",) 23 _SENSITIVE_HOME_FILES = ( 24 Path(".ssh") / "authorized_keys", 25 Path(".ssh") / "id_rsa", 26 Path(".ssh") / "id_ed25519", 27 Path(".ssh") / "config", 28 Path(".bashrc"), 29 Path(".zshrc"), 30 Path(".profile"), 31 Path(".bash_profile"), 32 Path(".zprofile"), 33 Path(".netrc"), 34 Path(".pgpass"), 35 Path(".npmrc"), 36 Path(".pypirc"), 37 ) 38 39 40 @dataclass(frozen=True) 41 class ContextReference: 42 raw: str 43 kind: str 44 target: str 45 start: int 46 end: int 47 line_start: int | None = None 48 line_end: int | None = None 49 50 51 @dataclass 52 class ContextReferenceResult: 53 message: str 54 original_message: str 55 references: list[ContextReference] = field(default_factory=list) 56 warnings: list[str] = field(default_factory=list) 57 injected_tokens: int = 0 58 expanded: bool = False 59 blocked: bool = False 60 61 62 def parse_context_references(message: str) -> list[ContextReference]: 63 refs: list[ContextReference] = [] 64 if not message: 65 return refs 66 67 for match in REFERENCE_PATTERN.finditer(message): 68 simple = match.group("simple") 69 if simple: 70 refs.append( 71 ContextReference( 72 raw=match.group(0), 73 kind=simple, 74 target="", 75 start=match.start(), 76 end=match.end(), 77 ) 78 ) 79 continue 80 81 kind = match.group("kind") 82 value = _strip_trailing_punctuation(match.group("value") or "") 83 line_start = None 84 line_end = None 85 target = _strip_reference_wrappers(value) 86 87 if kind == "file": 88 target, line_start, line_end = _parse_file_reference_value(value) 89 90 refs.append( 91 ContextReference( 92 raw=match.group(0), 93 kind=kind, 94 target=target, 95 start=match.start(), 96 end=match.end(), 97 line_start=line_start, 98 line_end=line_end, 99 ) 100 ) 101 102 return refs 103 104 105 def preprocess_context_references( 106 message: str, 107 *, 108 cwd: str | Path, 109 context_length: int, 110 url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, 111 allowed_root: str | Path | None = None, 112 ) -> ContextReferenceResult: 113 coro = preprocess_context_references_async( 114 message, 115 cwd=cwd, 116 context_length=context_length, 117 url_fetcher=url_fetcher, 118 allowed_root=allowed_root, 119 ) 120 # Safe for both CLI (no loop) and gateway (loop already running). 121 try: 122 loop = asyncio.get_running_loop() 123 except RuntimeError: 124 loop = None 125 if loop and loop.is_running(): 126 import concurrent.futures 127 with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: 128 return pool.submit(asyncio.run, coro).result() 129 return asyncio.run(coro) 130 131 132 async def preprocess_context_references_async( 133 message: str, 134 *, 135 cwd: str | Path, 136 context_length: int, 137 url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, 138 allowed_root: str | Path | None = None, 139 ) -> ContextReferenceResult: 140 refs = parse_context_references(message) 141 if not refs: 142 return ContextReferenceResult(message=message, original_message=message) 143 144 cwd_path = Path(cwd).expanduser().resolve() 145 # Default to the current working directory so @ references cannot escape 146 # the active workspace unless a caller explicitly widens the root. 147 allowed_root_path = ( 148 Path(allowed_root).expanduser().resolve() if allowed_root is not None else cwd_path 149 ) 150 warnings: list[str] = [] 151 blocks: list[str] = [] 152 injected_tokens = 0 153 154 for ref in refs: 155 warning, block = await _expand_reference( 156 ref, 157 cwd_path, 158 url_fetcher=url_fetcher, 159 allowed_root=allowed_root_path, 160 ) 161 if warning: 162 warnings.append(warning) 163 if block: 164 blocks.append(block) 165 injected_tokens += estimate_tokens_rough(block) 166 167 hard_limit = max(1, int(context_length * 0.50)) 168 soft_limit = max(1, int(context_length * 0.25)) 169 if injected_tokens > hard_limit: 170 warnings.append( 171 f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})." 172 ) 173 return ContextReferenceResult( 174 message=message, 175 original_message=message, 176 references=refs, 177 warnings=warnings, 178 injected_tokens=injected_tokens, 179 expanded=False, 180 blocked=True, 181 ) 182 183 if injected_tokens > soft_limit: 184 warnings.append( 185 f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})." 186 ) 187 188 stripped = _remove_reference_tokens(message, refs) 189 final = stripped 190 if warnings: 191 final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings) 192 if blocks: 193 final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks) 194 195 return ContextReferenceResult( 196 message=final.strip(), 197 original_message=message, 198 references=refs, 199 warnings=warnings, 200 injected_tokens=injected_tokens, 201 expanded=bool(blocks or warnings), 202 blocked=False, 203 ) 204 205 206 async def _expand_reference( 207 ref: ContextReference, 208 cwd: Path, 209 *, 210 url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, 211 allowed_root: Path | None = None, 212 ) -> tuple[str | None, str | None]: 213 try: 214 if ref.kind == "file": 215 return _expand_file_reference(ref, cwd, allowed_root=allowed_root) 216 if ref.kind == "folder": 217 return _expand_folder_reference(ref, cwd, allowed_root=allowed_root) 218 if ref.kind == "diff": 219 return _expand_git_reference(ref, cwd, ["diff"], "git diff") 220 if ref.kind == "staged": 221 return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged") 222 if ref.kind == "git": 223 count = max(1, min(int(ref.target or "1"), 10)) 224 return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p") 225 if ref.kind == "url": 226 content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher) 227 if not content: 228 return f"{ref.raw}: no content extracted", None 229 return None, f"๐ {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}" 230 except Exception as exc: 231 return f"{ref.raw}: {exc}", None 232 233 return f"{ref.raw}: unsupported reference type", None 234 235 236 def _expand_file_reference( 237 ref: ContextReference, 238 cwd: Path, 239 *, 240 allowed_root: Path | None = None, 241 ) -> tuple[str | None, str | None]: 242 path = _resolve_path(cwd, ref.target, allowed_root=allowed_root) 243 _ensure_reference_path_allowed(path) 244 if not path.exists(): 245 return f"{ref.raw}: file not found", None 246 if not path.is_file(): 247 return f"{ref.raw}: path is not a file", None 248 if _is_binary_file(path): 249 return f"{ref.raw}: binary files are not supported", None 250 251 text = path.read_text(encoding="utf-8") 252 if ref.line_start is not None: 253 lines = text.splitlines() 254 start_idx = max(ref.line_start - 1, 0) 255 end_idx = min(ref.line_end or ref.line_start, len(lines)) 256 text = "\n".join(lines[start_idx:end_idx]) 257 258 lang = _code_fence_language(path) 259 label = ref.raw 260 return None, f"๐ {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```" 261 262 263 def _expand_folder_reference( 264 ref: ContextReference, 265 cwd: Path, 266 *, 267 allowed_root: Path | None = None, 268 ) -> tuple[str | None, str | None]: 269 path = _resolve_path(cwd, ref.target, allowed_root=allowed_root) 270 _ensure_reference_path_allowed(path) 271 if not path.exists(): 272 return f"{ref.raw}: folder not found", None 273 if not path.is_dir(): 274 return f"{ref.raw}: path is not a folder", None 275 276 listing = _build_folder_listing(path, cwd) 277 return None, f"๐ {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}" 278 279 280 def _expand_git_reference( 281 ref: ContextReference, 282 cwd: Path, 283 args: list[str], 284 label: str, 285 ) -> tuple[str | None, str | None]: 286 try: 287 result = subprocess.run( 288 ["git", *args], 289 cwd=cwd, 290 capture_output=True, 291 text=True, 292 timeout=30, 293 ) 294 except subprocess.TimeoutExpired: 295 return f"{ref.raw}: git command timed out (30s)", None 296 if result.returncode != 0: 297 stderr = (result.stderr or "").strip() or "git command failed" 298 return f"{ref.raw}: {stderr}", None 299 content = result.stdout.strip() 300 if not content: 301 content = "(no output)" 302 return None, f"๐งพ {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```" 303 304 305 async def _fetch_url_content( 306 url: str, 307 *, 308 url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, 309 ) -> str: 310 fetcher = url_fetcher or _default_url_fetcher 311 content = fetcher(url) 312 if inspect.isawaitable(content): 313 content = await content 314 return str(content or "").strip() 315 316 317 async def _default_url_fetcher(url: str) -> str: 318 from tools.web_tools import web_extract_tool 319 320 raw = await web_extract_tool([url], format="markdown", use_llm_processing=True) 321 payload = json.loads(raw) 322 docs = payload.get("data", {}).get("documents", []) 323 if not docs: 324 return "" 325 doc = docs[0] 326 return str(doc.get("content") or doc.get("raw_content") or "").strip() 327 328 329 def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path: 330 path = Path(os.path.expanduser(target)) 331 if not path.is_absolute(): 332 path = cwd / path 333 resolved = path.resolve() 334 if allowed_root is not None: 335 try: 336 resolved.relative_to(allowed_root) 337 except ValueError as exc: 338 raise ValueError("path is outside the allowed workspace") from exc 339 return resolved 340 341 342 def _ensure_reference_path_allowed(path: Path) -> None: 343 from hermes_constants import get_hermes_home 344 home = Path(os.path.expanduser("~")).resolve() 345 hermes_home = get_hermes_home().resolve() 346 347 blocked_exact = {home / rel for rel in _SENSITIVE_HOME_FILES} 348 blocked_exact.add(hermes_home / ".env") 349 blocked_dirs = [home / rel for rel in _SENSITIVE_HOME_DIRS] 350 blocked_dirs.extend(hermes_home / rel for rel in _SENSITIVE_HERMES_DIRS) 351 352 if path in blocked_exact: 353 raise ValueError("path is a sensitive credential file and cannot be attached") 354 355 for blocked_dir in blocked_dirs: 356 try: 357 path.relative_to(blocked_dir) 358 except ValueError: 359 continue 360 raise ValueError("path is a sensitive credential or internal Hermes path and cannot be attached") 361 362 363 def _strip_trailing_punctuation(value: str) -> str: 364 stripped = value.rstrip(TRAILING_PUNCTUATION) 365 while stripped.endswith((")", "]", "}")): 366 closer = stripped[-1] 367 opener = {")": "(", "]": "[", "}": "{"}[closer] 368 if stripped.count(closer) > stripped.count(opener): 369 stripped = stripped[:-1] 370 continue 371 break 372 return stripped 373 374 375 def _strip_reference_wrappers(value: str) -> str: 376 if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'": 377 return value[1:-1] 378 return value 379 380 381 def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]: 382 quoted_match = re.match( 383 r'^(?P<quote>`|"|\')(?P<path>.+?)(?P=quote)(?::(?P<start>\d+)(?:-(?P<end>\d+))?)?$', 384 value, 385 ) 386 if quoted_match: 387 line_start = quoted_match.group("start") 388 line_end = quoted_match.group("end") 389 return ( 390 quoted_match.group("path"), 391 int(line_start) if line_start is not None else None, 392 int(line_end or line_start) if line_start is not None else None, 393 ) 394 395 range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value) 396 if range_match: 397 line_start = int(range_match.group("start")) 398 return ( 399 range_match.group("path"), 400 line_start, 401 int(range_match.group("end") or range_match.group("start")), 402 ) 403 404 return _strip_reference_wrappers(value), None, None 405 406 407 def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str: 408 pieces: list[str] = [] 409 cursor = 0 410 for ref in refs: 411 pieces.append(message[cursor:ref.start]) 412 cursor = ref.end 413 pieces.append(message[cursor:]) 414 text = "".join(pieces) 415 text = re.sub(r"\s{2,}", " ", text) 416 text = re.sub(r"\s+([,.;:!?])", r"\1", text) 417 return text.strip() 418 419 420 def _is_binary_file(path: Path) -> bool: 421 mime, _ = mimetypes.guess_type(path.name) 422 if mime and not mime.startswith("text/") and not any( 423 path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts") 424 ): 425 return True 426 chunk = path.read_bytes()[:4096] 427 return b"\x00" in chunk 428 429 430 def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str: 431 lines = [f"{path.relative_to(cwd)}/"] 432 entries = _iter_visible_entries(path, cwd, limit=limit) 433 for entry in entries: 434 rel = entry.relative_to(cwd) 435 indent = " " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0) 436 if entry.is_dir(): 437 lines.append(f"{indent}- {entry.name}/") 438 else: 439 meta = _file_metadata(entry) 440 lines.append(f"{indent}- {entry.name} ({meta})") 441 if len(entries) >= limit: 442 lines.append("- ...") 443 return "\n".join(lines) 444 445 446 def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]: 447 rg_entries = _rg_files(path, cwd, limit=limit) 448 if rg_entries is not None: 449 output: list[Path] = [] 450 seen_dirs: set[Path] = set() 451 for rel in rg_entries: 452 full = cwd / rel 453 for parent in full.parents: 454 if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}: 455 continue 456 seen_dirs.add(parent) 457 output.append(parent) 458 output.append(full) 459 return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p))) 460 461 output = [] 462 for root, dirs, files in os.walk(path): 463 dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__") 464 files = sorted(f for f in files if not f.startswith(".")) 465 root_path = Path(root) 466 for d in dirs: 467 output.append(root_path / d) 468 if len(output) >= limit: 469 return output 470 for f in files: 471 output.append(root_path / f) 472 if len(output) >= limit: 473 return output 474 return output 475 476 477 def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None: 478 try: 479 result = subprocess.run( 480 ["rg", "--files", str(path.relative_to(cwd))], 481 cwd=cwd, 482 capture_output=True, 483 text=True, 484 timeout=10, 485 ) 486 except (FileNotFoundError, OSError, subprocess.TimeoutExpired): 487 return None 488 if result.returncode != 0: 489 return None 490 files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()] 491 return files[:limit] 492 493 494 def _file_metadata(path: Path) -> str: 495 if _is_binary_file(path): 496 return f"{path.stat().st_size} bytes" 497 try: 498 line_count = path.read_text(encoding="utf-8").count("\n") + 1 499 except Exception: 500 return f"{path.stat().st_size} bytes" 501 return f"{line_count} lines" 502 503 504 def _code_fence_language(path: Path) -> str: 505 mapping = { 506 ".py": "python", 507 ".js": "javascript", 508 ".ts": "typescript", 509 ".tsx": "tsx", 510 ".jsx": "jsx", 511 ".json": "json", 512 ".md": "markdown", 513 ".sh": "bash", 514 ".yml": "yaml", 515 ".yaml": "yaml", 516 ".toml": "toml", 517 } 518 return mapping.get(path.suffix.lower(), "")