2222import wave
2323import numpy as np
2424
25+ # Class to check tts settings
26+ class InvalidSettingsError (Exception ):
27+ pass
28+
2529# List of supported language codes
2630supported_languages = {
2731 "ar" :"Arabic" ,
4347 "hi" :"Hindi"
4448}
4549
50+ default_tts_settings = {
51+ "temperature" : 0.75 ,
52+ "length_penalty" : 1.0 ,
53+ "repetition_penalty" : 5.0 ,
54+ "top_k" : 50 ,
55+ "top_p" : 0.85 ,
56+ "speed" : 1 ,
57+ "enable_text_splitting" : True
58+ }
59+
4660official_model_list = ["v2.0.0" ,"v2.0.1" ,"v2.0.2" ,"v2.0.3" ,"main" ]
4761official_model_list_v2 = ["2.0.0" ,"2.0.1" ,"2.0.2" ,"2.0.3" ]
4862
4963reversed_supported_languages = {name : code for code , name in supported_languages .items ()}
5064
5165class TTSWrapper :
52- def __init__ (self ,output_folder = "./output" , speaker_folder = "./speakers" ,lowvram = False ,model_source = "local" ,model_version = "2.0.2" ,device = "cuda" ,deepspeed = False ,enable_cache_results = True ):
66+ def __init__ (self ,output_folder = "./output" , speaker_folder = "./speakers" ,model_folder = "./xtts_folder" , lowvram = False ,model_source = "local" ,model_version = "2.0.2" ,device = "cuda" ,deepspeed = False ,enable_cache_results = True ):
5367
5468 self .cuda = device # If the user has chosen what to use, we rewrite the value to the value we want to use
5569 self .device = 'cpu' if lowvram else (self .cuda if torch .cuda .is_available () else "cpu" )
@@ -59,12 +73,13 @@ def __init__(self,output_folder = "./output", speaker_folder="./speakers",lowvra
5973
6074 self .model_source = model_source
6175 self .model_version = model_version
76+ self .tts_settings = default_tts_settings
6277
6378 self .deepspeed = deepspeed
6479
6580 self .speaker_folder = speaker_folder
6681 self .output_folder = output_folder
67- self .custom_models_folder = "./models"
82+ self .model_folder = model_folder
6883
6984 self .create_directories ()
7085 check_tts_version ()
@@ -90,6 +105,14 @@ def check_model_version_old_format(self,model_version):
90105 return "v" + model_version
91106 return model_version
92107
108+ def get_models_list (self ):
109+ # Fetch all entries in the directory given by self.model_folder
110+ entries = os .listdir (self .model_folder )
111+
112+ # Filter out and return only directories
113+ return [name for name in entries if os .path .isdir (os .path .join (self .model_folder , name ))]
114+
115+
93116 def get_wav_header (self , channels :int = 1 , sample_rate :int = 24000 , width :int = 2 ) -> bytes :
94117 wav_buf = io .BytesIO ()
95118 with wave .open (wav_buf , "wb" ) as out :
@@ -147,12 +170,11 @@ def load_model(self,load=True):
147170 self .model = TTS ("tts_models/multilingual/multi-dataset/xtts_v2" )
148171
149172 if self .model_source == "apiManual" :
150- this_dir = Path (__file__ ).parent .resolve () / "models"
173+ this_dir = Path (self .model_folder )
174+
151175 if self .isModelOfficial (self .model_version ):
152176 download_model (this_dir ,self .model_version )
153- else :
154- this_dir = Path (self .custom_models_folder ).resolve ()
155-
177+
156178 config_path = this_dir / f'{ self .model_version } ' / 'config.json'
157179 checkpoint_dir = this_dir / f'{ self .model_version } '
158180
@@ -170,13 +192,11 @@ def load_model(self,load=True):
170192 logger .info ("Model successfully loaded " )
171193
172194 def load_local_model (self ,load = True ):
173- this_model_dir = Path (__file__ ). parent . resolve ()
195+ this_model_dir = Path (self . model_folder )
174196
175197 if self .isModelOfficial (self .model_version ):
176198 download_model (this_model_dir ,self .model_version )
177- this_model_dir = this_model_dir / "models"
178- else :
179- this_model_dir = Path (self .custom_models_folder )
199+ this_model_dir = this_model_dir
180200
181201 config = XttsConfig ()
182202 config_path = this_model_dir / f'{ self .model_version } ' / 'config.json'
@@ -188,6 +208,34 @@ def load_local_model(self,load=True):
188208 self .model .load_checkpoint (config ,use_deepspeed = self .deepspeed , checkpoint_dir = str (checkpoint_dir ))
189209 self .model .to (self .device )
190210
211+ def switch_model (self ,model_name ):
212+
213+ model_list = self .get_models_list ()
214+ # Check to see if the same name is selected
215+ if (model_name == self .model_version ):
216+ raise InvalidSettingsError ("The model with this name is already loaded in memory" )
217+ return
218+
219+ # Check if the model is in the list at all
220+ if (model_name not in model_list ):
221+ raise InvalidSettingsError (f"A model with `{ model_name } ` name is not in the models folder, the current available models: { model_list } " )
222+ return
223+
224+ # Clear gpu cache from old model
225+ self .model = ""
226+ torch .cuda .empty_cache ()
227+ logger .info ("Model successfully unloaded from memory" )
228+
229+ # Start load model
230+ logger .info (f"Start loading { model_name } model" )
231+ self .model_version = model_name
232+ if self .model_source == "local" :
233+ self .load_local_model ()
234+ else :
235+ self .load_model ()
236+
237+ logger .info (f"Model successfully loaded" )
238+
191239 # LOWVRAM FUNCS
192240 def switch_model_device (self ):
193241 # We check for lowram and the existence of cuda
@@ -222,7 +270,7 @@ def create_latents_for_all(self):
222270
223271 # DIRICTORIES FUNCS
224272 def create_directories (self ):
225- directories = [self .output_folder , self .speaker_folder ,self .custom_models_folder ]
273+ directories = [self .output_folder , self .speaker_folder ,self .model_folder ]
226274
227275 for sanctuary in directories :
228276 # List of folders to be checked for existence
@@ -249,6 +297,50 @@ def set_out_folder(self, folder):
249297 else :
250298 raise ValueError ("Provided path is not a valid directory" )
251299
300+ def set_tts_settings (self , temperature , speed , length_penalty ,
301+ repetition_penalty , top_p , top_k , enable_text_splitting ):
302+ # Validate each parameter and raise an exception if any checks fail.
303+
304+ # Check temperature
305+ if not (0.01 <= temperature <= 1 ):
306+ raise InvalidSettingsError ("Temperature must be between 0.01 and 1." )
307+
308+ # Check speed
309+ if not (0.2 <= speed <= 2 ):
310+ raise InvalidSettingsError ("Speed must be between 0.2 and 2." )
311+
312+ # Check length_penalty (no explicit range specified)
313+ if not isinstance (length_penalty , float ):
314+ raise InvalidSettingsError ("Length penalty must be a floating point number." )
315+
316+ # Check repetition_penalty
317+ if not (0.1 <= repetition_penalty <= 10.0 ):
318+ raise InvalidSettingsError ("Repetition penalty must be between 0.1 and 10.0." )
319+
320+ # Check top_p
321+ if not (0.01 <= top_p <= 1 ):
322+ raise InvalidSettingsError ("Top_p must be between 0.01 and 1 and must be a float." )
323+
324+ # Check top_k
325+ if not (1 <= top_k <= 100 ):
326+ raise InvalidSettingsError ("Top_k must be an integer between 1 and 100." )
327+
328+ # Check enable_text_splitting
329+ if not isinstance (enable_text_splitting , bool ):
330+ raise InvalidSettingsError ("Enable text splitting must be either True or False." )
331+
332+ # All validations passed - proceed to apply settings.
333+ self .tts_settings = {
334+ "temperature" : temperature ,
335+ "speed" : speed ,
336+ "length_penalty" : length_penalty ,
337+ "repetition_penalty" : repetition_penalty ,
338+ "top_p" : top_p ,
339+ "top_k" : top_k ,
340+ "enable_text_splitting" : enable_text_splitting ,
341+ }
342+ print ("Successfully updated TTS settings." )
343+
252344 # GET FUNCS
253345 def get_wav_files (self , directory ):
254346 """ Finds all the wav files in a directory. """
@@ -361,12 +453,7 @@ async def stream_generation(self,text,speaker_name,speaker_wav,language,output_f
361453 language ,
362454 speaker_embedding = speaker_embedding ,
363455 gpt_cond_latent = gpt_cond_latent ,
364- temperature = 0.75 ,
365- length_penalty = 1.0 ,
366- repetition_penalty = 5.0 ,
367- top_k = 50 ,
368- top_p = 0.85 ,
369- enable_text_splitting = True ,
456+ ** self .tts_settings , # Expands the object with the settings and applies them for generation
370457 stream_chunk_size = 100 ,
371458 )
372459
@@ -402,12 +489,7 @@ def local_generation(self,text,speaker_name,speaker_wav,language,output_file):
402489 language ,
403490 gpt_cond_latent = gpt_cond_latent ,
404491 speaker_embedding = speaker_embedding ,
405- temperature = 0.75 ,
406- length_penalty = 1.0 ,
407- repetition_penalty = 5.0 ,
408- top_k = 50 ,
409- top_p = 0.85 ,
410- enable_text_splitting = True
492+ ** self .tts_settings , # Expands the object with the settings and applies them for generation
411493 )
412494
413495 torchaudio .save (output_file , torch .tensor (out ["wav" ]).unsqueeze (0 ), 24000 )
0 commit comments