/ tests / agent / test_context_references.py
test_context_references.py
  1  from __future__ import annotations
  2  
  3  import asyncio
  4  import subprocess
  5  from pathlib import Path
  6  from unittest.mock import patch
  7  
  8  import pytest
  9  
 10  
 11  def _git(cwd: Path, *args: str) -> str:
 12      result = subprocess.run(
 13          ["git", *args],
 14          cwd=cwd,
 15          check=True,
 16          capture_output=True,
 17          text=True,
 18      )
 19      return result.stdout.strip()
 20  
 21  
 22  @pytest.fixture
 23  def sample_repo(tmp_path: Path) -> Path:
 24      repo = tmp_path / "repo"
 25      repo.mkdir()
 26      _git(repo, "init")
 27      _git(repo, "config", "user.name", "Hermes Tests")
 28      _git(repo, "config", "user.email", "tests@example.com")
 29  
 30      (repo / "src").mkdir()
 31      (repo / "src" / "main.py").write_text(
 32          "def alpha():\n"
 33          "    return 'a'\n\n"
 34          "def beta():\n"
 35          "    return 'b'\n",
 36          encoding="utf-8",
 37      )
 38      (repo / "src" / "helper.py").write_text("VALUE = 1\n", encoding="utf-8")
 39      (repo / "README.md").write_text("# Demo\n", encoding="utf-8")
 40      (repo / "blob.bin").write_bytes(b"\x00\x01\x02binary")
 41  
 42      _git(repo, "add", ".")
 43      _git(repo, "commit", "-m", "initial")
 44  
 45      (repo / "src" / "main.py").write_text(
 46          "def alpha():\n"
 47          "    return 'changed'\n\n"
 48          "def beta():\n"
 49          "    return 'b'\n",
 50          encoding="utf-8",
 51      )
 52      (repo / "src" / "helper.py").write_text("VALUE = 2\n", encoding="utf-8")
 53      _git(repo, "add", "src/helper.py")
 54      return repo
 55  
 56  
 57  def test_parse_typed_references_ignores_emails_and_handles():
 58      from agent.context_references import parse_context_references
 59  
 60      message = (
 61          "email me at user@example.com and ping @teammate "
 62          "but include @file:src/main.py:1-2 plus @diff and @git:2 "
 63          "and @url:https://example.com/docs"
 64      )
 65  
 66      refs = parse_context_references(message)
 67  
 68      assert [ref.kind for ref in refs] == ["file", "diff", "git", "url"]
 69      assert refs[0].target == "src/main.py"
 70      assert refs[0].line_start == 1
 71      assert refs[0].line_end == 2
 72      assert refs[2].target == "2"
 73  
 74  
 75  def test_parse_references_strips_trailing_punctuation():
 76      from agent.context_references import parse_context_references
 77  
 78      refs = parse_context_references(
 79          "review @file:README.md, then see (@url:https://example.com/docs)."
 80      )
 81  
 82      assert [ref.kind for ref in refs] == ["file", "url"]
 83      assert refs[0].target == "README.md"
 84      assert refs[1].target == "https://example.com/docs"
 85  
 86  
 87  def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges():
 88      from agent.context_references import parse_context_references
 89  
 90      refs = parse_context_references(
 91          'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 '
 92          'and @folder:"docs and specs" plus @file:src/main.py:1-2'
 93      )
 94  
 95      assert [ref.kind for ref in refs] == ["file", "folder", "file"]
 96      assert refs[0].target == r"C:\Users\Simba\My Project\main.py"
 97      assert refs[0].line_start == 7
 98      assert refs[0].line_end == 9
 99      assert refs[1].target == "docs and specs"
100      assert refs[2].target == "src/main.py"
101      assert refs[2].line_start == 1
102      assert refs[2].line_end == 2
103  
104  
105  def test_expand_file_range_and_folder_listing(sample_repo: Path):
106      from agent.context_references import preprocess_context_references
107  
108      result = preprocess_context_references(
109          "Review @file:src/main.py:1-2 and @folder:src/",
110          cwd=sample_repo,
111          context_length=100_000,
112      )
113  
114      assert result.expanded
115      assert "Review and" in result.message
116      assert "Review @file:src/main.py:1-2" not in result.message
117      assert "--- Attached Context ---" in result.message
118      assert "def alpha():" in result.message
119      assert "return 'changed'" in result.message
120      assert "def beta():" not in result.message
121      assert "src/" in result.message
122      assert "main.py" in result.message
123      assert "helper.py" in result.message
124      assert result.injected_tokens > 0
125      assert not result.warnings
126  
127  
128  def test_folder_listing_falls_back_when_rg_is_blocked(sample_repo: Path):
129      from agent.context_references import preprocess_context_references
130  
131      real_run = subprocess.run
132  
133      def blocked_rg(*args, **kwargs):
134          cmd = args[0] if args else kwargs.get("args")
135          if isinstance(cmd, list) and cmd and cmd[0] == "rg":
136              raise PermissionError("rg blocked by policy")
137          return real_run(*args, **kwargs)
138  
139      with patch("agent.context_references.subprocess.run", side_effect=blocked_rg):
140          result = preprocess_context_references(
141              "Review @folder:src/",
142              cwd=sample_repo,
143              context_length=100_000,
144          )
145  
146      assert result.expanded
147      assert "src/" in result.message
148      assert "main.py" in result.message
149      assert "helper.py" in result.message
150      assert not result.warnings
151  
152  
153  def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
154      from agent.context_references import preprocess_context_references
155  
156      workspace = tmp_path / "repo"
157      folder = workspace / "docs and specs"
158      folder.mkdir(parents=True)
159      file_path = folder / "release notes.txt"
160      file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8")
161  
162      result = preprocess_context_references(
163          'Review @file:"docs and specs/release notes.txt":2-3',
164          cwd=workspace,
165          context_length=100_000,
166      )
167  
168      assert result.expanded
169      assert result.message.startswith("Review")
170      assert "line 1" not in result.message
171      assert "line 2" in result.message
172      assert "line 3" in result.message
173      assert "release notes.txt" in result.message
174      assert not result.warnings
175  
176  
177  def test_expand_git_diff_staged_and_log(sample_repo: Path):
178      from agent.context_references import preprocess_context_references
179  
180      result = preprocess_context_references(
181          "Inspect @diff and @staged and @git:1",
182          cwd=sample_repo,
183          context_length=100_000,
184      )
185  
186      assert result.expanded
187      assert "git diff" in result.message
188      assert "git diff --staged" in result.message
189      assert "git log -1 -p" in result.message
190      assert "initial" in result.message
191      assert "return 'changed'" in result.message
192      assert "VALUE = 2" in result.message
193  
194  
195  def test_binary_and_missing_files_become_warnings(sample_repo: Path):
196      from agent.context_references import preprocess_context_references
197  
198      result = preprocess_context_references(
199          "Check @file:blob.bin and @file:nope.txt",
200          cwd=sample_repo,
201          context_length=100_000,
202      )
203  
204      assert result.expanded
205      assert len(result.warnings) == 2
206      assert "binary" in result.message.lower()
207      assert "not found" in result.message.lower()
208  
209  
210  def test_soft_budget_warns_and_hard_budget_refuses(sample_repo: Path):
211      from agent.context_references import preprocess_context_references
212  
213      soft = preprocess_context_references(
214          "Check @file:src/main.py",
215          cwd=sample_repo,
216          context_length=100,
217      )
218      assert soft.expanded
219      assert any("25%" in warning for warning in soft.warnings)
220  
221      hard = preprocess_context_references(
222          "Check @file:src/main.py and @file:README.md",
223          cwd=sample_repo,
224          context_length=20,
225      )
226      assert not hard.expanded
227      assert hard.blocked
228      assert "@file:src/main.py" in hard.message
229      assert any("50%" in warning for warning in hard.warnings)
230  
231  
232  @pytest.mark.asyncio
233  async def test_async_url_expansion_uses_fetcher(sample_repo: Path):
234      from agent.context_references import preprocess_context_references_async
235  
236      async def fake_fetch(url: str) -> str:
237          assert url == "https://example.com/spec"
238          return "# Spec\n\nImportant details."
239  
240      result = await preprocess_context_references_async(
241          "Use @url:https://example.com/spec",
242          cwd=sample_repo,
243          context_length=100_000,
244          url_fetcher=fake_fetch,
245      )
246  
247      assert result.expanded
248      assert "Important details." in result.message
249      assert result.injected_tokens > 0
250  
251  
252  def test_sync_url_expansion_uses_async_fetcher(sample_repo: Path):
253      from agent.context_references import preprocess_context_references
254  
255      async def fake_fetch(url: str) -> str:
256          await asyncio.sleep(0)
257          return f"Content for {url}"
258  
259      result = preprocess_context_references(
260          "Use @url:https://example.com/spec",
261          cwd=sample_repo,
262          context_length=100_000,
263          url_fetcher=fake_fetch,
264      )
265  
266      assert result.expanded
267      assert "Content for https://example.com/spec" in result.message
268  
269  
270  def test_restricts_paths_to_allowed_root(tmp_path: Path):
271      from agent.context_references import preprocess_context_references
272  
273      workspace = tmp_path / "workspace"
274      workspace.mkdir()
275      (workspace / "notes.txt").write_text("inside\n", encoding="utf-8")
276      secret = tmp_path / "secret.txt"
277      secret.write_text("outside\n", encoding="utf-8")
278  
279      result = preprocess_context_references(
280          "read @file:../secret.txt and @file:notes.txt",
281          cwd=workspace,
282          context_length=100_000,
283          allowed_root=workspace,
284      )
285  
286      assert result.expanded
287      assert "```\noutside\n```" not in result.message
288      assert "inside" in result.message
289      assert any("outside the allowed workspace" in warning for warning in result.warnings)
290  
291  
292  def test_defaults_allowed_root_to_cwd(tmp_path: Path):
293      from agent.context_references import preprocess_context_references
294  
295      workspace = tmp_path / "workspace"
296      workspace.mkdir()
297      secret = tmp_path / "secret.txt"
298      secret.write_text("outside\n", encoding="utf-8")
299  
300      result = preprocess_context_references(
301          f"read @file:{secret}",
302          cwd=workspace,
303          context_length=100_000,
304      )
305  
306      assert result.expanded
307      assert "```\noutside\n```" not in result.message
308      assert any("outside the allowed workspace" in warning for warning in result.warnings)
309  
310  
311  @pytest.mark.asyncio
312  async def test_blocks_sensitive_home_and_hermes_paths(tmp_path: Path, monkeypatch):
313      from agent.context_references import preprocess_context_references_async
314  
315      monkeypatch.setenv("HOME", str(tmp_path))
316      monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
317  
318      hermes_env = tmp_path / ".hermes" / ".env"
319      hermes_env.parent.mkdir(parents=True)
320      hermes_env.write_text("API_KEY=super-secret\n", encoding="utf-8")
321  
322      ssh_key = tmp_path / ".ssh" / "id_rsa"
323      ssh_key.parent.mkdir(parents=True)
324      ssh_key.write_text("PRIVATE-KEY\n", encoding="utf-8")
325  
326      result = await preprocess_context_references_async(
327          "read @file:.hermes/.env and @file:.ssh/id_rsa",
328          cwd=tmp_path,
329          allowed_root=tmp_path,
330          context_length=100_000,
331      )
332  
333      assert result.expanded
334      assert "API_KEY=super-secret" not in result.message
335      assert "PRIVATE-KEY" not in result.message
336      assert any("sensitive credential" in warning for warning in result.warnings)