/ tests / hermes_cli / test_custom_provider_context_length.py
test_custom_provider_context_length.py
  1  """Regression tests for custom_providers per-model context_length resolution.
  2  
  3  Covers the fix for #15779 — mid-session /model switch to a named custom
  4  provider must honor ``custom_providers[].models.<id>.context_length`` the
  5  same way startup already does.
  6  """
  7  from __future__ import annotations
  8  
  9  from unittest.mock import patch
 10  
 11  from hermes_cli.config import get_custom_provider_context_length
 12  
 13  
 14  class TestGetCustomProviderContextLength:
 15      def test_returns_override_for_matching_entry(self):
 16          custom = [
 17              {
 18                  "name": "my-endpoint",
 19                  "base_url": "https://example.invalid/v1",
 20                  "models": {"gpt-5.5": {"context_length": 1_050_000}},
 21              }
 22          ]
 23          assert (
 24              get_custom_provider_context_length(
 25                  "gpt-5.5", "https://example.invalid/v1", custom
 26              )
 27              == 1_050_000
 28          )
 29  
 30      def test_trailing_slash_insensitive(self):
 31          custom = [
 32              {
 33                  "base_url": "https://example.invalid/v1/",
 34                  "models": {"m": {"context_length": 500_000}},
 35              }
 36          ]
 37          # config has trailing slash, runtime doesn't — must match
 38          assert (
 39              get_custom_provider_context_length(
 40                  "m", "https://example.invalid/v1", custom
 41              )
 42              == 500_000
 43          )
 44          # and the reverse
 45          custom2 = [
 46              {
 47                  "base_url": "https://example.invalid/v1",
 48                  "models": {"m": {"context_length": 500_000}},
 49              }
 50          ]
 51          assert (
 52              get_custom_provider_context_length(
 53                  "m", "https://example.invalid/v1/", custom2
 54              )
 55              == 500_000
 56          )
 57  
 58      def test_returns_none_when_url_does_not_match(self):
 59          custom = [
 60              {
 61                  "base_url": "https://example.invalid/v1",
 62                  "models": {"m": {"context_length": 400_000}},
 63              }
 64          ]
 65          assert (
 66              get_custom_provider_context_length(
 67                  "m", "https://other.invalid/v1", custom
 68              )
 69              is None
 70          )
 71  
 72      def test_returns_none_when_model_does_not_match(self):
 73          custom = [
 74              {
 75                  "base_url": "https://example.invalid/v1",
 76                  "models": {"gpt-5.5": {"context_length": 400_000}},
 77              }
 78          ]
 79          assert (
 80              get_custom_provider_context_length(
 81                  "different-model", "https://example.invalid/v1", custom
 82              )
 83              is None
 84          )
 85  
 86      def test_returns_none_for_string_value(self):
 87          """'256K' string is not a valid int — skip silently.
 88  
 89          (The inline startup path still emits a user-visible warning; the
 90          helper itself returns None so downstream fallbacks can run.)
 91          """
 92          custom = [
 93              {
 94                  "base_url": "https://example.invalid/v1",
 95                  "models": {"m": {"context_length": "256K"}},
 96              }
 97          ]
 98          assert (
 99              get_custom_provider_context_length(
100                  "m", "https://example.invalid/v1", custom
101              )
102              is None
103          )
104  
105      def test_returns_none_for_zero_or_negative(self):
106          for bad in (0, -1, -100):
107              custom = [
108                  {
109                      "base_url": "https://example.invalid/v1",
110                      "models": {"m": {"context_length": bad}},
111                  }
112              ]
113              assert (
114                  get_custom_provider_context_length(
115                      "m", "https://example.invalid/v1", custom
116                  )
117                  is None
118              ), f"value {bad!r} should be rejected"
119  
120      def test_empty_inputs_return_none(self):
121          assert get_custom_provider_context_length("", "http://x", [{"base_url": "http://x", "models": {"": {"context_length": 1}}}]) is None
122          assert get_custom_provider_context_length("m", "", [{"base_url": "", "models": {"m": {"context_length": 1}}}]) is None
123          assert get_custom_provider_context_length("m", "http://x", None) is None
124          assert get_custom_provider_context_length("m", "http://x", []) is None
125  
126      def test_ignores_non_dict_entries(self):
127          """Malformed entries must not crash the lookup."""
128          custom = [
129              "not a dict",
130              None,
131              {"base_url": "https://example.invalid/v1", "models": "not a dict"},
132              {"base_url": "https://example.invalid/v1", "models": {"m": "not a dict"}},
133              {
134                  "base_url": "https://example.invalid/v1",
135                  "models": {"m": {"context_length": 400_000}},
136              },
137          ]
138          assert (
139              get_custom_provider_context_length(
140                  "m", "https://example.invalid/v1", custom
141              )
142              == 400_000
143          )
144  
145  
146  class TestGetModelContextLengthHonorsOverride:
147      """agent.model_metadata.get_model_context_length must honor the
148      custom_providers override at step 0b — before any probe, cache hit,
149      or models.dev lookup can override it.
150      """
151  
152      def _mock_all_probes(self):
153          """Context manager that disables every downstream resolution step."""
154          from agent import model_metadata as _mm
155          return [
156              patch.object(_mm, "get_cached_context_length", return_value=None),
157              patch.object(_mm, "fetch_endpoint_model_metadata", return_value={}),
158              patch.object(_mm, "fetch_model_metadata", return_value={}),
159              patch.object(_mm, "is_local_endpoint", return_value=False),
160              patch.object(_mm, "_is_known_provider_base_url", return_value=False),
161          ]
162  
163      def test_custom_providers_override_wins_over_default_fallback(self):
164          from agent.model_metadata import get_model_context_length
165          custom = [
166              {
167                  "base_url": "https://example.invalid/v1",
168                  "models": {"gpt-5.5": {"context_length": 1_050_000}},
169              }
170          ]
171          patches = self._mock_all_probes()
172          for p in patches:
173              p.start()
174          try:
175              ctx = get_model_context_length(
176                  "gpt-5.5",
177                  base_url="https://example.invalid/v1",
178                  provider="custom",
179                  custom_providers=custom,
180              )
181          finally:
182              for p in patches:
183                  p.stop()
184          assert ctx == 1_050_000
185  
186      def test_explicit_config_context_length_still_wins(self):
187          """Top-level model.context_length (step 0) outranks custom_providers (step 0b).
188  
189          Users who set both should see the top-level value — that's the
190          documented precedence and matches the long-standing step-0 behavior.
191          """
192          from agent.model_metadata import get_model_context_length
193          custom = [
194              {
195                  "base_url": "https://example.invalid/v1",
196                  "models": {"m": {"context_length": 1_050_000}},
197              }
198          ]
199          ctx = get_model_context_length(
200              "m",
201              base_url="https://example.invalid/v1",
202              provider="custom",
203              config_context_length=500_000,  # explicit top-level wins
204              custom_providers=custom,
205          )
206          assert ctx == 500_000
207  
208      def test_no_override_falls_through_to_default(self):
209          """With custom_providers=None and all probes disabled, resolver
210          returns DEFAULT_FALLBACK_CONTEXT (256K after the stepdown bump).
211          """
212          from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT
213          patches = self._mock_all_probes()
214          for p in patches:
215              p.start()
216          try:
217              ctx = get_model_context_length(
218                  "unknown-model",
219                  base_url="https://example.invalid/v1",
220                  provider="custom",
221                  custom_providers=None,
222              )
223          finally:
224              for p in patches:
225                  p.stop()
226          assert ctx == DEFAULT_FALLBACK_CONTEXT
227  
228  
229  class TestContextProbeTiers:
230      def test_256k_is_top_tier_and_default(self):
231          """The stepdown probe starts at 256K and 256K is the new default."""
232          from agent.model_metadata import CONTEXT_PROBE_TIERS, DEFAULT_FALLBACK_CONTEXT
233  
234          assert CONTEXT_PROBE_TIERS[0] == 256_000
235          assert DEFAULT_FALLBACK_CONTEXT == 256_000
236          # Tiers still descend monotonically
237          for a, b in zip(CONTEXT_PROBE_TIERS, CONTEXT_PROBE_TIERS[1:]):
238              assert a > b, f"tiers must strictly descend, got {a} then {b}"
239          # 128K is still a tier (users relying on it probe-down get there)
240          assert 128_000 in CONTEXT_PROBE_TIERS