test_cli_context_warning.py
1 """Tests for the low context length warning in the CLI banner.""" 2 3 import os 4 from types import SimpleNamespace 5 from unittest.mock import MagicMock, patch 6 7 import pytest 8 9 10 @pytest.fixture 11 def _isolate(tmp_path, monkeypatch): 12 """Isolate HERMES_HOME so tests don't touch real config.""" 13 home = tmp_path / ".hermes" 14 home.mkdir() 15 monkeypatch.setenv("HERMES_HOME", str(home)) 16 17 18 @pytest.fixture 19 def cli_obj(_isolate): 20 """Create a minimal HermesCLI instance for banner testing.""" 21 with patch("cli.load_cli_config", return_value={ 22 "display": {"tool_progress": "new"}, 23 "terminal": {}, 24 }), patch("cli.get_tool_definitions", return_value=[]), \ 25 patch("cli.build_welcome_banner"): 26 from cli import HermesCLI 27 obj = HermesCLI.__new__(HermesCLI) 28 obj.model = "test-model" 29 obj.enabled_toolsets = ["hermes-core"] 30 obj.compact = False 31 obj.console = MagicMock() 32 obj.session_id = None 33 obj.api_key = "test" 34 obj.base_url = "" 35 obj.provider = "test" 36 obj._provider_source = None 37 # Mock agent with context compressor 38 obj.agent = SimpleNamespace( 39 context_compressor=SimpleNamespace(context_length=None) 40 ) 41 return obj 42 43 44 class TestLowContextWarning: 45 """Tests that the CLI warns about low context lengths.""" 46 47 def test_no_warning_for_normal_context(self, cli_obj): 48 """No warning when context is 32k+.""" 49 cli_obj.agent.context_compressor.context_length = 32768 50 with patch("cli.get_tool_definitions", return_value=[]), \ 51 patch("cli.build_welcome_banner"): 52 cli_obj.show_banner() 53 54 # Check that no yellow warning was printed 55 calls = [str(c) for c in cli_obj.console.print.call_args_list] 56 warning_calls = [c for c in calls if "too low" in c] 57 assert len(warning_calls) == 0 58 59 def test_warning_for_low_context(self, cli_obj): 60 """Warning shown when context is 4096 (Ollama default).""" 61 cli_obj.agent.context_compressor.context_length = 4096 62 with patch("cli.get_tool_definitions", return_value=[]), \ 63 patch("cli.build_welcome_banner"): 64 cli_obj.show_banner() 65 66 calls = [str(c) for c in cli_obj.console.print.call_args_list] 67 warning_calls = [c for c in calls if "too low" in c] 68 assert len(warning_calls) == 1 69 assert "4,096" in warning_calls[0] 70 71 def test_warning_for_2048_context(self, cli_obj): 72 """Warning shown for 2048 tokens (common LM Studio default).""" 73 cli_obj.agent.context_compressor.context_length = 2048 74 with patch("cli.get_tool_definitions", return_value=[]), \ 75 patch("cli.build_welcome_banner"): 76 cli_obj.show_banner() 77 78 calls = [str(c) for c in cli_obj.console.print.call_args_list] 79 warning_calls = [c for c in calls if "too low" in c] 80 assert len(warning_calls) == 1 81 82 def test_no_warning_at_boundary(self, cli_obj): 83 """No warning at exactly 8192 — 8192 is borderline but included in warning.""" 84 cli_obj.agent.context_compressor.context_length = 8192 85 with patch("cli.get_tool_definitions", return_value=[]), \ 86 patch("cli.build_welcome_banner"): 87 cli_obj.show_banner() 88 89 calls = [str(c) for c in cli_obj.console.print.call_args_list] 90 warning_calls = [c for c in calls if "too low" in c] 91 assert len(warning_calls) == 1 # 8192 is still warned about 92 93 def test_no_warning_above_boundary(self, cli_obj): 94 """No warning at 16384.""" 95 cli_obj.agent.context_compressor.context_length = 16384 96 with patch("cli.get_tool_definitions", return_value=[]), \ 97 patch("cli.build_welcome_banner"): 98 cli_obj.show_banner() 99 100 calls = [str(c) for c in cli_obj.console.print.call_args_list] 101 warning_calls = [c for c in calls if "too low" in c] 102 assert len(warning_calls) == 0 103 104 def test_ollama_specific_hint(self, cli_obj): 105 """Ollama-specific fix shown when port 11434 detected.""" 106 cli_obj.agent.context_compressor.context_length = 4096 107 cli_obj.base_url = "http://localhost:11434/v1" 108 with patch("cli.get_tool_definitions", return_value=[]), \ 109 patch("cli.build_welcome_banner"): 110 cli_obj.show_banner() 111 112 calls = [str(c) for c in cli_obj.console.print.call_args_list] 113 ollama_hints = [c for c in calls if "OLLAMA_CONTEXT_LENGTH" in c] 114 assert len(ollama_hints) == 1 115 116 def test_lm_studio_specific_hint(self, cli_obj): 117 """LM Studio-specific fix shown when port 1234 detected.""" 118 cli_obj.agent.context_compressor.context_length = 2048 119 cli_obj.base_url = "http://localhost:1234/v1" 120 with patch("cli.get_tool_definitions", return_value=[]), \ 121 patch("cli.build_welcome_banner"): 122 cli_obj.show_banner() 123 124 calls = [str(c) for c in cli_obj.console.print.call_args_list] 125 lms_hints = [c for c in calls if "LM Studio" in c] 126 assert len(lms_hints) == 1 127 128 def test_generic_hint_for_other_servers(self, cli_obj): 129 """Generic fix shown for unknown servers.""" 130 cli_obj.agent.context_compressor.context_length = 4096 131 cli_obj.base_url = "http://localhost:8080/v1" 132 with patch("cli.get_tool_definitions", return_value=[]), \ 133 patch("cli.build_welcome_banner"): 134 cli_obj.show_banner() 135 136 calls = [str(c) for c in cli_obj.console.print.call_args_list] 137 generic_hints = [c for c in calls if "config.yaml" in c] 138 assert len(generic_hints) == 1 139 140 def test_no_warning_when_no_context_length(self, cli_obj): 141 """No warning when context length is not yet known.""" 142 cli_obj.agent.context_compressor.context_length = None 143 with patch("cli.get_tool_definitions", return_value=[]), \ 144 patch("cli.build_welcome_banner"): 145 cli_obj.show_banner() 146 147 calls = [str(c) for c in cli_obj.console.print.call_args_list] 148 warning_calls = [c for c in calls if "too low" in c] 149 assert len(warning_calls) == 0 150 151 def test_compact_banner_does_not_crash_on_narrow_terminal(self, cli_obj): 152 """Compact mode should still have ctx_len defined for warning logic.""" 153 cli_obj.agent.context_compressor.context_length = 4096 154 155 with patch("shutil.get_terminal_size", return_value=os.terminal_size((70, 40))), \ 156 patch("cli._build_compact_banner", return_value="compact banner"): 157 cli_obj.show_banner() 158 159 calls = [str(c) for c in cli_obj.console.print.call_args_list] 160 warning_calls = [c for c in calls if "too low" in c] 161 assert len(warning_calls) == 1