/ tests / run_agent / test_provider_fallback.py
test_provider_fallback.py
  1  """Tests for ordered provider fallback chain (salvage of PR #1761).
  2  
  3  Extends the single-fallback tests in test_fallback_model.py to cover
  4  the new list-based ``fallback_providers`` config format and chain
  5  advancement through multiple providers.
  6  """
  7  
  8  from unittest.mock import MagicMock, patch
  9  
 10  from run_agent import AIAgent, _pool_may_recover_from_rate_limit
 11  
 12  
 13  def _make_agent(fallback_model=None):
 14      """Create a minimal AIAgent with optional fallback config."""
 15      with (
 16          patch("run_agent.get_tool_definitions", return_value=[]),
 17          patch("run_agent.check_toolset_requirements", return_value={}),
 18          patch("run_agent.OpenAI"),
 19      ):
 20          agent = AIAgent(
 21              api_key="test-key",
 22              base_url="https://openrouter.ai/api/v1",
 23              quiet_mode=True,
 24              skip_context_files=True,
 25              skip_memory=True,
 26              fallback_model=fallback_model,
 27          )
 28          agent.client = MagicMock()
 29          return agent
 30  
 31  
 32  def _mock_client(base_url="https://openrouter.ai/api/v1", api_key="fb-key"):
 33      mock = MagicMock()
 34      mock.base_url = base_url
 35      mock.api_key = api_key
 36      return mock
 37  
 38  
 39  # ── Chain initialisation ──────────────────────────────────────────────────
 40  
 41  
 42  class TestFallbackChainInit:
 43      def test_no_fallback(self):
 44          agent = _make_agent(fallback_model=None)
 45          assert agent._fallback_chain == []
 46          assert agent._fallback_index == 0
 47          assert agent._fallback_model is None
 48  
 49      def test_single_dict_backwards_compat(self):
 50          fb = {"provider": "openai", "model": "gpt-4o"}
 51          agent = _make_agent(fallback_model=fb)
 52          assert agent._fallback_chain == [fb]
 53          assert agent._fallback_model == fb
 54  
 55      def test_list_of_providers(self):
 56          fbs = [
 57              {"provider": "openai", "model": "gpt-4o"},
 58              {"provider": "zai", "model": "glm-4.7"},
 59          ]
 60          agent = _make_agent(fallback_model=fbs)
 61          assert len(agent._fallback_chain) == 2
 62          assert agent._fallback_model == fbs[0]
 63  
 64      def test_invalid_entries_filtered(self):
 65          fbs = [
 66              {"provider": "openai", "model": "gpt-4o"},
 67              {"provider": "", "model": "glm-4.7"},
 68              {"provider": "zai"},
 69              "not-a-dict",
 70          ]
 71          agent = _make_agent(fallback_model=fbs)
 72          assert len(agent._fallback_chain) == 1
 73          assert agent._fallback_chain[0]["provider"] == "openai"
 74  
 75      def test_empty_list(self):
 76          agent = _make_agent(fallback_model=[])
 77          assert agent._fallback_chain == []
 78          assert agent._fallback_model is None
 79  
 80      def test_invalid_dict_no_provider(self):
 81          agent = _make_agent(fallback_model={"model": "gpt-4o"})
 82          assert agent._fallback_chain == []
 83  
 84  
 85  # ── Chain advancement ─────────────────────────────────────────────────────
 86  
 87  
 88  class TestFallbackChainAdvancement:
 89      def test_exhausted_returns_false(self):
 90          agent = _make_agent(fallback_model=None)
 91          assert agent._try_activate_fallback() is False
 92  
 93      def test_advances_index(self):
 94          fbs = [
 95              {"provider": "openai", "model": "gpt-4o"},
 96              {"provider": "zai", "model": "glm-4.7"},
 97          ]
 98          agent = _make_agent(fallback_model=fbs)
 99          with patch("agent.auxiliary_client.resolve_provider_client",
100                      return_value=(_mock_client(), "gpt-4o")):
101              assert agent._try_activate_fallback() is True
102              assert agent._fallback_index == 1
103              assert agent.model == "gpt-4o"
104              assert agent._fallback_activated is True
105  
106      def test_second_fallback_works(self):
107          fbs = [
108              {"provider": "openai", "model": "gpt-4o"},
109              {"provider": "zai", "model": "glm-4.7"},
110          ]
111          agent = _make_agent(fallback_model=fbs)
112          with patch("agent.auxiliary_client.resolve_provider_client",
113                      return_value=(_mock_client(), "resolved")):
114              assert agent._try_activate_fallback() is True
115              assert agent.model == "gpt-4o"
116              assert agent._try_activate_fallback() is True
117              assert agent.model == "glm-4.7"
118              assert agent._fallback_index == 2
119  
120      def test_all_exhausted_returns_false(self):
121          fbs = [{"provider": "openai", "model": "gpt-4o"}]
122          agent = _make_agent(fallback_model=fbs)
123          with patch("agent.auxiliary_client.resolve_provider_client",
124                      return_value=(_mock_client(), "gpt-4o")):
125              assert agent._try_activate_fallback() is True
126              assert agent._try_activate_fallback() is False
127  
128      def test_skips_unconfigured_provider_to_next(self):
129          """If resolve_provider_client returns None, skip to next in chain."""
130          fbs = [
131              {"provider": "broken", "model": "nope"},
132              {"provider": "openai", "model": "gpt-4o"},
133          ]
134          agent = _make_agent(fallback_model=fbs)
135          with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc:
136              mock_rpc.side_effect = [
137                  (None, None),                    # broken provider
138                  (_mock_client(), "gpt-4o"),       # fallback succeeds
139              ]
140              assert agent._try_activate_fallback() is True
141              assert agent.model == "gpt-4o"
142              assert agent._fallback_index == 2
143  
144      def test_skips_provider_that_raises_to_next(self):
145          """If resolve_provider_client raises, skip to next in chain."""
146          fbs = [
147              {"provider": "broken", "model": "nope"},
148              {"provider": "openai", "model": "gpt-4o"},
149          ]
150          agent = _make_agent(fallback_model=fbs)
151          with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc:
152              mock_rpc.side_effect = [
153                  RuntimeError("auth failed"),
154                  (_mock_client(), "gpt-4o"),
155              ]
156              assert agent._try_activate_fallback() is True
157              assert agent.model == "gpt-4o"
158  
159      def test_resolves_key_env_for_fallback_provider(self):
160          fbs = [
161              {
162                  "provider": "custom",
163                  "model": "fallback-model",
164                  "base_url": "https://fallback.example/v1",
165                  "key_env": "MY_FALLBACK_KEY",
166              }
167          ]
168          agent = _make_agent(fallback_model=fbs)
169          with (
170              patch.dict("os.environ", {"MY_FALLBACK_KEY": "env-secret"}, clear=False),
171              patch(
172                  "agent.auxiliary_client.resolve_provider_client",
173                  return_value=(
174                      _mock_client(
175                          base_url="https://fallback.example/v1",
176                          api_key="env-secret",
177                      ),
178                      "fallback-model",
179                  ),
180              ) as mock_rpc,
181          ):
182              assert agent._try_activate_fallback() is True
183              assert mock_rpc.call_args.kwargs["explicit_api_key"] == "env-secret"
184  
185  
186  # ── Pool-rotation vs fallback gating (#11314) ────────────────────────────
187  
188  
189  def _pool(n_entries: int, has_available: bool = True):
190      """Make a minimal credential-pool stand-in for rotation-room checks."""
191      pool = MagicMock()
192      pool.entries.return_value = [MagicMock() for _ in range(n_entries)]
193      pool.has_available.return_value = has_available
194      return pool
195  
196  
197  class TestPoolRotationRoom:
198      def test_none_pool_returns_false(self):
199          assert _pool_may_recover_from_rate_limit(None) is False
200  
201      def test_single_credential_returns_false(self):
202          """With one credential that just 429'd, rotation has nowhere to go.
203  
204          The pool may still report has_available() True once cooldown expires,
205          but retrying against the same entry will hit the same daily-quota
206          429 and burn the retry budget.  Must fall back.
207          """
208          assert _pool_may_recover_from_rate_limit(_pool(1)) is False
209  
210      def test_single_credential_in_cooldown_returns_false(self):
211          assert _pool_may_recover_from_rate_limit(_pool(1, has_available=False)) is False
212  
213      def test_two_credentials_available_returns_true(self):
214          """With >1 credentials and at least one available, rotate instead of fallback."""
215          assert _pool_may_recover_from_rate_limit(_pool(2)) is True
216  
217      def test_multiple_credentials_all_in_cooldown_returns_false(self):
218          """All credentials cooling down — fall back rather than wait."""
219          assert _pool_may_recover_from_rate_limit(_pool(3, has_available=False)) is False
220  
221      def test_many_credentials_available_returns_true(self):
222          assert _pool_may_recover_from_rate_limit(_pool(10)) is True