test_agent2_providers.py
1 """Tests for agent2 provider factory and caching.""" 2 3 import types 4 5 import pytest 6 7 from restai.agent2.providers import ( 8 Agent2UnsupportedLLMError, 9 AnthropicProvider, 10 OpenAIProvider, 11 ProviderConfig, 12 _provider_cache, 13 _provider_cache_key, 14 build_provider_for_llm, 15 ) 16 17 18 def _make_llm_row( 19 name="test", 20 class_name="OpenAI", 21 options='{"model":"gpt-4o","api_key":"sk-fake"}', 22 context_window=4096, 23 privacy="public", 24 description=None, 25 input_cost=0.0, 26 output_cost=0.0, 27 ): 28 return types.SimpleNamespace( 29 id=1, 30 name=name, 31 class_name=class_name, 32 options=options, 33 context_window=context_window, 34 privacy=privacy, 35 description=description, 36 input_cost=input_cost, 37 output_cost=output_cost, 38 teams=[], 39 ) 40 41 42 def _clear_cache_for_row(row): 43 """Remove a specific row's entry from the provider cache.""" 44 key = _provider_cache_key(row) 45 _provider_cache.pop(key, None) 46 47 48 def test_provider_config_construction(): 49 cfg = ProviderConfig(model="gpt-4o", api_key="sk-test") 50 assert cfg.model == "gpt-4o" 51 assert cfg.api_key == "sk-test" 52 assert cfg.max_output_tokens == 4096 53 assert cfg.base_url is None 54 55 56 def test_build_provider_openai(): 57 row = _make_llm_row(class_name="OpenAI") 58 _clear_cache_for_row(row) 59 provider, cfg = build_provider_for_llm(row) 60 assert isinstance(provider, OpenAIProvider) 61 assert isinstance(cfg, ProviderConfig) 62 assert cfg.model == "gpt-4o" 63 _clear_cache_for_row(row) 64 65 66 def test_build_provider_anthropic(): 67 row = _make_llm_row( 68 class_name="Anthropic", 69 options='{"model":"claude-3-5-sonnet-latest","api_key":"sk-ant-fake"}', 70 ) 71 _clear_cache_for_row(row) 72 provider, cfg = build_provider_for_llm(row) 73 assert isinstance(provider, AnthropicProvider) 74 assert cfg.api_key == "sk-ant-fake" 75 _clear_cache_for_row(row) 76 77 78 def test_build_provider_ollama(): 79 row = _make_llm_row( 80 class_name="Ollama", 81 options='{"model":"llama3","base_url":"http://localhost:11434"}', 82 ) 83 _clear_cache_for_row(row) 84 provider, cfg = build_provider_for_llm(row) 85 assert isinstance(provider, OpenAIProvider) 86 assert cfg.base_url.endswith("/v1") 87 _clear_cache_for_row(row) 88 89 90 def test_build_provider_unsupported(): 91 row = _make_llm_row(class_name="FakeProvider") 92 _clear_cache_for_row(row) 93 with pytest.raises(Exception): 94 # class_name validation in LLMModel will reject "FakeProvider" 95 build_provider_for_llm(row) 96 97 98 def test_provider_cache_hit(): 99 row = _make_llm_row( 100 name="cache_test", 101 class_name="OpenAI", 102 options='{"model":"gpt-4o-mini","api_key":"sk-cache"}', 103 ) 104 _clear_cache_for_row(row) 105 provider1, cfg1 = build_provider_for_llm(row) 106 provider2, cfg2 = build_provider_for_llm(row) 107 assert id(provider1) == id(provider2), "Expected same provider instance from cache" 108 assert id(cfg1) == id(cfg2), "Expected same config instance from cache" 109 _clear_cache_for_row(row) 110 111 112 def test_provider_cache_key_consistency(): 113 row = _make_llm_row() 114 key1 = _provider_cache_key(row) 115 key2 = _provider_cache_key(row) 116 assert key1 == key2 117 assert isinstance(key1, tuple) 118 119 120 def test_provider_config_context_window(): 121 row = _make_llm_row(context_window=128000) 122 _clear_cache_for_row(row) 123 _, cfg = build_provider_for_llm(row) 124 assert cfg.context_window == 128000 125 _clear_cache_for_row(row)