/ tests / tools / test_managed_media_gateways.py
test_managed_media_gateways.py
  1  import sys
  2  import types
  3  from importlib.util import module_from_spec, spec_from_file_location
  4  from pathlib import Path
  5  
  6  import pytest
  7  
  8  
  9  TOOLS_DIR = Path(__file__).resolve().parents[2] / "tools"
 10  
 11  
 12  def _load_tool_module(module_name: str, filename: str):
 13      spec = spec_from_file_location(module_name, TOOLS_DIR / filename)
 14      assert spec and spec.loader
 15      module = module_from_spec(spec)
 16      sys.modules[module_name] = module
 17      spec.loader.exec_module(module)
 18      return module
 19  
 20  
 21  @pytest.fixture(autouse=True)
 22  def _restore_tool_and_agent_modules():
 23      original_modules = {
 24          name: module
 25          for name, module in sys.modules.items()
 26          if name == "tools"
 27          or name.startswith("tools.")
 28          or name == "agent"
 29          or name.startswith("agent.")
 30          or name in {"fal_client", "openai"}
 31      }
 32      try:
 33          yield
 34      finally:
 35          for name in list(sys.modules):
 36              if (
 37                  name == "tools"
 38                  or name.startswith("tools.")
 39                  or name == "agent"
 40                  or name.startswith("agent.")
 41                  or name in {"fal_client", "openai"}
 42              ):
 43                  sys.modules.pop(name, None)
 44          sys.modules.update(original_modules)
 45  
 46  
 47  @pytest.fixture(autouse=True)
 48  def _enable_managed_nous_tools(monkeypatch):
 49      """Patch the source modules so managed_nous_tools_enabled() returns True
 50      even after tool modules are dynamically reloaded."""
 51      monkeypatch.setattr("hermes_cli.auth.get_nous_auth_status", lambda: {"logged_in": True})
 52      monkeypatch.setattr("hermes_cli.models.check_nous_free_tier", lambda: False)
 53  
 54  
 55  def _install_fake_tools_package():
 56      tools_package = types.ModuleType("tools")
 57      tools_package.__path__ = [str(TOOLS_DIR)]  # type: ignore[attr-defined]
 58      sys.modules["tools"] = tools_package
 59      sys.modules["tools.debug_helpers"] = types.SimpleNamespace(
 60          DebugSession=lambda *args, **kwargs: types.SimpleNamespace(
 61              active=False,
 62              session_id="debug-session",
 63              log_call=lambda *a, **k: None,
 64              save=lambda: None,
 65              get_session_info=lambda: {},
 66          )
 67      )
 68      sys.modules["tools.managed_tool_gateway"] = _load_tool_module(
 69          "tools.managed_tool_gateway",
 70          "managed_tool_gateway.py",
 71      )
 72  
 73  
 74  def _install_fake_fal_client(captured):
 75      def submit(model, arguments=None, headers=None):
 76          raise AssertionError("managed FAL gateway mode should use fal_client.SyncClient")
 77  
 78      class FakeResponse:
 79          def json(self):
 80              return {
 81                  "request_id": "req-123",
 82                  "response_url": "http://127.0.0.1:3009/requests/req-123",
 83                  "status_url": "http://127.0.0.1:3009/requests/req-123/status",
 84                  "cancel_url": "http://127.0.0.1:3009/requests/req-123/cancel",
 85              }
 86  
 87      def _maybe_retry_request(client, method, url, json=None, timeout=None, headers=None):
 88          captured["submit_via"] = "managed_client"
 89          captured["http_client"] = client
 90          captured["method"] = method
 91          captured["submit_url"] = url
 92          captured["arguments"] = json
 93          captured["timeout"] = timeout
 94          captured["headers"] = headers
 95          return FakeResponse()
 96  
 97      class SyncRequestHandle:
 98          def __init__(self, request_id, response_url, status_url, cancel_url, client):
 99              captured["request_id"] = request_id
100              captured["response_url"] = response_url
101              captured["status_url"] = status_url
102              captured["cancel_url"] = cancel_url
103              captured["handle_client"] = client
104  
105      class SyncClient:
106          def __init__(self, key=None, default_timeout=120.0):
107              captured["sync_client_inits"] = captured.get("sync_client_inits", 0) + 1
108              captured["client_key"] = key
109              captured["client_timeout"] = default_timeout
110              self.default_timeout = default_timeout
111              self._client = object()
112  
113      fal_client_module = types.SimpleNamespace(
114          submit=submit,
115          SyncClient=SyncClient,
116          client=types.SimpleNamespace(
117              _maybe_retry_request=_maybe_retry_request,
118              _raise_for_status=lambda response: None,
119              SyncRequestHandle=SyncRequestHandle,
120          ),
121      )
122      sys.modules["fal_client"] = fal_client_module
123      return fal_client_module
124  
125  
126  def _install_fake_openai_module(captured, transcription_response=None):
127      class FakeSpeechResponse:
128          def stream_to_file(self, output_path):
129              captured["stream_to_file"] = output_path
130  
131      class FakeOpenAI:
132          def __init__(self, api_key, base_url, **kwargs):
133              captured["api_key"] = api_key
134              captured["base_url"] = base_url
135              captured["client_kwargs"] = kwargs
136              captured["close_calls"] = captured.get("close_calls", 0)
137  
138              def create_speech(**kwargs):
139                  captured["speech_kwargs"] = kwargs
140                  return FakeSpeechResponse()
141  
142              def create_transcription(**kwargs):
143                  captured["transcription_kwargs"] = kwargs
144                  return transcription_response
145  
146              self.audio = types.SimpleNamespace(
147                  speech=types.SimpleNamespace(
148                      create=create_speech
149                  ),
150                  transcriptions=types.SimpleNamespace(
151                      create=create_transcription
152                  ),
153              )
154  
155          def close(self):
156              captured["close_calls"] += 1
157  
158      fake_module = types.SimpleNamespace(
159          OpenAI=FakeOpenAI,
160          APIError=Exception,
161          APIConnectionError=Exception,
162          APITimeoutError=Exception,
163      )
164      sys.modules["openai"] = fake_module
165  
166  
167  def test_managed_fal_submit_uses_gateway_origin_and_nous_token(monkeypatch):
168      captured = {}
169      _install_fake_tools_package()
170      _install_fake_fal_client(captured)
171      monkeypatch.delenv("FAL_KEY", raising=False)
172      monkeypatch.setenv("FAL_QUEUE_GATEWAY_URL", "http://127.0.0.1:3009")
173      monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
174  
175      image_generation_tool = _load_tool_module(
176          "tools.image_generation_tool",
177          "image_generation_tool.py",
178      )
179      monkeypatch.setattr(image_generation_tool.uuid, "uuid4", lambda: "fal-submit-123")
180      
181      image_generation_tool._submit_fal_request(
182          "fal-ai/flux-2-pro",
183          {"prompt": "test prompt", "num_images": 1},
184      )
185  
186      assert captured["submit_via"] == "managed_client"
187      assert captured["client_key"] == "nous-token"
188      assert captured["submit_url"] == "http://127.0.0.1:3009/fal-ai/flux-2-pro"
189      assert captured["method"] == "POST"
190      assert captured["arguments"] == {"prompt": "test prompt", "num_images": 1}
191      assert captured["headers"] == {"x-idempotency-key": "fal-submit-123"}
192      assert captured["sync_client_inits"] == 1
193  
194  
195  def test_managed_fal_submit_reuses_cached_sync_client(monkeypatch):
196      captured = {}
197      _install_fake_tools_package()
198      _install_fake_fal_client(captured)
199      monkeypatch.delenv("FAL_KEY", raising=False)
200      monkeypatch.setenv("FAL_QUEUE_GATEWAY_URL", "http://127.0.0.1:3009")
201      monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
202  
203      image_generation_tool = _load_tool_module(
204          "tools.image_generation_tool",
205          "image_generation_tool.py",
206      )
207  
208      image_generation_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "first"})
209      first_client = captured["http_client"]
210      image_generation_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "second"})
211  
212      assert captured["sync_client_inits"] == 1
213      assert captured["http_client"] is first_client
214  
215  
216  def test_openai_tts_uses_managed_audio_gateway_when_direct_key_absent(monkeypatch, tmp_path):
217      captured = {}
218      _install_fake_tools_package()
219      _install_fake_openai_module(captured)
220      monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
221      monkeypatch.delenv("OPENAI_API_KEY", raising=False)
222      monkeypatch.setenv("TOOL_GATEWAY_DOMAIN", "nousresearch.com")
223      monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
224  
225      tts_tool = _load_tool_module("tools.tts_tool", "tts_tool.py")
226      monkeypatch.setattr(tts_tool.uuid, "uuid4", lambda: "tts-call-123")
227      output_path = tmp_path / "speech.mp3"
228      tts_tool._generate_openai_tts("hello world", str(output_path), {"openai": {}})
229  
230      assert captured["api_key"] == "nous-token"
231      assert captured["base_url"] == "https://openai-audio-gateway.nousresearch.com/v1"
232      assert captured["speech_kwargs"]["model"] == "gpt-4o-mini-tts"
233      assert captured["speech_kwargs"]["extra_headers"] == {"x-idempotency-key": "tts-call-123"}
234      assert captured["stream_to_file"] == str(output_path)
235      assert captured["close_calls"] == 1
236  
237  
238  def test_openai_tts_accepts_openai_api_key_as_direct_fallback(monkeypatch, tmp_path):
239      captured = {}
240      _install_fake_tools_package()
241      _install_fake_openai_module(captured)
242      monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
243      monkeypatch.setenv("OPENAI_API_KEY", "openai-direct-key")
244      monkeypatch.setenv("TOOL_GATEWAY_DOMAIN", "nousresearch.com")
245      monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
246  
247      tts_tool = _load_tool_module("tools.tts_tool", "tts_tool.py")
248      output_path = tmp_path / "speech.mp3"
249      tts_tool._generate_openai_tts("hello world", str(output_path), {"openai": {}})
250  
251      assert captured["api_key"] == "openai-direct-key"
252      assert captured["base_url"] == "https://api.openai.com/v1"
253      assert captured["close_calls"] == 1
254  
255  
256  def test_transcription_uses_model_specific_response_formats(monkeypatch, tmp_path):
257      whisper_capture = {}
258      _install_fake_tools_package()
259      _install_fake_openai_module(whisper_capture, transcription_response="hello from whisper")
260      monkeypatch.setenv("HERMES_HOME", str(tmp_path))
261      (tmp_path / "config.yaml").write_text("stt:\n  provider: openai\n")
262      monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
263      monkeypatch.delenv("OPENAI_API_KEY", raising=False)
264      monkeypatch.setenv("TOOL_GATEWAY_DOMAIN", "nousresearch.com")
265      monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
266  
267      transcription_tools = _load_tool_module(
268          "tools.transcription_tools",
269          "transcription_tools.py",
270      )
271      transcription_tools._load_stt_config = lambda: {"provider": "openai"}
272      audio_path = tmp_path / "audio.wav"
273      audio_path.write_bytes(b"RIFF0000WAVEfmt ")
274  
275      whisper_result = transcription_tools.transcribe_audio(str(audio_path), model="whisper-1")
276      assert whisper_result["success"] is True
277      assert whisper_capture["base_url"] == "https://openai-audio-gateway.nousresearch.com/v1"
278      assert whisper_capture["transcription_kwargs"]["response_format"] == "text"
279      assert whisper_capture["close_calls"] == 1
280  
281      json_capture = {}
282      _install_fake_openai_module(
283          json_capture,
284          transcription_response=types.SimpleNamespace(text="hello from gpt-4o"),
285      )
286      transcription_tools = _load_tool_module(
287          "tools.transcription_tools",
288          "transcription_tools.py",
289      )
290  
291      json_result = transcription_tools.transcribe_audio(
292          str(audio_path),
293          model="gpt-4o-mini-transcribe",
294      )
295      assert json_result["success"] is True
296      assert json_result["transcript"] == "hello from gpt-4o"
297      assert json_capture["transcription_kwargs"]["response_format"] == "json"
298      assert json_capture["close_calls"] == 1