/ agent / image_gen_registry.py
image_gen_registry.py
  1  """
  2  Image Generation Provider Registry
  3  ==================================
  4  
  5  Central map of registered providers. Populated by plugins at import-time via
  6  ``PluginContext.register_image_gen_provider()``; consumed by the
  7  ``image_generate`` tool to dispatch each call to the active backend.
  8  
  9  Active selection
 10  ----------------
 11  The active provider is chosen by ``image_gen.provider`` in ``config.yaml``.
 12  If unset, :func:`get_active_provider` applies fallback logic:
 13  
 14  1. If exactly one provider is registered, use it.
 15  2. Otherwise if a provider named ``fal`` is registered, use it (legacy
 16     default — matches pre-plugin behavior).
 17  3. Otherwise return ``None`` (the tool surfaces a helpful error pointing
 18     the user at ``hermes tools``).
 19  """
 20  
 21  from __future__ import annotations
 22  
 23  import logging
 24  import threading
 25  from typing import Dict, List, Optional
 26  
 27  from agent.image_gen_provider import ImageGenProvider
 28  
 29  logger = logging.getLogger(__name__)
 30  
 31  
 32  _providers: Dict[str, ImageGenProvider] = {}
 33  _lock = threading.Lock()
 34  
 35  
 36  def register_provider(provider: ImageGenProvider) -> None:
 37      """Register an image generation provider.
 38  
 39      Re-registration (same ``name``) overwrites the previous entry and logs
 40      a debug message — this makes hot-reload scenarios (tests, dev loops)
 41      behave predictably.
 42      """
 43      if not isinstance(provider, ImageGenProvider):
 44          raise TypeError(
 45              f"register_provider() expects an ImageGenProvider instance, "
 46              f"got {type(provider).__name__}"
 47          )
 48      name = provider.name
 49      if not isinstance(name, str) or not name.strip():
 50          raise ValueError("Image gen provider .name must be a non-empty string")
 51      with _lock:
 52          existing = _providers.get(name)
 53          _providers[name] = provider
 54      if existing is not None:
 55          logger.debug("Image gen provider '%s' re-registered (was %r)", name, type(existing).__name__)
 56      else:
 57          logger.debug("Registered image gen provider '%s' (%s)", name, type(provider).__name__)
 58  
 59  
 60  def list_providers() -> List[ImageGenProvider]:
 61      """Return all registered providers, sorted by name."""
 62      with _lock:
 63          items = list(_providers.values())
 64      return sorted(items, key=lambda p: p.name)
 65  
 66  
 67  def get_provider(name: str) -> Optional[ImageGenProvider]:
 68      """Return the provider registered under *name*, or None."""
 69      if not isinstance(name, str):
 70          return None
 71      with _lock:
 72          return _providers.get(name.strip())
 73  
 74  
 75  def get_active_provider() -> Optional[ImageGenProvider]:
 76      """Resolve the currently-active provider.
 77  
 78      Reads ``image_gen.provider`` from config.yaml; falls back per the
 79      module docstring.
 80      """
 81      configured: Optional[str] = None
 82      try:
 83          from hermes_cli.config import load_config
 84  
 85          cfg = load_config()
 86          section = cfg.get("image_gen") if isinstance(cfg, dict) else None
 87          if isinstance(section, dict):
 88              raw = section.get("provider")
 89              if isinstance(raw, str) and raw.strip():
 90                  configured = raw.strip()
 91      except Exception as exc:
 92          logger.debug("Could not read image_gen.provider from config: %s", exc)
 93  
 94      with _lock:
 95          snapshot = dict(_providers)
 96  
 97      if configured:
 98          provider = snapshot.get(configured)
 99          if provider is not None:
100              return provider
101          logger.debug(
102              "image_gen.provider='%s' configured but not registered; falling back",
103              configured,
104          )
105  
106      # Fallback: single-provider case
107      if len(snapshot) == 1:
108          return next(iter(snapshot.values()))
109  
110      # Fallback: prefer legacy FAL for backward compat
111      if "fal" in snapshot:
112          return snapshot["fal"]
113  
114      return None
115  
116  
117  def _reset_for_tests() -> None:
118      """Clear the registry. **Test-only.**"""
119      with _lock:
120          _providers.clear()