/ tests / cli / test_cli_context_warning.py
test_cli_context_warning.py
  1  """Tests for the low context length warning in the CLI banner."""
  2  
  3  import os
  4  from types import SimpleNamespace
  5  from unittest.mock import MagicMock, patch
  6  
  7  import pytest
  8  
  9  
 10  @pytest.fixture
 11  def _isolate(tmp_path, monkeypatch):
 12      """Isolate HERMES_HOME so tests don't touch real config."""
 13      home = tmp_path / ".hermes"
 14      home.mkdir()
 15      monkeypatch.setenv("HERMES_HOME", str(home))
 16  
 17  
 18  @pytest.fixture
 19  def cli_obj(_isolate):
 20      """Create a minimal HermesCLI instance for banner testing."""
 21      with patch("cli.load_cli_config", return_value={
 22          "display": {"tool_progress": "new"},
 23          "terminal": {},
 24      }), patch("cli.get_tool_definitions", return_value=[]), \
 25           patch("cli.build_welcome_banner"):
 26          from cli import HermesCLI
 27          obj = HermesCLI.__new__(HermesCLI)
 28          obj.model = "test-model"
 29          obj.enabled_toolsets = ["hermes-core"]
 30          obj.compact = False
 31          obj.console = MagicMock()
 32          obj.session_id = None
 33          obj.api_key = "test"
 34          obj.base_url = ""
 35          obj.provider = "test"
 36          obj._provider_source = None
 37          # Mock agent with context compressor
 38          obj.agent = SimpleNamespace(
 39              context_compressor=SimpleNamespace(context_length=None)
 40          )
 41          return obj
 42  
 43  
 44  class TestLowContextWarning:
 45      """Tests that the CLI warns about low context lengths."""
 46  
 47      def test_no_warning_for_normal_context(self, cli_obj):
 48          """No warning when context is 32k+."""
 49          cli_obj.agent.context_compressor.context_length = 32768
 50          with patch("cli.get_tool_definitions", return_value=[]), \
 51               patch("cli.build_welcome_banner"):
 52              cli_obj.show_banner()
 53  
 54          # Check that no yellow warning was printed
 55          calls = [str(c) for c in cli_obj.console.print.call_args_list]
 56          warning_calls = [c for c in calls if "too low" in c]
 57          assert len(warning_calls) == 0
 58  
 59      def test_warning_for_low_context(self, cli_obj):
 60          """Warning shown when context is 4096 (Ollama default)."""
 61          cli_obj.agent.context_compressor.context_length = 4096
 62          with patch("cli.get_tool_definitions", return_value=[]), \
 63               patch("cli.build_welcome_banner"):
 64              cli_obj.show_banner()
 65  
 66          calls = [str(c) for c in cli_obj.console.print.call_args_list]
 67          warning_calls = [c for c in calls if "too low" in c]
 68          assert len(warning_calls) == 1
 69          assert "4,096" in warning_calls[0]
 70  
 71      def test_warning_for_2048_context(self, cli_obj):
 72          """Warning shown for 2048 tokens (common LM Studio default)."""
 73          cli_obj.agent.context_compressor.context_length = 2048
 74          with patch("cli.get_tool_definitions", return_value=[]), \
 75               patch("cli.build_welcome_banner"):
 76              cli_obj.show_banner()
 77  
 78          calls = [str(c) for c in cli_obj.console.print.call_args_list]
 79          warning_calls = [c for c in calls if "too low" in c]
 80          assert len(warning_calls) == 1
 81  
 82      def test_no_warning_at_boundary(self, cli_obj):
 83          """No warning at exactly 8192 — 8192 is borderline but included in warning."""
 84          cli_obj.agent.context_compressor.context_length = 8192
 85          with patch("cli.get_tool_definitions", return_value=[]), \
 86               patch("cli.build_welcome_banner"):
 87              cli_obj.show_banner()
 88  
 89          calls = [str(c) for c in cli_obj.console.print.call_args_list]
 90          warning_calls = [c for c in calls if "too low" in c]
 91          assert len(warning_calls) == 1  # 8192 is still warned about
 92  
 93      def test_no_warning_above_boundary(self, cli_obj):
 94          """No warning at 16384."""
 95          cli_obj.agent.context_compressor.context_length = 16384
 96          with patch("cli.get_tool_definitions", return_value=[]), \
 97               patch("cli.build_welcome_banner"):
 98              cli_obj.show_banner()
 99  
100          calls = [str(c) for c in cli_obj.console.print.call_args_list]
101          warning_calls = [c for c in calls if "too low" in c]
102          assert len(warning_calls) == 0
103  
104      def test_ollama_specific_hint(self, cli_obj):
105          """Ollama-specific fix shown when port 11434 detected."""
106          cli_obj.agent.context_compressor.context_length = 4096
107          cli_obj.base_url = "http://localhost:11434/v1"
108          with patch("cli.get_tool_definitions", return_value=[]), \
109               patch("cli.build_welcome_banner"):
110              cli_obj.show_banner()
111  
112          calls = [str(c) for c in cli_obj.console.print.call_args_list]
113          ollama_hints = [c for c in calls if "OLLAMA_CONTEXT_LENGTH" in c]
114          assert len(ollama_hints) == 1
115  
116      def test_lm_studio_specific_hint(self, cli_obj):
117          """LM Studio-specific fix shown when port 1234 detected."""
118          cli_obj.agent.context_compressor.context_length = 2048
119          cli_obj.base_url = "http://localhost:1234/v1"
120          with patch("cli.get_tool_definitions", return_value=[]), \
121               patch("cli.build_welcome_banner"):
122              cli_obj.show_banner()
123  
124          calls = [str(c) for c in cli_obj.console.print.call_args_list]
125          lms_hints = [c for c in calls if "LM Studio" in c]
126          assert len(lms_hints) == 1
127  
128      def test_generic_hint_for_other_servers(self, cli_obj):
129          """Generic fix shown for unknown servers."""
130          cli_obj.agent.context_compressor.context_length = 4096
131          cli_obj.base_url = "http://localhost:8080/v1"
132          with patch("cli.get_tool_definitions", return_value=[]), \
133               patch("cli.build_welcome_banner"):
134              cli_obj.show_banner()
135  
136          calls = [str(c) for c in cli_obj.console.print.call_args_list]
137          generic_hints = [c for c in calls if "config.yaml" in c]
138          assert len(generic_hints) == 1
139  
140      def test_no_warning_when_no_context_length(self, cli_obj):
141          """No warning when context length is not yet known."""
142          cli_obj.agent.context_compressor.context_length = None
143          with patch("cli.get_tool_definitions", return_value=[]), \
144               patch("cli.build_welcome_banner"):
145              cli_obj.show_banner()
146  
147          calls = [str(c) for c in cli_obj.console.print.call_args_list]
148          warning_calls = [c for c in calls if "too low" in c]
149          assert len(warning_calls) == 0
150  
151      def test_compact_banner_does_not_crash_on_narrow_terminal(self, cli_obj):
152          """Compact mode should still have ctx_len defined for warning logic."""
153          cli_obj.agent.context_compressor.context_length = 4096
154  
155          with patch("shutil.get_terminal_size", return_value=os.terminal_size((70, 40))), \
156               patch("cli._build_compact_banner", return_value="compact banner"):
157              cli_obj.show_banner()
158  
159          calls = [str(c) for c in cli_obj.console.print.call_args_list]
160          warning_calls = [c for c in calls if "too low" in c]
161          assert len(warning_calls) == 1