test_context_references.py
1 from __future__ import annotations 2 3 import asyncio 4 import subprocess 5 from pathlib import Path 6 from unittest.mock import patch 7 8 import pytest 9 10 11 def _git(cwd: Path, *args: str) -> str: 12 result = subprocess.run( 13 ["git", *args], 14 cwd=cwd, 15 check=True, 16 capture_output=True, 17 text=True, 18 ) 19 return result.stdout.strip() 20 21 22 @pytest.fixture 23 def sample_repo(tmp_path: Path) -> Path: 24 repo = tmp_path / "repo" 25 repo.mkdir() 26 _git(repo, "init") 27 _git(repo, "config", "user.name", "Hermes Tests") 28 _git(repo, "config", "user.email", "tests@example.com") 29 30 (repo / "src").mkdir() 31 (repo / "src" / "main.py").write_text( 32 "def alpha():\n" 33 " return 'a'\n\n" 34 "def beta():\n" 35 " return 'b'\n", 36 encoding="utf-8", 37 ) 38 (repo / "src" / "helper.py").write_text("VALUE = 1\n", encoding="utf-8") 39 (repo / "README.md").write_text("# Demo\n", encoding="utf-8") 40 (repo / "blob.bin").write_bytes(b"\x00\x01\x02binary") 41 42 _git(repo, "add", ".") 43 _git(repo, "commit", "-m", "initial") 44 45 (repo / "src" / "main.py").write_text( 46 "def alpha():\n" 47 " return 'changed'\n\n" 48 "def beta():\n" 49 " return 'b'\n", 50 encoding="utf-8", 51 ) 52 (repo / "src" / "helper.py").write_text("VALUE = 2\n", encoding="utf-8") 53 _git(repo, "add", "src/helper.py") 54 return repo 55 56 57 def test_parse_typed_references_ignores_emails_and_handles(): 58 from agent.context_references import parse_context_references 59 60 message = ( 61 "email me at user@example.com and ping @teammate " 62 "but include @file:src/main.py:1-2 plus @diff and @git:2 " 63 "and @url:https://example.com/docs" 64 ) 65 66 refs = parse_context_references(message) 67 68 assert [ref.kind for ref in refs] == ["file", "diff", "git", "url"] 69 assert refs[0].target == "src/main.py" 70 assert refs[0].line_start == 1 71 assert refs[0].line_end == 2 72 assert refs[2].target == "2" 73 74 75 def test_parse_references_strips_trailing_punctuation(): 76 from agent.context_references import parse_context_references 77 78 refs = parse_context_references( 79 "review @file:README.md, then see (@url:https://example.com/docs)." 80 ) 81 82 assert [ref.kind for ref in refs] == ["file", "url"] 83 assert refs[0].target == "README.md" 84 assert refs[1].target == "https://example.com/docs" 85 86 87 def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges(): 88 from agent.context_references import parse_context_references 89 90 refs = parse_context_references( 91 'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 ' 92 'and @folder:"docs and specs" plus @file:src/main.py:1-2' 93 ) 94 95 assert [ref.kind for ref in refs] == ["file", "folder", "file"] 96 assert refs[0].target == r"C:\Users\Simba\My Project\main.py" 97 assert refs[0].line_start == 7 98 assert refs[0].line_end == 9 99 assert refs[1].target == "docs and specs" 100 assert refs[2].target == "src/main.py" 101 assert refs[2].line_start == 1 102 assert refs[2].line_end == 2 103 104 105 def test_expand_file_range_and_folder_listing(sample_repo: Path): 106 from agent.context_references import preprocess_context_references 107 108 result = preprocess_context_references( 109 "Review @file:src/main.py:1-2 and @folder:src/", 110 cwd=sample_repo, 111 context_length=100_000, 112 ) 113 114 assert result.expanded 115 assert "Review and" in result.message 116 assert "Review @file:src/main.py:1-2" not in result.message 117 assert "--- Attached Context ---" in result.message 118 assert "def alpha():" in result.message 119 assert "return 'changed'" in result.message 120 assert "def beta():" not in result.message 121 assert "src/" in result.message 122 assert "main.py" in result.message 123 assert "helper.py" in result.message 124 assert result.injected_tokens > 0 125 assert not result.warnings 126 127 128 def test_folder_listing_falls_back_when_rg_is_blocked(sample_repo: Path): 129 from agent.context_references import preprocess_context_references 130 131 real_run = subprocess.run 132 133 def blocked_rg(*args, **kwargs): 134 cmd = args[0] if args else kwargs.get("args") 135 if isinstance(cmd, list) and cmd and cmd[0] == "rg": 136 raise PermissionError("rg blocked by policy") 137 return real_run(*args, **kwargs) 138 139 with patch("agent.context_references.subprocess.run", side_effect=blocked_rg): 140 result = preprocess_context_references( 141 "Review @folder:src/", 142 cwd=sample_repo, 143 context_length=100_000, 144 ) 145 146 assert result.expanded 147 assert "src/" in result.message 148 assert "main.py" in result.message 149 assert "helper.py" in result.message 150 assert not result.warnings 151 152 153 def test_expand_quoted_file_reference_with_spaces(tmp_path: Path): 154 from agent.context_references import preprocess_context_references 155 156 workspace = tmp_path / "repo" 157 folder = workspace / "docs and specs" 158 folder.mkdir(parents=True) 159 file_path = folder / "release notes.txt" 160 file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8") 161 162 result = preprocess_context_references( 163 'Review @file:"docs and specs/release notes.txt":2-3', 164 cwd=workspace, 165 context_length=100_000, 166 ) 167 168 assert result.expanded 169 assert result.message.startswith("Review") 170 assert "line 1" not in result.message 171 assert "line 2" in result.message 172 assert "line 3" in result.message 173 assert "release notes.txt" in result.message 174 assert not result.warnings 175 176 177 def test_expand_git_diff_staged_and_log(sample_repo: Path): 178 from agent.context_references import preprocess_context_references 179 180 result = preprocess_context_references( 181 "Inspect @diff and @staged and @git:1", 182 cwd=sample_repo, 183 context_length=100_000, 184 ) 185 186 assert result.expanded 187 assert "git diff" in result.message 188 assert "git diff --staged" in result.message 189 assert "git log -1 -p" in result.message 190 assert "initial" in result.message 191 assert "return 'changed'" in result.message 192 assert "VALUE = 2" in result.message 193 194 195 def test_binary_and_missing_files_become_warnings(sample_repo: Path): 196 from agent.context_references import preprocess_context_references 197 198 result = preprocess_context_references( 199 "Check @file:blob.bin and @file:nope.txt", 200 cwd=sample_repo, 201 context_length=100_000, 202 ) 203 204 assert result.expanded 205 assert len(result.warnings) == 2 206 assert "binary" in result.message.lower() 207 assert "not found" in result.message.lower() 208 209 210 def test_soft_budget_warns_and_hard_budget_refuses(sample_repo: Path): 211 from agent.context_references import preprocess_context_references 212 213 soft = preprocess_context_references( 214 "Check @file:src/main.py", 215 cwd=sample_repo, 216 context_length=100, 217 ) 218 assert soft.expanded 219 assert any("25%" in warning for warning in soft.warnings) 220 221 hard = preprocess_context_references( 222 "Check @file:src/main.py and @file:README.md", 223 cwd=sample_repo, 224 context_length=20, 225 ) 226 assert not hard.expanded 227 assert hard.blocked 228 assert "@file:src/main.py" in hard.message 229 assert any("50%" in warning for warning in hard.warnings) 230 231 232 @pytest.mark.asyncio 233 async def test_async_url_expansion_uses_fetcher(sample_repo: Path): 234 from agent.context_references import preprocess_context_references_async 235 236 async def fake_fetch(url: str) -> str: 237 assert url == "https://example.com/spec" 238 return "# Spec\n\nImportant details." 239 240 result = await preprocess_context_references_async( 241 "Use @url:https://example.com/spec", 242 cwd=sample_repo, 243 context_length=100_000, 244 url_fetcher=fake_fetch, 245 ) 246 247 assert result.expanded 248 assert "Important details." in result.message 249 assert result.injected_tokens > 0 250 251 252 def test_sync_url_expansion_uses_async_fetcher(sample_repo: Path): 253 from agent.context_references import preprocess_context_references 254 255 async def fake_fetch(url: str) -> str: 256 await asyncio.sleep(0) 257 return f"Content for {url}" 258 259 result = preprocess_context_references( 260 "Use @url:https://example.com/spec", 261 cwd=sample_repo, 262 context_length=100_000, 263 url_fetcher=fake_fetch, 264 ) 265 266 assert result.expanded 267 assert "Content for https://example.com/spec" in result.message 268 269 270 def test_restricts_paths_to_allowed_root(tmp_path: Path): 271 from agent.context_references import preprocess_context_references 272 273 workspace = tmp_path / "workspace" 274 workspace.mkdir() 275 (workspace / "notes.txt").write_text("inside\n", encoding="utf-8") 276 secret = tmp_path / "secret.txt" 277 secret.write_text("outside\n", encoding="utf-8") 278 279 result = preprocess_context_references( 280 "read @file:../secret.txt and @file:notes.txt", 281 cwd=workspace, 282 context_length=100_000, 283 allowed_root=workspace, 284 ) 285 286 assert result.expanded 287 assert "```\noutside\n```" not in result.message 288 assert "inside" in result.message 289 assert any("outside the allowed workspace" in warning for warning in result.warnings) 290 291 292 def test_defaults_allowed_root_to_cwd(tmp_path: Path): 293 from agent.context_references import preprocess_context_references 294 295 workspace = tmp_path / "workspace" 296 workspace.mkdir() 297 secret = tmp_path / "secret.txt" 298 secret.write_text("outside\n", encoding="utf-8") 299 300 result = preprocess_context_references( 301 f"read @file:{secret}", 302 cwd=workspace, 303 context_length=100_000, 304 ) 305 306 assert result.expanded 307 assert "```\noutside\n```" not in result.message 308 assert any("outside the allowed workspace" in warning for warning in result.warnings) 309 310 311 @pytest.mark.asyncio 312 async def test_blocks_sensitive_home_and_hermes_paths(tmp_path: Path, monkeypatch): 313 from agent.context_references import preprocess_context_references_async 314 315 monkeypatch.setenv("HOME", str(tmp_path)) 316 monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) 317 318 hermes_env = tmp_path / ".hermes" / ".env" 319 hermes_env.parent.mkdir(parents=True) 320 hermes_env.write_text("API_KEY=super-secret\n", encoding="utf-8") 321 322 ssh_key = tmp_path / ".ssh" / "id_rsa" 323 ssh_key.parent.mkdir(parents=True) 324 ssh_key.write_text("PRIVATE-KEY\n", encoding="utf-8") 325 326 result = await preprocess_context_references_async( 327 "read @file:.hermes/.env and @file:.ssh/id_rsa", 328 cwd=tmp_path, 329 allowed_root=tmp_path, 330 context_length=100_000, 331 ) 332 333 assert result.expanded 334 assert "API_KEY=super-secret" not in result.message 335 assert "PRIVATE-KEY" not in result.message 336 assert any("sensitive credential" in warning for warning in result.warnings)