Skip to content

Commit ebb9bdb

Browse files
authored
Merge pull request #24 from lendot/sample-wav-multiple
Support for speakers with multiple speaker_wav files
2 parents eb14eeb + 631427b commit ebb9bdb

File tree

3 files changed

+104
-57
lines changed

3 files changed

+104
-57
lines changed

xtts_api_server/RealtimeTTS/engines/coqui_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def send_command(self, command, data):
401401
message = {'command': command, 'data': data}
402402
self.parent_synthesize_pipe.send(message)
403403

404-
def set_cloning_reference(self, cloning_reference_wav: str):
404+
def set_cloning_reference(self, cloning_reference_wav: Union[str, List[str]]):
405405
"""
406406
Send an 'update_reference' command and wait for a response.
407407
"""
@@ -594,7 +594,7 @@ def get_voices(self):
594594

595595
return voice_file_names
596596

597-
def set_voice(self, voice: str):
597+
def set_voice(self, voice: Union[str, List[str]]):
598598
"""
599599
Sets the voice to be used for speech synthesis.
600600
"""
@@ -637,4 +637,4 @@ def shutdown(self):
637637

638638
# Wait for the process to terminate
639639
self.synthesize_process.join()
640-
logging.info('Worker process has been terminated')
640+
logging.info('Worker process has been terminated')

xtts_api_server/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_folders():
147147
output_folder = XTTS.output_folder
148148
return {"speaker_folder": speaker_folder, "output_folder": output_folder}
149149

150-
@app.get("/sample/{file_name}")
150+
@app.get("/sample/{file_name:path}")
151151
def get_sample(file_name: str):
152152
file_path = os.path.join(XTTS.speaker_folder, file_name)
153153
if os.path.isfile(file_path):
@@ -184,7 +184,7 @@ async def tts_to_audio(request: SynthesisRequest):
184184
raise HTTPException(status_code=400,
185185
detail="Language code sent is either unsupported or misspelled.")
186186

187-
speaker_wav = XTTS.get_speaker_path(request.speaker_wav)
187+
speaker_wav = XTTS.get_speaker_wav(request.speaker_wav)
188188
language = request.language[0:2]
189189

190190
if stream.is_playing() and not STREAM_PLAY_SYNC:
@@ -262,4 +262,4 @@ async def tts_to_file(request: SynthesisFileRequest):
262262
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
263263

264264
if __name__ == "__main__":
265-
uvicorn.run(app,host="0.0.0.0",port=8002)
265+
uvicorn.run(app,host="0.0.0.0",port=8002)

xtts_api_server/tts_funcs.py

Lines changed: 98 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -111,19 +111,18 @@ def switch_model_device(self):
111111
# Clearing the cache to free up VRAM
112112
torch.cuda.empty_cache()
113113

114-
def get_or_create_latents(self, speaker_wav):
115-
if speaker_wav not in self.latents_cache:
114+
def get_or_create_latents(self, speaker_name, speaker_wav):
115+
if speaker_name not in self.latents_cache:
116+
logger.info(f"creating latents for {speaker_name}: {speaker_wav}")
116117
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(speaker_wav)
117-
self.latents_cache[speaker_wav] = (gpt_cond_latent, speaker_embedding)
118-
return self.latents_cache[speaker_wav]
118+
self.latents_cache[speaker_name] = (gpt_cond_latent, speaker_embedding)
119+
return self.latents_cache[speaker_name]
119120

120121
def create_latents_for_all(self):
121-
speakers_list = self.get_speakers()
122+
speakers_list = self._get_speakers()
122123

123-
for speaker_name in speakers_list:
124-
speaker_wav = os.path.join(self.speaker_folder, speaker_name+".wav")
125-
126-
self.get_or_create_latents(speaker_wav)
124+
for speaker in speakers_list:
125+
self.get_or_create_latents(speaker['speaker_name'],speaker['speaker_wav'])
127126

128127
logger.info(f"Latents created for all {len(speakers_list)} speakers.")
129128

@@ -137,7 +136,7 @@ def create_directories(self):
137136
if not os.path.exists(absolute_path):
138137
# If the folder does not exist, create it
139138
os.makedirs(absolute_path)
140-
print(f"Folder in the path {absolute_path} has been created")
139+
logger.info(f"Folder in the path {absolute_path} has been created")
141140

142141
def set_speaker_folder(self, folder):
143142
if os.path.exists(folder) and os.path.isdir(folder):
@@ -155,38 +154,78 @@ def set_out_folder(self, folder):
155154
else:
156155
raise ValueError("Provided path is not a valid directory")
157156

158-
def list_speakers(self):
159-
speakers_list = [f for f in os.listdir(self.speaker_folder) if f.endswith('.wav')]
160-
return speakers_list
157+
def get_wav_files(self, directory):
158+
""" Finds all the wav files in a directory. """
159+
wav_files = [f for f in os.listdir(directory) if f.endswith('.wav')]
160+
return wav_files
161+
162+
def _get_speakers(self):
163+
"""
164+
Gets info on all the speakers.
165+
166+
Returns a list of {speaker_name,speaker_wav,preview} dicts
167+
"""
168+
speakers = []
169+
for f in os.listdir(self.speaker_folder):
170+
full_path = os.path.join(self.speaker_folder,f)
171+
if os.path.isdir(full_path):
172+
# multi-sample voice
173+
subdir_files = self.get_wav_files(full_path)
174+
if len(subdir_files) == 0:
175+
# no wav files in directory
176+
continue
177+
178+
speaker_name = f
179+
speaker_wav = [os.path.join(self.speaker_folder,f,s) for s in subdir_files]
180+
# use the first file found as the preview
181+
preview = os.path.join(f,subdir_files[0])
182+
speakers.append({
183+
'speaker_name': speaker_name,
184+
'speaker_wav': speaker_wav,
185+
'preview': preview
186+
})
187+
188+
elif f.endswith('.wav'):
189+
speaker_name = os.path.splitext(f)[0]
190+
speaker_wav = full_path
191+
preview = f
192+
speakers.append({
193+
'speaker_name': speaker_name,
194+
'speaker_wav': speaker_wav,
195+
'preview': preview
196+
})
197+
return speakers
161198

162199
def get_speakers(self):
163-
# Use os.path.splitext to split off the extension and take only the name
164-
speakers_list = [os.path.splitext(f)[0] for f in os.listdir(self.speaker_folder) if f.endswith('.wav')]
165-
return speakers_list
200+
""" Gets available speakers """
201+
speakers = [ s['speaker_name'] for s in self._get_speakers() ]
202+
return speakers
203+
166204
# Special format for SillyTavern
167205
def get_speakers_special(self):
168-
speakers_list = []
169206
BASE_URL = os.getenv('BASE_URL', '127.0.0.1:8020')
170207
TUNNEL_URL = os.getenv('TUNNEL_URL', '')
171208

172-
preview_url = ""
173-
for file in os.listdir(self.speaker_folder):
174-
209+
speakers_special = []
210+
211+
speakers = self._get_speakers()
212+
213+
for speaker in speakers:
175214
if TUNNEL_URL == "":
176-
preview_url = f"{BASE_URL}/sample/{file}"
215+
preview_url = f"{BASE_URL}/sample/{speaker['preview']}"
177216
else:
178-
preview_url = f"{TUNNEL_URL}/sample/{file}"
217+
preview_url = f"{TUNNEL_URL}/sample/{speaker['preview']}"
179218

180-
if file.endswith('.wav'):
181-
speaker_name = os.path.splitext(file)[0]
182-
speaker = {
183-
'name': speaker_name,
184-
'voice_id': speaker_name,
219+
speaker_special = {
220+
'name': speaker['speaker_name'],
221+
'voice_id': speaker['speaker_name'],
185222
'preview_url': preview_url
186-
}
187-
speakers_list.append(speaker)
188-
return speakers_list
189-
223+
}
224+
speakers_special.append(speaker_special)
225+
226+
return speakers_special
227+
228+
190229
def list_languages(self):
191230
return reversed_supported_languages
192231

@@ -197,11 +236,11 @@ def clean_text(self,text):
197236
text = re.sub(r'"\s?(.*?)\s?"', r"'\1'", text)
198237
return text
199238

200-
def local_generation(self,text,speaker_wav,language,output_file):
239+
def local_generation(self,text,speaker_name,speaker_wav,language,output_file):
201240
# Log time
202241
generate_start_time = time.time() # Record the start time of loading the model
203242

204-
gpt_cond_latent, speaker_embedding = self.get_or_create_latents(speaker_wav)
243+
gpt_cond_latent, speaker_embedding = self.get_or_create_latents(speaker_name, speaker_wav)
205244

206245
out = self.model.inference(
207246
text,
@@ -230,29 +269,37 @@ def api_generation(self,text,speaker_wav,language,output_file):
230269
language=language,
231270
file_path=output_file,
232271
)
233-
234-
def get_speaker_path(self,speaker_name_or_path):
235-
# Check if the speaker path is a .wav file or just the name
272+
273+
def get_speaker_wav(self, speaker_name_or_path):
274+
""" Gets the speaker_wav(s) for a given speaker name. """
236275
if speaker_name_or_path.endswith('.wav'):
237-
if os.path.isabs(speaker_name_or_path):
238-
# If it's an absolute path for the speaker file
239-
speaker_wav = speaker_name_or_path
240-
else:
241-
# It's just a filename; append it to the speakers folder
242-
speaker_wav = os.path.join(self.speaker_folder, speaker_name_or_path)
276+
# it's a file name
277+
if os.path.isabs(spekaer_name_or_path):
278+
# absolute path; nothing to do
279+
speaker_wav = speaker_name_or_path
280+
else:
281+
# make it a full path
282+
speaker_wav = os.path.join(self.speaker_folder, speaker_name_or_path)
243283
else:
244-
# Look for the corresponding .wav in our list of speakers
245-
speakers_list = self.list_speakers()
246-
if f"{speaker_name_or_path}.wav" in speakers_list:
247-
speaker_wav = os.path.join(self.speaker_folder, f"{speaker_name_or_path}.wav")
248-
else:
249-
raise ValueError(f"Speaker {speaker_name_or_path} not found.")
284+
# it's a speaker name
285+
full_path = os.path.join(self.speaker_folder, speaker_name_or_path)
286+
wav_file = f"{full_path}.wav"
287+
if os.path.isdir(full_path):
288+
# multi-sample speaker
289+
speaker_wav = [ os.path.join(full_path,wav) for wav in self.get_wav_files(full_path) ]
290+
if len(speaker_wav) == 0:
291+
raise ValueError(f"no wav files found in {full_path}")
292+
elif os.path.isfile(wav_file):
293+
speaker_wav = wav_file
294+
else:
295+
raise ValueError(f"Speaker {speaker_name_or_path} not found.")
296+
250297
return speaker_wav
251298

252299

253300
def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or_path="out.wav"):
254301
try:
255-
speaker_wav = self.get_speaker_path(speaker_name_or_path)
302+
speaker_wav = self.get_speaker_wav(speaker_name_or_path)
256303
# Determine output path based on whether a full path or a file name was provided
257304
if os.path.isabs(file_name_or_path):
258305
# An absolute path was provided by user; use as is.
@@ -268,7 +315,7 @@ def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or
268315

269316
# Define generation if model via api or locally
270317
if self.model_source == "local":
271-
self.local_generation(clear_text,speaker_wav,language,output_file)
318+
self.local_generation(clear_text,speaker_name_or_path,speaker_wav,language,output_file)
272319
else:
273320
self.api_generation(clear_text,speaker_wav,language,output_file)
274321

@@ -282,4 +329,4 @@ def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or
282329

283330

284331

285-
332+

0 commit comments

Comments
 (0)