test_provider_fallback.py
1 """Tests for ordered provider fallback chain (salvage of PR #1761). 2 3 Extends the single-fallback tests in test_fallback_model.py to cover 4 the new list-based ``fallback_providers`` config format and chain 5 advancement through multiple providers. 6 """ 7 8 from unittest.mock import MagicMock, patch 9 10 from run_agent import AIAgent, _pool_may_recover_from_rate_limit 11 12 13 def _make_agent(fallback_model=None): 14 """Create a minimal AIAgent with optional fallback config.""" 15 with ( 16 patch("run_agent.get_tool_definitions", return_value=[]), 17 patch("run_agent.check_toolset_requirements", return_value={}), 18 patch("run_agent.OpenAI"), 19 ): 20 agent = AIAgent( 21 api_key="test-key", 22 base_url="https://openrouter.ai/api/v1", 23 quiet_mode=True, 24 skip_context_files=True, 25 skip_memory=True, 26 fallback_model=fallback_model, 27 ) 28 agent.client = MagicMock() 29 return agent 30 31 32 def _mock_client(base_url="https://openrouter.ai/api/v1", api_key="fb-key"): 33 mock = MagicMock() 34 mock.base_url = base_url 35 mock.api_key = api_key 36 return mock 37 38 39 # ── Chain initialisation ────────────────────────────────────────────────── 40 41 42 class TestFallbackChainInit: 43 def test_no_fallback(self): 44 agent = _make_agent(fallback_model=None) 45 assert agent._fallback_chain == [] 46 assert agent._fallback_index == 0 47 assert agent._fallback_model is None 48 49 def test_single_dict_backwards_compat(self): 50 fb = {"provider": "openai", "model": "gpt-4o"} 51 agent = _make_agent(fallback_model=fb) 52 assert agent._fallback_chain == [fb] 53 assert agent._fallback_model == fb 54 55 def test_list_of_providers(self): 56 fbs = [ 57 {"provider": "openai", "model": "gpt-4o"}, 58 {"provider": "zai", "model": "glm-4.7"}, 59 ] 60 agent = _make_agent(fallback_model=fbs) 61 assert len(agent._fallback_chain) == 2 62 assert agent._fallback_model == fbs[0] 63 64 def test_invalid_entries_filtered(self): 65 fbs = [ 66 {"provider": "openai", "model": "gpt-4o"}, 67 {"provider": "", "model": "glm-4.7"}, 68 {"provider": "zai"}, 69 "not-a-dict", 70 ] 71 agent = _make_agent(fallback_model=fbs) 72 assert len(agent._fallback_chain) == 1 73 assert agent._fallback_chain[0]["provider"] == "openai" 74 75 def test_empty_list(self): 76 agent = _make_agent(fallback_model=[]) 77 assert agent._fallback_chain == [] 78 assert agent._fallback_model is None 79 80 def test_invalid_dict_no_provider(self): 81 agent = _make_agent(fallback_model={"model": "gpt-4o"}) 82 assert agent._fallback_chain == [] 83 84 85 # ── Chain advancement ───────────────────────────────────────────────────── 86 87 88 class TestFallbackChainAdvancement: 89 def test_exhausted_returns_false(self): 90 agent = _make_agent(fallback_model=None) 91 assert agent._try_activate_fallback() is False 92 93 def test_advances_index(self): 94 fbs = [ 95 {"provider": "openai", "model": "gpt-4o"}, 96 {"provider": "zai", "model": "glm-4.7"}, 97 ] 98 agent = _make_agent(fallback_model=fbs) 99 with patch("agent.auxiliary_client.resolve_provider_client", 100 return_value=(_mock_client(), "gpt-4o")): 101 assert agent._try_activate_fallback() is True 102 assert agent._fallback_index == 1 103 assert agent.model == "gpt-4o" 104 assert agent._fallback_activated is True 105 106 def test_second_fallback_works(self): 107 fbs = [ 108 {"provider": "openai", "model": "gpt-4o"}, 109 {"provider": "zai", "model": "glm-4.7"}, 110 ] 111 agent = _make_agent(fallback_model=fbs) 112 with patch("agent.auxiliary_client.resolve_provider_client", 113 return_value=(_mock_client(), "resolved")): 114 assert agent._try_activate_fallback() is True 115 assert agent.model == "gpt-4o" 116 assert agent._try_activate_fallback() is True 117 assert agent.model == "glm-4.7" 118 assert agent._fallback_index == 2 119 120 def test_all_exhausted_returns_false(self): 121 fbs = [{"provider": "openai", "model": "gpt-4o"}] 122 agent = _make_agent(fallback_model=fbs) 123 with patch("agent.auxiliary_client.resolve_provider_client", 124 return_value=(_mock_client(), "gpt-4o")): 125 assert agent._try_activate_fallback() is True 126 assert agent._try_activate_fallback() is False 127 128 def test_skips_unconfigured_provider_to_next(self): 129 """If resolve_provider_client returns None, skip to next in chain.""" 130 fbs = [ 131 {"provider": "broken", "model": "nope"}, 132 {"provider": "openai", "model": "gpt-4o"}, 133 ] 134 agent = _make_agent(fallback_model=fbs) 135 with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc: 136 mock_rpc.side_effect = [ 137 (None, None), # broken provider 138 (_mock_client(), "gpt-4o"), # fallback succeeds 139 ] 140 assert agent._try_activate_fallback() is True 141 assert agent.model == "gpt-4o" 142 assert agent._fallback_index == 2 143 144 def test_skips_provider_that_raises_to_next(self): 145 """If resolve_provider_client raises, skip to next in chain.""" 146 fbs = [ 147 {"provider": "broken", "model": "nope"}, 148 {"provider": "openai", "model": "gpt-4o"}, 149 ] 150 agent = _make_agent(fallback_model=fbs) 151 with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc: 152 mock_rpc.side_effect = [ 153 RuntimeError("auth failed"), 154 (_mock_client(), "gpt-4o"), 155 ] 156 assert agent._try_activate_fallback() is True 157 assert agent.model == "gpt-4o" 158 159 def test_resolves_key_env_for_fallback_provider(self): 160 fbs = [ 161 { 162 "provider": "custom", 163 "model": "fallback-model", 164 "base_url": "https://fallback.example/v1", 165 "key_env": "MY_FALLBACK_KEY", 166 } 167 ] 168 agent = _make_agent(fallback_model=fbs) 169 with ( 170 patch.dict("os.environ", {"MY_FALLBACK_KEY": "env-secret"}, clear=False), 171 patch( 172 "agent.auxiliary_client.resolve_provider_client", 173 return_value=( 174 _mock_client( 175 base_url="https://fallback.example/v1", 176 api_key="env-secret", 177 ), 178 "fallback-model", 179 ), 180 ) as mock_rpc, 181 ): 182 assert agent._try_activate_fallback() is True 183 assert mock_rpc.call_args.kwargs["explicit_api_key"] == "env-secret" 184 185 186 # ── Pool-rotation vs fallback gating (#11314) ──────────────────────────── 187 188 189 def _pool(n_entries: int, has_available: bool = True): 190 """Make a minimal credential-pool stand-in for rotation-room checks.""" 191 pool = MagicMock() 192 pool.entries.return_value = [MagicMock() for _ in range(n_entries)] 193 pool.has_available.return_value = has_available 194 return pool 195 196 197 class TestPoolRotationRoom: 198 def test_none_pool_returns_false(self): 199 assert _pool_may_recover_from_rate_limit(None) is False 200 201 def test_single_credential_returns_false(self): 202 """With one credential that just 429'd, rotation has nowhere to go. 203 204 The pool may still report has_available() True once cooldown expires, 205 but retrying against the same entry will hit the same daily-quota 206 429 and burn the retry budget. Must fall back. 207 """ 208 assert _pool_may_recover_from_rate_limit(_pool(1)) is False 209 210 def test_single_credential_in_cooldown_returns_false(self): 211 assert _pool_may_recover_from_rate_limit(_pool(1, has_available=False)) is False 212 213 def test_two_credentials_available_returns_true(self): 214 """With >1 credentials and at least one available, rotate instead of fallback.""" 215 assert _pool_may_recover_from_rate_limit(_pool(2)) is True 216 217 def test_multiple_credentials_all_in_cooldown_returns_false(self): 218 """All credentials cooling down — fall back rather than wait.""" 219 assert _pool_may_recover_from_rate_limit(_pool(3, has_available=False)) is False 220 221 def test_many_credentials_available_returns_true(self): 222 assert _pool_may_recover_from_rate_limit(_pool(10)) is True