/ tests / test_agent2_providers.py
test_agent2_providers.py
  1  """Tests for agent2 provider factory and caching."""
  2  
  3  import types
  4  
  5  import pytest
  6  
  7  from restai.agent2.providers import (
  8      Agent2UnsupportedLLMError,
  9      AnthropicProvider,
 10      OpenAIProvider,
 11      ProviderConfig,
 12      _provider_cache,
 13      _provider_cache_key,
 14      build_provider_for_llm,
 15  )
 16  
 17  
 18  def _make_llm_row(
 19      name="test",
 20      class_name="OpenAI",
 21      options='{"model":"gpt-4o","api_key":"sk-fake"}',
 22      context_window=4096,
 23      privacy="public",
 24      description=None,
 25      input_cost=0.0,
 26      output_cost=0.0,
 27  ):
 28      return types.SimpleNamespace(
 29          id=1,
 30          name=name,
 31          class_name=class_name,
 32          options=options,
 33          context_window=context_window,
 34          privacy=privacy,
 35          description=description,
 36          input_cost=input_cost,
 37          output_cost=output_cost,
 38          teams=[],
 39      )
 40  
 41  
 42  def _clear_cache_for_row(row):
 43      """Remove a specific row's entry from the provider cache."""
 44      key = _provider_cache_key(row)
 45      _provider_cache.pop(key, None)
 46  
 47  
 48  def test_provider_config_construction():
 49      cfg = ProviderConfig(model="gpt-4o", api_key="sk-test")
 50      assert cfg.model == "gpt-4o"
 51      assert cfg.api_key == "sk-test"
 52      assert cfg.max_output_tokens == 4096
 53      assert cfg.base_url is None
 54  
 55  
 56  def test_build_provider_openai():
 57      row = _make_llm_row(class_name="OpenAI")
 58      _clear_cache_for_row(row)
 59      provider, cfg = build_provider_for_llm(row)
 60      assert isinstance(provider, OpenAIProvider)
 61      assert isinstance(cfg, ProviderConfig)
 62      assert cfg.model == "gpt-4o"
 63      _clear_cache_for_row(row)
 64  
 65  
 66  def test_build_provider_anthropic():
 67      row = _make_llm_row(
 68          class_name="Anthropic",
 69          options='{"model":"claude-3-5-sonnet-latest","api_key":"sk-ant-fake"}',
 70      )
 71      _clear_cache_for_row(row)
 72      provider, cfg = build_provider_for_llm(row)
 73      assert isinstance(provider, AnthropicProvider)
 74      assert cfg.api_key == "sk-ant-fake"
 75      _clear_cache_for_row(row)
 76  
 77  
 78  def test_build_provider_ollama():
 79      row = _make_llm_row(
 80          class_name="Ollama",
 81          options='{"model":"llama3","base_url":"http://localhost:11434"}',
 82      )
 83      _clear_cache_for_row(row)
 84      provider, cfg = build_provider_for_llm(row)
 85      assert isinstance(provider, OpenAIProvider)
 86      assert cfg.base_url.endswith("/v1")
 87      _clear_cache_for_row(row)
 88  
 89  
 90  def test_build_provider_unsupported():
 91      row = _make_llm_row(class_name="FakeProvider")
 92      _clear_cache_for_row(row)
 93      with pytest.raises(Exception):
 94          # class_name validation in LLMModel will reject "FakeProvider"
 95          build_provider_for_llm(row)
 96  
 97  
 98  def test_provider_cache_hit():
 99      row = _make_llm_row(
100          name="cache_test",
101          class_name="OpenAI",
102          options='{"model":"gpt-4o-mini","api_key":"sk-cache"}',
103      )
104      _clear_cache_for_row(row)
105      provider1, cfg1 = build_provider_for_llm(row)
106      provider2, cfg2 = build_provider_for_llm(row)
107      assert id(provider1) == id(provider2), "Expected same provider instance from cache"
108      assert id(cfg1) == id(cfg2), "Expected same config instance from cache"
109      _clear_cache_for_row(row)
110  
111  
112  def test_provider_cache_key_consistency():
113      row = _make_llm_row()
114      key1 = _provider_cache_key(row)
115      key2 = _provider_cache_key(row)
116      assert key1 == key2
117      assert isinstance(key1, tuple)
118  
119  
120  def test_provider_config_context_window():
121      row = _make_llm_row(context_window=128000)
122      _clear_cache_for_row(row)
123      _, cfg = build_provider_for_llm(row)
124      assert cfg.context_window == 128000
125      _clear_cache_for_row(row)