image_gen_registry.py
1 """ 2 Image Generation Provider Registry 3 ================================== 4 5 Central map of registered providers. Populated by plugins at import-time via 6 ``PluginContext.register_image_gen_provider()``; consumed by the 7 ``image_generate`` tool to dispatch each call to the active backend. 8 9 Active selection 10 ---------------- 11 The active provider is chosen by ``image_gen.provider`` in ``config.yaml``. 12 If unset, :func:`get_active_provider` applies fallback logic: 13 14 1. If exactly one provider is registered, use it. 15 2. Otherwise if a provider named ``fal`` is registered, use it (legacy 16 default — matches pre-plugin behavior). 17 3. Otherwise return ``None`` (the tool surfaces a helpful error pointing 18 the user at ``hermes tools``). 19 """ 20 21 from __future__ import annotations 22 23 import logging 24 import threading 25 from typing import Dict, List, Optional 26 27 from agent.image_gen_provider import ImageGenProvider 28 29 logger = logging.getLogger(__name__) 30 31 32 _providers: Dict[str, ImageGenProvider] = {} 33 _lock = threading.Lock() 34 35 36 def register_provider(provider: ImageGenProvider) -> None: 37 """Register an image generation provider. 38 39 Re-registration (same ``name``) overwrites the previous entry and logs 40 a debug message — this makes hot-reload scenarios (tests, dev loops) 41 behave predictably. 42 """ 43 if not isinstance(provider, ImageGenProvider): 44 raise TypeError( 45 f"register_provider() expects an ImageGenProvider instance, " 46 f"got {type(provider).__name__}" 47 ) 48 name = provider.name 49 if not isinstance(name, str) or not name.strip(): 50 raise ValueError("Image gen provider .name must be a non-empty string") 51 with _lock: 52 existing = _providers.get(name) 53 _providers[name] = provider 54 if existing is not None: 55 logger.debug("Image gen provider '%s' re-registered (was %r)", name, type(existing).__name__) 56 else: 57 logger.debug("Registered image gen provider '%s' (%s)", name, type(provider).__name__) 58 59 60 def list_providers() -> List[ImageGenProvider]: 61 """Return all registered providers, sorted by name.""" 62 with _lock: 63 items = list(_providers.values()) 64 return sorted(items, key=lambda p: p.name) 65 66 67 def get_provider(name: str) -> Optional[ImageGenProvider]: 68 """Return the provider registered under *name*, or None.""" 69 if not isinstance(name, str): 70 return None 71 with _lock: 72 return _providers.get(name.strip()) 73 74 75 def get_active_provider() -> Optional[ImageGenProvider]: 76 """Resolve the currently-active provider. 77 78 Reads ``image_gen.provider`` from config.yaml; falls back per the 79 module docstring. 80 """ 81 configured: Optional[str] = None 82 try: 83 from hermes_cli.config import load_config 84 85 cfg = load_config() 86 section = cfg.get("image_gen") if isinstance(cfg, dict) else None 87 if isinstance(section, dict): 88 raw = section.get("provider") 89 if isinstance(raw, str) and raw.strip(): 90 configured = raw.strip() 91 except Exception as exc: 92 logger.debug("Could not read image_gen.provider from config: %s", exc) 93 94 with _lock: 95 snapshot = dict(_providers) 96 97 if configured: 98 provider = snapshot.get(configured) 99 if provider is not None: 100 return provider 101 logger.debug( 102 "image_gen.provider='%s' configured but not registered; falling back", 103 configured, 104 ) 105 106 # Fallback: single-provider case 107 if len(snapshot) == 1: 108 return next(iter(snapshot.values())) 109 110 # Fallback: prefer legacy FAL for backward compat 111 if "fal" in snapshot: 112 return snapshot["fal"] 113 114 return None 115 116 117 def _reset_for_tests() -> None: 118 """Clear the registry. **Test-only.**""" 119 with _lock: 120 _providers.clear()