@@ -111,19 +111,18 @@ def switch_model_device(self):
111
111
# Clearing the cache to free up VRAM
112
112
torch .cuda .empty_cache ()
113
113
114
- def get_or_create_latents (self , speaker_wav ):
115
- if speaker_wav not in self .latents_cache :
114
+ def get_or_create_latents (self , speaker_name , speaker_wav ):
115
+ if speaker_name not in self .latents_cache :
116
+ logger .info (f"creating latents for { speaker_name } : { speaker_wav } " )
116
117
gpt_cond_latent , speaker_embedding = self .model .get_conditioning_latents (speaker_wav )
117
- self .latents_cache [speaker_wav ] = (gpt_cond_latent , speaker_embedding )
118
- return self .latents_cache [speaker_wav ]
118
+ self .latents_cache [speaker_name ] = (gpt_cond_latent , speaker_embedding )
119
+ return self .latents_cache [speaker_name ]
119
120
120
121
def create_latents_for_all (self ):
121
- speakers_list = self .get_speakers ()
122
+ speakers_list = self ._get_speakers ()
122
123
123
- for speaker_name in speakers_list :
124
- speaker_wav = os .path .join (self .speaker_folder , speaker_name + ".wav" )
125
-
126
- self .get_or_create_latents (speaker_wav )
124
+ for speaker in speakers_list :
125
+ self .get_or_create_latents (speaker ['speaker_name' ],speaker ['speaker_wav' ])
127
126
128
127
logger .info (f"Latents created for all { len (speakers_list )} speakers." )
129
128
@@ -137,7 +136,7 @@ def create_directories(self):
137
136
if not os .path .exists (absolute_path ):
138
137
# If the folder does not exist, create it
139
138
os .makedirs (absolute_path )
140
- print (f"Folder in the path { absolute_path } has been created" )
139
+ logger . info (f"Folder in the path { absolute_path } has been created" )
141
140
142
141
def set_speaker_folder (self , folder ):
143
142
if os .path .exists (folder ) and os .path .isdir (folder ):
@@ -155,38 +154,78 @@ def set_out_folder(self, folder):
155
154
else :
156
155
raise ValueError ("Provided path is not a valid directory" )
157
156
158
- def list_speakers (self ):
159
- speakers_list = [f for f in os .listdir (self .speaker_folder ) if f .endswith ('.wav' )]
160
- return speakers_list
157
+ def get_wav_files (self , directory ):
158
+ """ Finds all the wav files in a directory. """
159
+ wav_files = [f for f in os .listdir (directory ) if f .endswith ('.wav' )]
160
+ return wav_files
161
+
162
+ def _get_speakers (self ):
163
+ """
164
+ Gets info on all the speakers.
165
+
166
+ Returns a list of {speaker_name,speaker_wav,preview} dicts
167
+ """
168
+ speakers = []
169
+ for f in os .listdir (self .speaker_folder ):
170
+ full_path = os .path .join (self .speaker_folder ,f )
171
+ if os .path .isdir (full_path ):
172
+ # multi-sample voice
173
+ subdir_files = self .get_wav_files (full_path )
174
+ if len (subdir_files ) == 0 :
175
+ # no wav files in directory
176
+ continue
177
+
178
+ speaker_name = f
179
+ speaker_wav = [os .path .join (self .speaker_folder ,f ,s ) for s in subdir_files ]
180
+ # use the first file found as the preview
181
+ preview = os .path .join (f ,subdir_files [0 ])
182
+ speakers .append ({
183
+ 'speaker_name' : speaker_name ,
184
+ 'speaker_wav' : speaker_wav ,
185
+ 'preview' : preview
186
+ })
187
+
188
+ elif f .endswith ('.wav' ):
189
+ speaker_name = os .path .splitext (f )[0 ]
190
+ speaker_wav = full_path
191
+ preview = f
192
+ speakers .append ({
193
+ 'speaker_name' : speaker_name ,
194
+ 'speaker_wav' : speaker_wav ,
195
+ 'preview' : preview
196
+ })
197
+ return speakers
161
198
162
199
def get_speakers (self ):
163
- # Use os.path.splitext to split off the extension and take only the name
164
- speakers_list = [os .path .splitext (f )[0 ] for f in os .listdir (self .speaker_folder ) if f .endswith ('.wav' )]
165
- return speakers_list
200
+ """ Gets available speakers """
201
+ speakers = [ s ['speaker_name' ] for s in self ._get_speakers () ]
202
+ return speakers
203
+
166
204
# Special format for SillyTavern
167
205
def get_speakers_special (self ):
168
- speakers_list = []
169
206
BASE_URL = os .getenv ('BASE_URL' , '127.0.0.1:8020' )
170
207
TUNNEL_URL = os .getenv ('TUNNEL_URL' , '' )
171
208
172
- preview_url = ""
173
- for file in os .listdir (self .speaker_folder ):
174
-
209
+ speakers_special = []
210
+
211
+ speakers = self ._get_speakers ()
212
+
213
+ for speaker in speakers :
175
214
if TUNNEL_URL == "" :
176
- preview_url = f"{ BASE_URL } /sample/{ file } "
215
+ preview_url = f"{ BASE_URL } /sample/{ speaker [ 'preview' ] } "
177
216
else :
178
- preview_url = f"{ TUNNEL_URL } /sample/{ file } "
217
+ preview_url = f"{ TUNNEL_URL } /sample/{ speaker [ 'preview' ] } "
179
218
180
- if file .endswith ('.wav' ):
181
- speaker_name = os .path .splitext (file )[0 ]
182
- speaker = {
183
- 'name' : speaker_name ,
184
- 'voice_id' : speaker_name ,
219
+ speaker_special = {
220
+ 'name' : speaker ['speaker_name' ],
221
+ 'voice_id' : speaker ['speaker_name' ],
185
222
'preview_url' : preview_url
186
- }
187
- speakers_list .append (speaker )
188
- return speakers_list
189
-
223
+ }
224
+ speakers_special .append (speaker_special )
225
+
226
+ return speakers_special
227
+
228
+
190
229
def list_languages (self ):
191
230
return reversed_supported_languages
192
231
@@ -197,11 +236,11 @@ def clean_text(self,text):
197
236
text = re .sub (r'"\s?(.*?)\s?"' , r"'\1'" , text )
198
237
return text
199
238
200
- def local_generation (self ,text ,speaker_wav ,language ,output_file ):
239
+ def local_generation (self ,text ,speaker_name , speaker_wav ,language ,output_file ):
201
240
# Log time
202
241
generate_start_time = time .time () # Record the start time of loading the model
203
242
204
- gpt_cond_latent , speaker_embedding = self .get_or_create_latents (speaker_wav )
243
+ gpt_cond_latent , speaker_embedding = self .get_or_create_latents (speaker_name , speaker_wav )
205
244
206
245
out = self .model .inference (
207
246
text ,
@@ -230,29 +269,37 @@ def api_generation(self,text,speaker_wav,language,output_file):
230
269
language = language ,
231
270
file_path = output_file ,
232
271
)
233
-
234
- def get_speaker_path (self ,speaker_name_or_path ):
235
- # Check if the speaker path is a .wav file or just the name
272
+
273
+ def get_speaker_wav (self , speaker_name_or_path ):
274
+ """ Gets the speaker_wav(s) for a given speaker name. """
236
275
if speaker_name_or_path .endswith ('.wav' ):
237
- if os .path .isabs (speaker_name_or_path ):
238
- # If it's an absolute path for the speaker file
239
- speaker_wav = speaker_name_or_path
240
- else :
241
- # It's just a filename; append it to the speakers folder
242
- speaker_wav = os .path .join (self .speaker_folder , speaker_name_or_path )
276
+ # it's a file name
277
+ if os .path .isabs (spekaer_name_or_path ):
278
+ # absolute path; nothing to do
279
+ speaker_wav = speaker_name_or_path
280
+ else :
281
+ # make it a full path
282
+ speaker_wav = os .path .join (self .speaker_folder , speaker_name_or_path )
243
283
else :
244
- # Look for the corresponding .wav in our list of speakers
245
- speakers_list = self .list_speakers ()
246
- if f"{ speaker_name_or_path } .wav" in speakers_list :
247
- speaker_wav = os .path .join (self .speaker_folder , f"{ speaker_name_or_path } .wav" )
248
- else :
249
- raise ValueError (f"Speaker { speaker_name_or_path } not found." )
284
+ # it's a speaker name
285
+ full_path = os .path .join (self .speaker_folder , speaker_name_or_path )
286
+ wav_file = f"{ full_path } .wav"
287
+ if os .path .isdir (full_path ):
288
+ # multi-sample speaker
289
+ speaker_wav = [ os .path .join (full_path ,wav ) for wav in self .get_wav_files (full_path ) ]
290
+ if len (speaker_wav ) == 0 :
291
+ raise ValueError (f"no wav files found in { full_path } " )
292
+ elif os .path .isfile (wav_file ):
293
+ speaker_wav = wav_file
294
+ else :
295
+ raise ValueError (f"Speaker { speaker_name_or_path } not found." )
296
+
250
297
return speaker_wav
251
298
252
299
253
300
def process_tts_to_file (self , text , speaker_name_or_path , language , file_name_or_path = "out.wav" ):
254
301
try :
255
- speaker_wav = self .get_speaker_path (speaker_name_or_path )
302
+ speaker_wav = self .get_speaker_wav (speaker_name_or_path )
256
303
# Determine output path based on whether a full path or a file name was provided
257
304
if os .path .isabs (file_name_or_path ):
258
305
# An absolute path was provided by user; use as is.
@@ -268,7 +315,7 @@ def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or
268
315
269
316
# Define generation if model via api or locally
270
317
if self .model_source == "local" :
271
- self .local_generation (clear_text ,speaker_wav ,language ,output_file )
318
+ self .local_generation (clear_text ,speaker_name_or_path , speaker_wav ,language ,output_file )
272
319
else :
273
320
self .api_generation (clear_text ,speaker_wav ,language ,output_file )
274
321
@@ -282,4 +329,4 @@ def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or
282
329
283
330
284
331
285
-
332
+
0 commit comments