/ tests / agent / test_image_gen_registry.py
test_image_gen_registry.py
  1  """Tests for agent/image_gen_registry.py — provider registration & active lookup."""
  2  
  3  from __future__ import annotations
  4  
  5  import pytest
  6  
  7  from agent import image_gen_registry
  8  from agent.image_gen_provider import ImageGenProvider
  9  
 10  
 11  class _FakeProvider(ImageGenProvider):
 12      def __init__(self, name: str, available: bool = True):
 13          self._name = name
 14          self._available = available
 15  
 16      @property
 17      def name(self) -> str:
 18          return self._name
 19  
 20      def is_available(self) -> bool:
 21          return self._available
 22  
 23      def generate(self, prompt, aspect_ratio="landscape", **kw):
 24          return {"success": True, "image": f"{self._name}://{prompt}"}
 25  
 26  
 27  @pytest.fixture(autouse=True)
 28  def _reset_registry():
 29      image_gen_registry._reset_for_tests()
 30      yield
 31      image_gen_registry._reset_for_tests()
 32  
 33  
 34  class TestRegisterProvider:
 35      def test_register_and_lookup(self):
 36          provider = _FakeProvider("fake")
 37          image_gen_registry.register_provider(provider)
 38          assert image_gen_registry.get_provider("fake") is provider
 39  
 40      def test_rejects_non_provider(self):
 41          with pytest.raises(TypeError):
 42              image_gen_registry.register_provider("not a provider")  # type: ignore[arg-type]
 43  
 44      def test_rejects_empty_name(self):
 45          class Empty(ImageGenProvider):
 46              @property
 47              def name(self) -> str:
 48                  return ""
 49  
 50              def generate(self, prompt, aspect_ratio="landscape", **kw):
 51                  return {}
 52  
 53          with pytest.raises(ValueError):
 54              image_gen_registry.register_provider(Empty())
 55  
 56      def test_reregister_overwrites(self):
 57          a = _FakeProvider("same")
 58          b = _FakeProvider("same")
 59          image_gen_registry.register_provider(a)
 60          image_gen_registry.register_provider(b)
 61          assert image_gen_registry.get_provider("same") is b
 62  
 63      def test_list_is_sorted(self):
 64          image_gen_registry.register_provider(_FakeProvider("zeta"))
 65          image_gen_registry.register_provider(_FakeProvider("alpha"))
 66          names = [p.name for p in image_gen_registry.list_providers()]
 67          assert names == ["alpha", "zeta"]
 68  
 69  
 70  class TestGetActiveProvider:
 71      def test_single_provider_autoresolves(self, tmp_path, monkeypatch):
 72          monkeypatch.setenv("HERMES_HOME", str(tmp_path))
 73          image_gen_registry.register_provider(_FakeProvider("solo"))
 74          active = image_gen_registry.get_active_provider()
 75          assert active is not None and active.name == "solo"
 76  
 77      def test_fal_preferred_on_multi_without_config(self, tmp_path, monkeypatch):
 78          monkeypatch.setenv("HERMES_HOME", str(tmp_path))
 79          image_gen_registry.register_provider(_FakeProvider("fal"))
 80          image_gen_registry.register_provider(_FakeProvider("openai"))
 81          active = image_gen_registry.get_active_provider()
 82          assert active is not None and active.name == "fal"
 83  
 84      def test_explicit_config_wins(self, tmp_path, monkeypatch):
 85          import yaml
 86  
 87          monkeypatch.setenv("HERMES_HOME", str(tmp_path))
 88          (tmp_path / "config.yaml").write_text(
 89              yaml.safe_dump({"image_gen": {"provider": "openai"}})
 90          )
 91          image_gen_registry.register_provider(_FakeProvider("fal"))
 92          image_gen_registry.register_provider(_FakeProvider("openai"))
 93          active = image_gen_registry.get_active_provider()
 94          assert active is not None and active.name == "openai"
 95  
 96      def test_missing_configured_provider_falls_back(self, tmp_path, monkeypatch):
 97          import yaml
 98  
 99          monkeypatch.setenv("HERMES_HOME", str(tmp_path))
100          (tmp_path / "config.yaml").write_text(
101              yaml.safe_dump({"image_gen": {"provider": "replicate"}})
102          )
103          # Only FAL is registered — configured provider doesn't exist
104          image_gen_registry.register_provider(_FakeProvider("fal"))
105          active = image_gen_registry.get_active_provider()
106          # Falls back to FAL preference (legacy default) rather than None
107          assert active is not None and active.name == "fal"
108  
109      def test_none_when_empty(self, tmp_path, monkeypatch):
110          monkeypatch.setenv("HERMES_HOME", str(tmp_path))
111          assert image_gen_registry.get_active_provider() is None