/ tests / hermes_cli / test_model_switch_context_display.py
test_model_switch_context_display.py
  1  """Regression test for /model context-length display on provider-capped models.
  2  
  3  Bug (April 2026): `/model gpt-5.5` on openai-codex (ChatGPT OAuth) showed
  4  "Context: 1,050,000 tokens" because the display code used the raw models.dev
  5  ``ModelInfo.context_window`` (which reports the direct-OpenAI API value) instead
  6  of the provider-aware resolver. The agent was actually running at 272K — Codex
  7  OAuth's enforced cap — so the display was lying to the user.
  8  
  9  Fix: ``resolve_display_context_length()`` prefers
 10  ``agent.model_metadata.get_model_context_length`` (which knows about Codex OAuth,
 11  Copilot, Nous, etc.) and falls back to models.dev only if that returns nothing.
 12  """
 13  from __future__ import annotations
 14  
 15  from unittest.mock import patch
 16  
 17  from hermes_cli.model_switch import resolve_display_context_length
 18  
 19  
 20  class _FakeModelInfo:
 21      def __init__(self, ctx):
 22          self.context_window = ctx
 23  
 24  
 25  class TestResolveDisplayContextLength:
 26      def test_codex_oauth_overrides_models_dev(self):
 27          """gpt-5.5 on openai-codex must show Codex's 272K cap, not models.dev's 1.05M."""
 28          fake_mi = _FakeModelInfo(1_050_000)  # what models.dev reports
 29          with patch(
 30              "agent.model_metadata.get_model_context_length",
 31              return_value=272_000,  # what Codex OAuth actually enforces
 32          ):
 33              ctx = resolve_display_context_length(
 34                  "gpt-5.5",
 35                  "openai-codex",
 36                  base_url="https://chatgpt.com/backend-api/codex",
 37                  api_key="",
 38                  model_info=fake_mi,
 39              )
 40          assert ctx == 272_000, (
 41              "Codex OAuth's 272K cap must win over models.dev's 1.05M for gpt-5.5"
 42          )
 43  
 44      def test_falls_back_to_model_info_when_resolver_returns_none(self):
 45          fake_mi = _FakeModelInfo(1_048_576)
 46          with patch(
 47              "agent.model_metadata.get_model_context_length", return_value=None
 48          ):
 49              ctx = resolve_display_context_length(
 50                  "some-model",
 51                  "some-provider",
 52                  model_info=fake_mi,
 53              )
 54          assert ctx == 1_048_576
 55  
 56      def test_returns_none_when_both_sources_empty(self):
 57          with patch(
 58              "agent.model_metadata.get_model_context_length", return_value=None
 59          ):
 60              ctx = resolve_display_context_length(
 61                  "unknown-model",
 62                  "unknown-provider",
 63                  model_info=None,
 64              )
 65          assert ctx is None
 66  
 67      def test_resolver_exception_falls_back_to_model_info(self):
 68          fake_mi = _FakeModelInfo(200_000)
 69          with patch(
 70              "agent.model_metadata.get_model_context_length",
 71              side_effect=RuntimeError("network down"),
 72          ):
 73              ctx = resolve_display_context_length(
 74                  "x", "y", model_info=fake_mi
 75              )
 76          assert ctx == 200_000
 77  
 78      def test_prefers_resolver_even_when_model_info_has_larger_value(self):
 79          """Invariant: provider-aware resolver is authoritative, even if models.dev
 80          reports a bigger window."""
 81          fake_mi = _FakeModelInfo(2_000_000)
 82          with patch(
 83              "agent.model_metadata.get_model_context_length", return_value=128_000
 84          ):
 85              ctx = resolve_display_context_length(
 86                  "capped-model",
 87                  "capped-provider",
 88                  model_info=fake_mi,
 89              )
 90          assert ctx == 128_000
 91  
 92      def test_custom_providers_override_honored(self):
 93          """Regression for #15779: /model switch onto a custom provider must
 94          surface the configured per-model context_length, not the 128K/256K
 95          fallback.
 96          """
 97          custom_provs = [
 98              {
 99                  "name": "my-custom-endpoint",
100                  "base_url": "https://example.invalid/v1",
101                  "models": {"gpt-5.5": {"context_length": 1_050_000}},
102              }
103          ]
104          # Real resolver call — no mock — so the override path is exercised
105          # through agent.model_metadata.get_model_context_length.
106          from unittest.mock import patch as _p
107          from agent import model_metadata as _mm
108          with _p.object(_mm, "get_cached_context_length", return_value=None), \
109               _p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \
110               _p.object(_mm, "fetch_model_metadata", return_value={}), \
111               _p.object(_mm, "is_local_endpoint", return_value=False), \
112               _p.object(_mm, "_is_known_provider_base_url", return_value=False):
113              ctx = resolve_display_context_length(
114                  "gpt-5.5",
115                  "custom",
116                  base_url="https://example.invalid/v1",
117                  api_key="k",
118                  custom_providers=custom_provs,
119              )
120          assert ctx == 1_050_000, (
121              "custom_providers[].models.gpt-5.5.context_length=1.05M must win "
122              "over probe-down fallback"
123          )
124  
125      def test_custom_providers_trailing_slash_insensitive(self):
126          """Base URL comparison must tolerate trailing-slash differences
127          between config.yaml and the runtime value.
128          """
129          custom_provs = [
130              {
131                  "base_url": "https://example.invalid/v1/",
132                  "models": {"m": {"context_length": 400_000}},
133              }
134          ]
135          from unittest.mock import patch as _p
136          from agent import model_metadata as _mm
137          with _p.object(_mm, "get_cached_context_length", return_value=None), \
138               _p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \
139               _p.object(_mm, "fetch_model_metadata", return_value={}), \
140               _p.object(_mm, "is_local_endpoint", return_value=False), \
141               _p.object(_mm, "_is_known_provider_base_url", return_value=False):
142              ctx = resolve_display_context_length(
143                  "m",
144                  "custom",
145                  base_url="https://example.invalid/v1",  # no trailing slash
146                  custom_providers=custom_provs,
147              )
148          assert ctx == 400_000