/ tests / tools / test_tts_command_providers.py
test_tts_command_providers.py
  1  """
  2  Tests for custom command-type TTS providers.
  3  
  4  These tests cover the ``tts.providers.<name>`` registry: built-in
  5  precedence, command resolution, placeholder rendering, shell-quote
  6  context handling, timeout / failure cleanup, voice_compatible opt-in,
  7  and max_text_length lookup.
  8  
  9  Nothing here talks to a real TTS engine. The shell command itself is
 10  portable: we write bytes to ``{output_path}`` using ``python -c`` so
 11  the tests run identically on Linux, macOS, and (with minor quoting
 12  differences) Windows.
 13  """
 14  
 15  import json
 16  import os
 17  import subprocess
 18  import sys
 19  from pathlib import Path
 20  from typing import Optional
 21  from unittest.mock import patch
 22  
 23  import pytest
 24  
 25  from tools.tts_tool import (
 26      BUILTIN_TTS_PROVIDERS,
 27      COMMAND_TTS_OUTPUT_FORMATS,
 28      DEFAULT_COMMAND_TTS_MAX_TEXT_LENGTH,
 29      DEFAULT_COMMAND_TTS_OUTPUT_FORMAT,
 30      DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS,
 31      _generate_command_tts,
 32      _get_command_tts_output_format,
 33      _get_command_tts_timeout,
 34      _get_named_provider_config,
 35      _has_any_command_tts_provider,
 36      _is_command_provider_config,
 37      _is_command_tts_voice_compatible,
 38      _iter_command_providers,
 39      _render_command_tts_template,
 40      _resolve_command_provider_config,
 41      _resolve_max_text_length,
 42      _shell_quote_context,
 43      check_tts_requirements,
 44      text_to_speech_tool,
 45  )
 46  
 47  
 48  # ---------------------------------------------------------------------------
 49  # Helpers
 50  # ---------------------------------------------------------------------------
 51  
 52  def _python_copy_command(output_placeholder: str = "{output_path}") -> str:
 53      """Return a cross-platform shell command that copies {input_path} -> output."""
 54      interpreter = sys.executable
 55      return (
 56          f'"{interpreter}" -c "import shutil, sys; '
 57          f'shutil.copyfile(sys.argv[1], sys.argv[2])" '
 58          f'{{input_path}} {output_placeholder}'
 59      )
 60  
 61  
 62  # ---------------------------------------------------------------------------
 63  # _resolve_command_provider_config / built-in precedence
 64  # ---------------------------------------------------------------------------
 65  
 66  class TestResolveCommandProviderConfig:
 67      def test_builtin_names_are_never_command_providers(self):
 68          cfg = {
 69              "providers": {
 70                  "openai": {"type": "command", "command": "echo hi"},
 71                  "edge": {"type": "command", "command": "echo hi"},
 72              },
 73          }
 74          for name in BUILTIN_TTS_PROVIDERS:
 75              assert _resolve_command_provider_config(name, cfg) is None
 76  
 77      def test_missing_provider_returns_none(self):
 78          cfg = {"providers": {}}
 79          assert _resolve_command_provider_config("nope", cfg) is None
 80  
 81      def test_user_declared_command_provider_resolves(self):
 82          cfg = {
 83              "providers": {
 84                  "piper-cli": {"type": "command", "command": "piper-cli foo"},
 85              },
 86          }
 87          resolved = _resolve_command_provider_config("piper-cli", cfg)
 88          assert resolved is not None
 89          assert resolved["command"] == "piper-cli foo"
 90  
 91      def test_type_command_is_implied_when_command_is_set(self):
 92          cfg = {"providers": {"piper-cli": {"command": "piper-cli foo"}}}
 93          resolved = _resolve_command_provider_config("piper-cli", cfg)
 94          assert resolved is not None
 95  
 96      def test_other_type_values_reject(self):
 97          cfg = {"providers": {"piper-cli": {"type": "python", "command": "piper-cli foo"}}}
 98          assert _resolve_command_provider_config("piper-cli", cfg) is None
 99  
100      def test_empty_command_rejects(self):
101          cfg = {"providers": {"piper-cli": {"type": "command", "command": "   "}}}
102          assert _resolve_command_provider_config("piper-cli", cfg) is None
103  
104      def test_case_insensitive_lookup(self):
105          cfg = {"providers": {"piper-cli": {"type": "command", "command": "x"}}}
106          assert _resolve_command_provider_config("PIPER-CLI", cfg) is not None
107  
108      def test_native_piper_cannot_be_shadowed_by_command_entry(self):
109          """Regression guard for PR that added native Piper as a built-in.
110          A user's ``tts.providers.piper`` must not override the built-in."""
111          cfg = {
112              "providers": {
113                  "piper": {"type": "command", "command": "some-script"},
114              },
115          }
116          assert _resolve_command_provider_config("piper", cfg) is None
117  
118  
119  class TestGetNamedProviderConfig:
120      def test_providers_block_wins(self):
121          cfg = {"providers": {"voxcpm": {"command": "new"}},
122                 "voxcpm": {"command": "legacy"}}
123          assert _get_named_provider_config(cfg, "voxcpm") == {"command": "new"}
124  
125      def test_legacy_tts_name_block_still_resolves(self):
126          cfg = {"voxcpm": {"type": "command", "command": "legacy"}}
127          assert _get_named_provider_config(cfg, "voxcpm") == {
128              "type": "command", "command": "legacy"
129          }
130  
131      def test_builtin_names_do_not_leak_through_legacy_path(self):
132          """``tts.openai`` must never be mistaken for a command provider."""
133          cfg = {"openai": {"command": "oops", "type": "command"}}
134          assert _get_named_provider_config(cfg, "openai") == {}
135  
136  
137  class TestIsCommandProviderConfig:
138      def test_empty_dict_is_false(self):
139          assert _is_command_provider_config({}) is False
140  
141      def test_non_dict_is_false(self):
142          assert _is_command_provider_config("foo") is False
143          assert _is_command_provider_config(None) is False
144  
145      def test_type_mismatch_is_false(self):
146          assert _is_command_provider_config({"type": "native", "command": "x"}) is False
147  
148  
149  # ---------------------------------------------------------------------------
150  # _iter_command_providers / _has_any_command_tts_provider
151  # ---------------------------------------------------------------------------
152  
153  class TestIterCommandProviders:
154      def test_iterates_only_user_command_providers(self):
155          cfg = {
156              "providers": {
157                  "openai": {"type": "command", "command": "shouldnt show up"},
158                  "piper-cli": {"type": "command", "command": "piper-cli"},
159                  "voxcpm": {"type": "command", "command": "voxcpm"},
160                  "broken": {"type": "command", "command": ""},
161              },
162          }
163          names = sorted(name for name, _ in _iter_command_providers(cfg))
164          assert names == ["piper-cli", "voxcpm"]
165  
166      def test_has_any_command_provider_detects_declared(self):
167          cfg = {"providers": {"piper-cli": {"type": "command", "command": "piper-cli"}}}
168          assert _has_any_command_tts_provider(cfg) is True
169  
170      def test_has_any_command_provider_when_none(self):
171          assert _has_any_command_tts_provider({"providers": {}}) is False
172          assert _has_any_command_tts_provider({}) is False
173  
174  
175  # ---------------------------------------------------------------------------
176  # config getters
177  # ---------------------------------------------------------------------------
178  
179  class TestConfigGetters:
180      def test_timeout_defaults(self):
181          assert _get_command_tts_timeout({}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS)
182  
183      def test_timeout_coerces_string(self):
184          assert _get_command_tts_timeout({"timeout": "45"}) == 45.0
185  
186      def test_timeout_rejects_non_positive(self):
187          assert _get_command_tts_timeout({"timeout": 0}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS)
188          assert _get_command_tts_timeout({"timeout": -1}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS)
189  
190      def test_timeout_rejects_garbage(self):
191          assert _get_command_tts_timeout({"timeout": "fast"}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS)
192  
193      def test_timeout_seconds_alias(self):
194          assert _get_command_tts_timeout({"timeout_seconds": 90}) == 90.0
195  
196      def test_output_format_defaults(self):
197          assert _get_command_tts_output_format({}) == DEFAULT_COMMAND_TTS_OUTPUT_FORMAT
198  
199      def test_output_format_path_override(self):
200          assert _get_command_tts_output_format({}, "/tmp/clip.wav") == "wav"
201  
202      def test_output_format_unknown_path_falls_back_to_config(self):
203          assert _get_command_tts_output_format({"format": "ogg"}, "/tmp/clip.xyz") == "ogg"
204  
205      def test_output_format_rejects_unknown(self):
206          assert _get_command_tts_output_format({"format": "m4a"}) == DEFAULT_COMMAND_TTS_OUTPUT_FORMAT
207  
208      def test_output_format_supported_set(self):
209          assert COMMAND_TTS_OUTPUT_FORMATS == frozenset({"mp3", "wav", "ogg", "flac"})
210  
211      def test_voice_compatible_boolean(self):
212          assert _is_command_tts_voice_compatible({"voice_compatible": True}) is True
213          assert _is_command_tts_voice_compatible({"voice_compatible": False}) is False
214  
215      def test_voice_compatible_string(self):
216          assert _is_command_tts_voice_compatible({"voice_compatible": "yes"}) is True
217          assert _is_command_tts_voice_compatible({"voice_compatible": "0"}) is False
218  
219      def test_voice_compatible_default_off(self):
220          assert _is_command_tts_voice_compatible({}) is False
221  
222  
223  # ---------------------------------------------------------------------------
224  # _resolve_max_text_length for command providers
225  # ---------------------------------------------------------------------------
226  
227  class TestMaxTextLengthForCommandProviders:
228      def test_default_for_command_provider(self):
229          cfg = {"providers": {"piper-cli": {"type": "command", "command": "x"}}}
230          assert _resolve_max_text_length("piper-cli", cfg) == DEFAULT_COMMAND_TTS_MAX_TEXT_LENGTH
231  
232      def test_override_under_providers(self):
233          cfg = {"providers": {"piper-cli": {"type": "command", "command": "x", "max_text_length": 2500}}}
234          assert _resolve_max_text_length("piper-cli", cfg) == 2500
235  
236      def test_override_under_legacy_tts_name_block(self):
237          cfg = {"piper-cli": {"type": "command", "command": "x", "max_text_length": 7777}}
238          assert _resolve_max_text_length("piper-cli", cfg) == 7777
239  
240      def test_non_command_unknown_provider_still_falls_back(self):
241          assert _resolve_max_text_length("unknown", {}) > 0
242  
243  
244  # ---------------------------------------------------------------------------
245  # _shell_quote_context / template rendering
246  # ---------------------------------------------------------------------------
247  
248  class TestShellQuoteContext:
249      def test_bare_context(self):
250          tpl = 'tts {output_path}'
251          pos = tpl.index("{output_path}")
252          assert _shell_quote_context(tpl, pos) is None
253  
254      def test_inside_single_quotes(self):
255          tpl = "tts '{output_path}'"
256          pos = tpl.index("{output_path}")
257          assert _shell_quote_context(tpl, pos) == "'"
258  
259      def test_inside_double_quotes(self):
260          tpl = 'tts "{output_path}"'
261          pos = tpl.index("{output_path}")
262          assert _shell_quote_context(tpl, pos) == '"'
263  
264      def test_escaped_double_quote_inside_double(self):
265          tpl = r'tts "foo \" {output_path}"'
266          pos = tpl.index("{output_path}")
267          assert _shell_quote_context(tpl, pos) == '"'
268  
269  
270  class TestRenderCommandTtsTemplate:
271      def test_substitutes_all_placeholders(self):
272          placeholders = {
273              "input_path": "/tmp/in.txt",
274              "text_path": "/tmp/in.txt",
275              "output_path": "/tmp/out.mp3",
276              "format": "mp3",
277              "voice": "af_sky",
278              "model": "tiny",
279              "speed": "1.0",
280          }
281          rendered = _render_command_tts_template(
282              "tts --voice {voice} --in {input_path} --out {output_path}",
283              placeholders,
284          )
285          assert "af_sky" in rendered
286          assert "/tmp/out.mp3" in rendered
287  
288      def test_quotes_paths_with_spaces(self):
289          placeholders = {
290              "input_path": "/tmp/Jane Doe/in.txt",
291              "text_path": "/tmp/Jane Doe/in.txt",
292              "output_path": "/tmp/out.mp3",
293              "format": "mp3",
294              "voice": "",
295              "model": "",
296              "speed": "1.0",
297          }
298          rendered = _render_command_tts_template(
299              "tts --in {input_path} --out {output_path}",
300              placeholders,
301          )
302          # shlex.quote wraps space-containing paths in single quotes on POSIX.
303          if os.name != "nt":
304              assert "'/tmp/Jane Doe/in.txt'" in rendered
305  
306      def test_literal_braces_survive(self):
307          placeholders = {
308              "input_path": "/tmp/in.txt", "text_path": "/tmp/in.txt",
309              "output_path": "/tmp/out.mp3", "format": "mp3",
310              "voice": "", "model": "", "speed": "1.0",
311          }
312          rendered = _render_command_tts_template(
313              "echo '{{not a placeholder}}' && tts --in {input_path}",
314              placeholders,
315          )
316          assert "{not a placeholder}" in rendered
317  
318      def test_injection_is_neutralized(self):
319          """Embedded shell metacharacters in a placeholder value must be quoted."""
320          placeholders = {
321              "input_path": "/tmp/in.txt", "text_path": "/tmp/in.txt",
322              "output_path": "/tmp/out; rm -rf /",
323              "format": "mp3",
324              "voice": "$(whoami)", "model": "", "speed": "1.0",
325          }
326          rendered = _render_command_tts_template(
327              "tts --voice {voice} --out {output_path}",
328              placeholders,
329          )
330          # The injection payload must not appear unquoted in the rendered
331          # command. On POSIX shlex.quote wraps the value in single quotes.
332          if os.name != "nt":
333              assert "'$(whoami)'" in rendered or "'\\''" in rendered
334              assert "; rm -rf /" not in rendered.replace(
335                  "'/tmp/out; rm -rf /'", "",
336              )
337  
338      def test_preserves_shell_quoting_style(self):
339          placeholders = {
340              "input_path": "/tmp/in.txt", "text_path": "/tmp/in.txt",
341              "output_path": "/tmp/out.mp3", "format": "mp3",
342              "voice": "bob's voice", "model": "", "speed": "1.0",
343          }
344          # When the template wraps the placeholder in double quotes we must
345          # escape for that context, not collapse to single-quoted form.
346          rendered = _render_command_tts_template(
347              'tts --voice "{voice}"',
348              placeholders,
349          )
350          assert '"bob\'s voice"' in rendered
351  
352  
353  # ---------------------------------------------------------------------------
354  # End-to-end: _generate_command_tts
355  # ---------------------------------------------------------------------------
356  
357  class TestGenerateCommandTts:
358      def test_writes_output_file(self, tmp_path):
359          out = tmp_path / "clip.mp3"
360          config = {"command": _python_copy_command()}
361          result = _generate_command_tts(
362              "hello world",
363              str(out),
364              "py-copy",
365              config,
366              {},
367          )
368          assert result == str(out)
369          assert out.exists()
370          # The command copied the input text file over to output, so it
371          # contains the original UTF-8 text.
372          assert out.read_text(encoding="utf-8") == "hello world"
373  
374      def test_empty_command_raises(self, tmp_path):
375          with pytest.raises(ValueError, match="is not configured"):
376              _generate_command_tts(
377                  "hello",
378                  str(tmp_path / "x.mp3"),
379                  "empty",
380                  {"command": "  "},
381                  {},
382              )
383  
384      def test_nonzero_exit_raises_runtime(self, tmp_path):
385          config = {"command": f'"{sys.executable}" -c "import sys; sys.exit(3)"'}
386          with pytest.raises(RuntimeError, match="exited with code 3"):
387              _generate_command_tts(
388                  "hello",
389                  str(tmp_path / "x.mp3"),
390                  "failing",
391                  config,
392                  {},
393              )
394  
395      def test_empty_output_raises_runtime(self, tmp_path):
396          # This command completes successfully but writes nothing.
397          config = {"command": f'"{sys.executable}" -c "pass"'}
398          with pytest.raises(RuntimeError, match="produced no output"):
399              _generate_command_tts(
400                  "hello",
401                  str(tmp_path / "x.mp3"),
402                  "silent",
403                  config,
404                  {},
405              )
406  
407      @pytest.mark.skipif(os.name == "nt", reason="POSIX-only timeout semantics")
408      def test_timeout_raises_runtime(self, tmp_path):
409          config = {
410              "command": f'"{sys.executable}" -c "import time; time.sleep(10)"',
411              "timeout": 1,
412          }
413          with pytest.raises(RuntimeError, match="timed out"):
414              _generate_command_tts(
415                  "hello",
416                  str(tmp_path / "x.mp3"),
417                  "slow",
418                  config,
419                  {},
420              )
421  
422  
423  # ---------------------------------------------------------------------------
424  # text_to_speech_tool integration
425  # ---------------------------------------------------------------------------
426  
427  class TestTextToSpeechToolWithCommandProvider:
428      def test_command_provider_dispatches_end_to_end(self, tmp_path):
429          cfg = {
430              "tts": {
431                  "provider": "py-copy",
432                  "providers": {
433                      "py-copy": {
434                          "type": "command",
435                          "command": _python_copy_command(),
436                          "output_format": "mp3",
437                      },
438                  },
439              },
440          }
441          out = tmp_path / "clip.mp3"
442  
443          # Patch the config loader used by the tool so we don't touch disk.
444          def fake_load():
445              return cfg["tts"]
446  
447          with patch("tools.tts_tool._load_tts_config", fake_load):
448              result = text_to_speech_tool(text="hi", output_path=str(out))
449          data = json.loads(result)
450          assert data["success"] is True, data
451          assert data["provider"] == "py-copy"
452          assert data["voice_compatible"] is False
453          assert Path(data["file_path"]).exists()
454  
455      def test_voice_compatible_opt_in_toggles_flag(self, tmp_path):
456          """voice_compatible=true is reflected in the response when the
457          file is already .ogg (no ffmpeg needed)."""
458          cfg = {
459              "provider": "py-copy-ogg",
460              "providers": {
461                  "py-copy-ogg": {
462                      "type": "command",
463                      "command": _python_copy_command(),
464                      "output_format": "ogg",
465                      "voice_compatible": True,
466                  },
467              },
468          }
469          out = tmp_path / "clip.ogg"
470  
471          with patch("tools.tts_tool._load_tts_config", return_value=cfg):
472              result = text_to_speech_tool(text="hi", output_path=str(out))
473          data = json.loads(result)
474          assert data["success"] is True
475          assert data["voice_compatible"] is True
476          assert data["media_tag"].startswith("[[audio_as_voice]]")
477  
478      def test_missing_command_falls_through_to_builtin(self, tmp_path):
479          """A provider entry with an empty command is not a command
480          provider; the tool should not raise a "command not configured"
481          error but fall through to the built-in resolution path."""
482          cfg = {
483              "provider": "broken",
484              "providers": {
485                  "broken": {"type": "command", "command": "   "},
486              },
487          }
488          with patch("tools.tts_tool._load_tts_config", return_value=cfg):
489              result = text_to_speech_tool(text="hi", output_path=str(tmp_path / "x.mp3"))
490          data = json.loads(result)
491          # The response should not carry the command-provider error text.
492          err = (data.get("error") or "").lower()
493          assert "tts.providers.broken.command is not configured" not in err
494  
495  
496  class TestCheckTtsRequirements:
497      def test_configured_command_provider_satisfies_requirement(self):
498          cfg = {"providers": {"x": {"type": "command", "command": "echo x"}}}
499          with patch("tools.tts_tool._load_tts_config", return_value=cfg):
500              assert check_tts_requirements() is True