test_custom_provider_context_length.py
1 """Regression tests for custom_providers per-model context_length resolution. 2 3 Covers the fix for #15779 — mid-session /model switch to a named custom 4 provider must honor ``custom_providers[].models.<id>.context_length`` the 5 same way startup already does. 6 """ 7 from __future__ import annotations 8 9 from unittest.mock import patch 10 11 from hermes_cli.config import get_custom_provider_context_length 12 13 14 class TestGetCustomProviderContextLength: 15 def test_returns_override_for_matching_entry(self): 16 custom = [ 17 { 18 "name": "my-endpoint", 19 "base_url": "https://example.invalid/v1", 20 "models": {"gpt-5.5": {"context_length": 1_050_000}}, 21 } 22 ] 23 assert ( 24 get_custom_provider_context_length( 25 "gpt-5.5", "https://example.invalid/v1", custom 26 ) 27 == 1_050_000 28 ) 29 30 def test_trailing_slash_insensitive(self): 31 custom = [ 32 { 33 "base_url": "https://example.invalid/v1/", 34 "models": {"m": {"context_length": 500_000}}, 35 } 36 ] 37 # config has trailing slash, runtime doesn't — must match 38 assert ( 39 get_custom_provider_context_length( 40 "m", "https://example.invalid/v1", custom 41 ) 42 == 500_000 43 ) 44 # and the reverse 45 custom2 = [ 46 { 47 "base_url": "https://example.invalid/v1", 48 "models": {"m": {"context_length": 500_000}}, 49 } 50 ] 51 assert ( 52 get_custom_provider_context_length( 53 "m", "https://example.invalid/v1/", custom2 54 ) 55 == 500_000 56 ) 57 58 def test_returns_none_when_url_does_not_match(self): 59 custom = [ 60 { 61 "base_url": "https://example.invalid/v1", 62 "models": {"m": {"context_length": 400_000}}, 63 } 64 ] 65 assert ( 66 get_custom_provider_context_length( 67 "m", "https://other.invalid/v1", custom 68 ) 69 is None 70 ) 71 72 def test_returns_none_when_model_does_not_match(self): 73 custom = [ 74 { 75 "base_url": "https://example.invalid/v1", 76 "models": {"gpt-5.5": {"context_length": 400_000}}, 77 } 78 ] 79 assert ( 80 get_custom_provider_context_length( 81 "different-model", "https://example.invalid/v1", custom 82 ) 83 is None 84 ) 85 86 def test_returns_none_for_string_value(self): 87 """'256K' string is not a valid int — skip silently. 88 89 (The inline startup path still emits a user-visible warning; the 90 helper itself returns None so downstream fallbacks can run.) 91 """ 92 custom = [ 93 { 94 "base_url": "https://example.invalid/v1", 95 "models": {"m": {"context_length": "256K"}}, 96 } 97 ] 98 assert ( 99 get_custom_provider_context_length( 100 "m", "https://example.invalid/v1", custom 101 ) 102 is None 103 ) 104 105 def test_returns_none_for_zero_or_negative(self): 106 for bad in (0, -1, -100): 107 custom = [ 108 { 109 "base_url": "https://example.invalid/v1", 110 "models": {"m": {"context_length": bad}}, 111 } 112 ] 113 assert ( 114 get_custom_provider_context_length( 115 "m", "https://example.invalid/v1", custom 116 ) 117 is None 118 ), f"value {bad!r} should be rejected" 119 120 def test_empty_inputs_return_none(self): 121 assert get_custom_provider_context_length("", "http://x", [{"base_url": "http://x", "models": {"": {"context_length": 1}}}]) is None 122 assert get_custom_provider_context_length("m", "", [{"base_url": "", "models": {"m": {"context_length": 1}}}]) is None 123 assert get_custom_provider_context_length("m", "http://x", None) is None 124 assert get_custom_provider_context_length("m", "http://x", []) is None 125 126 def test_ignores_non_dict_entries(self): 127 """Malformed entries must not crash the lookup.""" 128 custom = [ 129 "not a dict", 130 None, 131 {"base_url": "https://example.invalid/v1", "models": "not a dict"}, 132 {"base_url": "https://example.invalid/v1", "models": {"m": "not a dict"}}, 133 { 134 "base_url": "https://example.invalid/v1", 135 "models": {"m": {"context_length": 400_000}}, 136 }, 137 ] 138 assert ( 139 get_custom_provider_context_length( 140 "m", "https://example.invalid/v1", custom 141 ) 142 == 400_000 143 ) 144 145 146 class TestGetModelContextLengthHonorsOverride: 147 """agent.model_metadata.get_model_context_length must honor the 148 custom_providers override at step 0b — before any probe, cache hit, 149 or models.dev lookup can override it. 150 """ 151 152 def _mock_all_probes(self): 153 """Context manager that disables every downstream resolution step.""" 154 from agent import model_metadata as _mm 155 return [ 156 patch.object(_mm, "get_cached_context_length", return_value=None), 157 patch.object(_mm, "fetch_endpoint_model_metadata", return_value={}), 158 patch.object(_mm, "fetch_model_metadata", return_value={}), 159 patch.object(_mm, "is_local_endpoint", return_value=False), 160 patch.object(_mm, "_is_known_provider_base_url", return_value=False), 161 ] 162 163 def test_custom_providers_override_wins_over_default_fallback(self): 164 from agent.model_metadata import get_model_context_length 165 custom = [ 166 { 167 "base_url": "https://example.invalid/v1", 168 "models": {"gpt-5.5": {"context_length": 1_050_000}}, 169 } 170 ] 171 patches = self._mock_all_probes() 172 for p in patches: 173 p.start() 174 try: 175 ctx = get_model_context_length( 176 "gpt-5.5", 177 base_url="https://example.invalid/v1", 178 provider="custom", 179 custom_providers=custom, 180 ) 181 finally: 182 for p in patches: 183 p.stop() 184 assert ctx == 1_050_000 185 186 def test_explicit_config_context_length_still_wins(self): 187 """Top-level model.context_length (step 0) outranks custom_providers (step 0b). 188 189 Users who set both should see the top-level value — that's the 190 documented precedence and matches the long-standing step-0 behavior. 191 """ 192 from agent.model_metadata import get_model_context_length 193 custom = [ 194 { 195 "base_url": "https://example.invalid/v1", 196 "models": {"m": {"context_length": 1_050_000}}, 197 } 198 ] 199 ctx = get_model_context_length( 200 "m", 201 base_url="https://example.invalid/v1", 202 provider="custom", 203 config_context_length=500_000, # explicit top-level wins 204 custom_providers=custom, 205 ) 206 assert ctx == 500_000 207 208 def test_no_override_falls_through_to_default(self): 209 """With custom_providers=None and all probes disabled, resolver 210 returns DEFAULT_FALLBACK_CONTEXT (256K after the stepdown bump). 211 """ 212 from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT 213 patches = self._mock_all_probes() 214 for p in patches: 215 p.start() 216 try: 217 ctx = get_model_context_length( 218 "unknown-model", 219 base_url="https://example.invalid/v1", 220 provider="custom", 221 custom_providers=None, 222 ) 223 finally: 224 for p in patches: 225 p.stop() 226 assert ctx == DEFAULT_FALLBACK_CONTEXT 227 228 229 class TestContextProbeTiers: 230 def test_256k_is_top_tier_and_default(self): 231 """The stepdown probe starts at 256K and 256K is the new default.""" 232 from agent.model_metadata import CONTEXT_PROBE_TIERS, DEFAULT_FALLBACK_CONTEXT 233 234 assert CONTEXT_PROBE_TIERS[0] == 256_000 235 assert DEFAULT_FALLBACK_CONTEXT == 256_000 236 # Tiers still descend monotonically 237 for a, b in zip(CONTEXT_PROBE_TIERS, CONTEXT_PROBE_TIERS[1:]): 238 assert a > b, f"tiers must strictly descend, got {a} then {b}" 239 # 128K is still a tier (users relying on it probe-down get there) 240 assert 128_000 in CONTEXT_PROBE_TIERS