22
22
import wave
23
23
import numpy as np
24
24
25
+ # Class to check tts settings
26
+ class InvalidSettingsError (Exception ):
27
+ pass
28
+
25
29
# List of supported language codes
26
30
supported_languages = {
27
31
"ar" :"Arabic" ,
43
47
"hi" :"Hindi"
44
48
}
45
49
50
+ default_tts_settings = {
51
+ "temperature" : 0.75 ,
52
+ "length_penalty" : 1.0 ,
53
+ "repetition_penalty" : 5.0 ,
54
+ "top_k" : 50 ,
55
+ "top_p" : 0.85 ,
56
+ "speed" : 1 ,
57
+ "enable_text_splitting" : True
58
+ }
59
+
46
60
official_model_list = ["v2.0.0" ,"v2.0.1" ,"v2.0.2" ,"v2.0.3" ,"main" ]
47
61
official_model_list_v2 = ["2.0.0" ,"2.0.1" ,"2.0.2" ,"2.0.3" ]
48
62
49
63
reversed_supported_languages = {name : code for code , name in supported_languages .items ()}
50
64
51
65
class TTSWrapper :
52
- def __init__ (self ,output_folder = "./output" , speaker_folder = "./speakers" ,lowvram = False ,model_source = "local" ,model_version = "2.0.2" ,device = "cuda" ,deepspeed = False ,enable_cache_results = True ):
66
+ def __init__ (self ,output_folder = "./output" , speaker_folder = "./speakers" ,model_folder = "./xtts_folder" , lowvram = False ,model_source = "local" ,model_version = "2.0.2" ,device = "cuda" ,deepspeed = False ,enable_cache_results = True ):
53
67
54
68
self .cuda = device # If the user has chosen what to use, we rewrite the value to the value we want to use
55
69
self .device = 'cpu' if lowvram else (self .cuda if torch .cuda .is_available () else "cpu" )
@@ -59,12 +73,13 @@ def __init__(self,output_folder = "./output", speaker_folder="./speakers",lowvra
59
73
60
74
self .model_source = model_source
61
75
self .model_version = model_version
76
+ self .tts_settings = default_tts_settings
62
77
63
78
self .deepspeed = deepspeed
64
79
65
80
self .speaker_folder = speaker_folder
66
81
self .output_folder = output_folder
67
- self .custom_models_folder = "./models"
82
+ self .model_folder = model_folder
68
83
69
84
self .create_directories ()
70
85
check_tts_version ()
@@ -90,6 +105,14 @@ def check_model_version_old_format(self,model_version):
90
105
return "v" + model_version
91
106
return model_version
92
107
108
+ def get_models_list (self ):
109
+ # Fetch all entries in the directory given by self.model_folder
110
+ entries = os .listdir (self .model_folder )
111
+
112
+ # Filter out and return only directories
113
+ return [name for name in entries if os .path .isdir (os .path .join (self .model_folder , name ))]
114
+
115
+
93
116
def get_wav_header (self , channels :int = 1 , sample_rate :int = 24000 , width :int = 2 ) -> bytes :
94
117
wav_buf = io .BytesIO ()
95
118
with wave .open (wav_buf , "wb" ) as out :
@@ -147,12 +170,11 @@ def load_model(self,load=True):
147
170
self .model = TTS ("tts_models/multilingual/multi-dataset/xtts_v2" )
148
171
149
172
if self .model_source == "apiManual" :
150
- this_dir = Path (__file__ ).parent .resolve () / "models"
173
+ this_dir = Path (self .model_folder )
174
+
151
175
if self .isModelOfficial (self .model_version ):
152
176
download_model (this_dir ,self .model_version )
153
- else :
154
- this_dir = Path (self .custom_models_folder ).resolve ()
155
-
177
+
156
178
config_path = this_dir / f'{ self .model_version } ' / 'config.json'
157
179
checkpoint_dir = this_dir / f'{ self .model_version } '
158
180
@@ -170,13 +192,11 @@ def load_model(self,load=True):
170
192
logger .info ("Model successfully loaded " )
171
193
172
194
def load_local_model (self ,load = True ):
173
- this_model_dir = Path (__file__ ). parent . resolve ()
195
+ this_model_dir = Path (self . model_folder )
174
196
175
197
if self .isModelOfficial (self .model_version ):
176
198
download_model (this_model_dir ,self .model_version )
177
- this_model_dir = this_model_dir / "models"
178
- else :
179
- this_model_dir = Path (self .custom_models_folder )
199
+ this_model_dir = this_model_dir
180
200
181
201
config = XttsConfig ()
182
202
config_path = this_model_dir / f'{ self .model_version } ' / 'config.json'
@@ -188,6 +208,34 @@ def load_local_model(self,load=True):
188
208
self .model .load_checkpoint (config ,use_deepspeed = self .deepspeed , checkpoint_dir = str (checkpoint_dir ))
189
209
self .model .to (self .device )
190
210
211
+ def switch_model (self ,model_name ):
212
+
213
+ model_list = self .get_models_list ()
214
+ # Check to see if the same name is selected
215
+ if (model_name == self .model_version ):
216
+ raise InvalidSettingsError ("The model with this name is already loaded in memory" )
217
+ return
218
+
219
+ # Check if the model is in the list at all
220
+ if (model_name not in model_list ):
221
+ raise InvalidSettingsError (f"A model with `{ model_name } ` name is not in the models folder, the current available models: { model_list } " )
222
+ return
223
+
224
+ # Clear gpu cache from old model
225
+ self .model = ""
226
+ torch .cuda .empty_cache ()
227
+ logger .info ("Model successfully unloaded from memory" )
228
+
229
+ # Start load model
230
+ logger .info (f"Start loading { model_name } model" )
231
+ self .model_version = model_name
232
+ if self .model_source == "local" :
233
+ self .load_local_model ()
234
+ else :
235
+ self .load_model ()
236
+
237
+ logger .info (f"Model successfully loaded" )
238
+
191
239
# LOWVRAM FUNCS
192
240
def switch_model_device (self ):
193
241
# We check for lowram and the existence of cuda
@@ -222,7 +270,7 @@ def create_latents_for_all(self):
222
270
223
271
# DIRICTORIES FUNCS
224
272
def create_directories (self ):
225
- directories = [self .output_folder , self .speaker_folder ,self .custom_models_folder ]
273
+ directories = [self .output_folder , self .speaker_folder ,self .model_folder ]
226
274
227
275
for sanctuary in directories :
228
276
# List of folders to be checked for existence
@@ -249,6 +297,50 @@ def set_out_folder(self, folder):
249
297
else :
250
298
raise ValueError ("Provided path is not a valid directory" )
251
299
300
+ def set_tts_settings (self , temperature , speed , length_penalty ,
301
+ repetition_penalty , top_p , top_k , enable_text_splitting ):
302
+ # Validate each parameter and raise an exception if any checks fail.
303
+
304
+ # Check temperature
305
+ if not (0.01 <= temperature <= 1 ):
306
+ raise InvalidSettingsError ("Temperature must be between 0.01 and 1." )
307
+
308
+ # Check speed
309
+ if not (0.2 <= speed <= 2 ):
310
+ raise InvalidSettingsError ("Speed must be between 0.2 and 2." )
311
+
312
+ # Check length_penalty (no explicit range specified)
313
+ if not isinstance (length_penalty , float ):
314
+ raise InvalidSettingsError ("Length penalty must be a floating point number." )
315
+
316
+ # Check repetition_penalty
317
+ if not (0.1 <= repetition_penalty <= 10.0 ):
318
+ raise InvalidSettingsError ("Repetition penalty must be between 0.1 and 10.0." )
319
+
320
+ # Check top_p
321
+ if not (0.01 <= top_p <= 1 ):
322
+ raise InvalidSettingsError ("Top_p must be between 0.01 and 1 and must be a float." )
323
+
324
+ # Check top_k
325
+ if not (1 <= top_k <= 100 ):
326
+ raise InvalidSettingsError ("Top_k must be an integer between 1 and 100." )
327
+
328
+ # Check enable_text_splitting
329
+ if not isinstance (enable_text_splitting , bool ):
330
+ raise InvalidSettingsError ("Enable text splitting must be either True or False." )
331
+
332
+ # All validations passed - proceed to apply settings.
333
+ self .tts_settings = {
334
+ "temperature" : temperature ,
335
+ "speed" : speed ,
336
+ "length_penalty" : length_penalty ,
337
+ "repetition_penalty" : repetition_penalty ,
338
+ "top_p" : top_p ,
339
+ "top_k" : top_k ,
340
+ "enable_text_splitting" : enable_text_splitting ,
341
+ }
342
+ print ("Successfully updated TTS settings." )
343
+
252
344
# GET FUNCS
253
345
def get_wav_files (self , directory ):
254
346
""" Finds all the wav files in a directory. """
@@ -361,12 +453,7 @@ async def stream_generation(self,text,speaker_name,speaker_wav,language,output_f
361
453
language ,
362
454
speaker_embedding = speaker_embedding ,
363
455
gpt_cond_latent = gpt_cond_latent ,
364
- temperature = 0.75 ,
365
- length_penalty = 1.0 ,
366
- repetition_penalty = 5.0 ,
367
- top_k = 50 ,
368
- top_p = 0.85 ,
369
- enable_text_splitting = True ,
456
+ ** self .tts_settings , # Expands the object with the settings and applies them for generation
370
457
stream_chunk_size = 100 ,
371
458
)
372
459
@@ -402,12 +489,7 @@ def local_generation(self,text,speaker_name,speaker_wav,language,output_file):
402
489
language ,
403
490
gpt_cond_latent = gpt_cond_latent ,
404
491
speaker_embedding = speaker_embedding ,
405
- temperature = 0.75 ,
406
- length_penalty = 1.0 ,
407
- repetition_penalty = 5.0 ,
408
- top_k = 50 ,
409
- top_p = 0.85 ,
410
- enable_text_splitting = True
492
+ ** self .tts_settings , # Expands the object with the settings and applies them for generation
411
493
)
412
494
413
495
torchaudio .save (output_file , torch .tensor (out ["wav" ]).unsqueeze (0 ), 24000 )
0 commit comments