server.py
1 #!flask/bin/python 2 import argparse 3 import io 4 import json 5 import os 6 import sys 7 from pathlib import Path 8 from threading import Lock 9 from typing import Union 10 from urllib.parse import parse_qs 11 12 from flask import Flask, render_template, render_template_string, request, send_file 13 from flask_cors import CORS 14 15 from TTS.config import load_config 16 from TTS.utils.manage import ModelManager 17 from TTS.utils.synthesizer import Synthesizer 18 19 20 def create_argparser(): 21 def convert_boolean(x): 22 return x.lower() in ["true", "1", "yes"] 23 24 parser = argparse.ArgumentParser() 25 parser.add_argument( 26 "--list_models", 27 type=convert_boolean, 28 nargs="?", 29 const=True, 30 default=False, 31 help="list available pre-trained tts and vocoder models.", 32 ) 33 parser.add_argument( 34 "--model_name", 35 type=str, 36 default="tts_models/en/ljspeech/tacotron2-DDC", 37 help="Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>", 38 ) 39 parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.") 40 41 # Args for running custom models 42 parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") 43 parser.add_argument( 44 "--model_path", 45 type=str, 46 default=None, 47 help="Path to model file.", 48 ) 49 parser.add_argument( 50 "--vocoder_path", 51 type=str, 52 help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", 53 default=None, 54 ) 55 parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) 56 parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) 57 parser.add_argument("--port", type=int, default=5002, help="port to listen on.") 58 parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") 59 parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") 60 parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.") 61 return parser 62 63 64 # parse the args 65 args = create_argparser().parse_args() 66 67 path = Path(__file__).parent / "models.json" 68 manager = ModelManager(path) 69 70 if args.list_models: 71 manager.list_models() 72 sys.exit() 73 74 # update in-use models to the specified released models. 75 model_path = None 76 config_path = None 77 speakers_file_path = None 78 vocoder_path = None 79 vocoder_config_path = None 80 81 # CASE1: list pre-trained TTS models 82 if args.list_models: 83 manager.list_models() 84 sys.exit() 85 86 # CASE2: load pre-trained model paths 87 if args.model_name is not None and not args.model_path: 88 model_path, config_path, model_item = manager.download_model(args.model_name) 89 args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name 90 91 if args.vocoder_name is not None and not args.vocoder_path: 92 vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) 93 94 # CASE3: set custom model paths 95 if args.model_path is not None: 96 model_path = args.model_path 97 config_path = args.config_path 98 speakers_file_path = args.speakers_file_path 99 100 if args.vocoder_path is not None: 101 vocoder_path = args.vocoder_path 102 vocoder_config_path = args.vocoder_config_path 103 104 # load models 105 synthesizer = Synthesizer( 106 tts_checkpoint=model_path, 107 tts_config_path=config_path, 108 tts_speakers_file=speakers_file_path, 109 tts_languages_file=None, 110 vocoder_checkpoint=vocoder_path, 111 vocoder_config=vocoder_config_path, 112 encoder_checkpoint="", 113 encoder_config="", 114 use_cuda=args.use_cuda, 115 ) 116 117 use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( 118 synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None 119 ) 120 speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) 121 122 use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and ( 123 synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None 124 ) 125 language_manager = getattr(synthesizer.tts_model, "language_manager", None) 126 127 # TODO: set this from SpeakerManager 128 use_gst = synthesizer.tts_config.get("use_gst", False) 129 app = Flask(__name__) 130 CORS(app) 131 132 133 def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]: 134 """Transform an uri style_wav, in either a string (path to wav file to be use for style transfer) 135 or a dict (gst tokens/values to be use for styling) 136 137 Args: 138 style_wav (str): uri 139 140 Returns: 141 Union[str, dict]: path to file (str) or gst style (dict) 142 """ 143 if style_wav: 144 if os.path.isfile(style_wav) and style_wav.endswith(".wav"): 145 return style_wav # style_wav is a .wav file located on the server 146 147 style_wav = json.loads(style_wav) 148 return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...} 149 return None 150 151 152 @app.route("/") 153 def index(): 154 return render_template( 155 "index.html", 156 show_details=args.show_details, 157 use_multi_speaker=use_multi_speaker, 158 use_multi_language=use_multi_language, 159 speaker_ids=speaker_manager.name_to_id if speaker_manager is not None else None, 160 language_ids=language_manager.name_to_id if language_manager is not None else None, 161 use_gst=use_gst, 162 ) 163 164 165 @app.route("/details") 166 def details(): 167 model_config = load_config(args.tts_config) 168 if args.vocoder_config is not None and os.path.isfile(args.vocoder_config): 169 vocoder_config = load_config(args.vocoder_config) 170 else: 171 vocoder_config = None 172 173 return render_template( 174 "details.html", 175 show_details=args.show_details, 176 model_config=model_config, 177 vocoder_config=vocoder_config, 178 args=args.__dict__, 179 ) 180 181 182 lock = Lock() 183 184 185 @app.route("/api/tts", methods=["GET"]) 186 def tts(): 187 with lock: 188 text = request.args.get("text") 189 speaker_idx = request.args.get("speaker_id", "") 190 language_idx = request.args.get("language_id", "") 191 style_wav = request.args.get("style_wav", "") 192 style_wav = style_wav_uri_to_dict(style_wav) 193 print(f" > Model input: {text}") 194 print(f" > Speaker Idx: {speaker_idx}") 195 print(f" > Language Idx: {language_idx}") 196 wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) 197 out = io.BytesIO() 198 synthesizer.save_wav(wavs, out) 199 return send_file(out, mimetype="audio/wav") 200 201 202 # Basic MaryTTS compatibility layer 203 204 205 @app.route("/locales", methods=["GET"]) 206 def mary_tts_api_locales(): 207 """MaryTTS-compatible /locales endpoint""" 208 # NOTE: We currently assume there is only one model active at the same time 209 if args.model_name is not None: 210 model_details = args.model_name.split("/") 211 else: 212 model_details = ["", "en", "", "default"] 213 return render_template_string("{{ locale }}\n", locale=model_details[1]) 214 215 216 @app.route("/voices", methods=["GET"]) 217 def mary_tts_api_voices(): 218 """MaryTTS-compatible /voices endpoint""" 219 # NOTE: We currently assume there is only one model active at the same time 220 if args.model_name is not None: 221 model_details = args.model_name.split("/") 222 else: 223 model_details = ["", "en", "", "default"] 224 return render_template_string( 225 "{{ name }} {{ locale }} {{ gender }}\n", name=model_details[3], locale=model_details[1], gender="u" 226 ) 227 228 229 @app.route("/process", methods=["GET", "POST"]) 230 def mary_tts_api_process(): 231 """MaryTTS-compatible /process endpoint""" 232 with lock: 233 if request.method == "POST": 234 data = parse_qs(request.get_data(as_text=True)) 235 # NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model 236 text = data.get("INPUT_TEXT", [""])[0] 237 else: 238 text = request.args.get("INPUT_TEXT", "") 239 print(f" > Model input: {text}") 240 wavs = synthesizer.tts(text) 241 out = io.BytesIO() 242 synthesizer.save_wav(wavs, out) 243 return send_file(out, mimetype="audio/wav") 244 245 246 def main(): 247 app.run(debug=args.debug, host="::", port=args.port) 248 249 250 if __name__ == "__main__": 251 main()