|
18 | 18 | import re
|
19 | 19 | import json
|
20 | 20 | import socket
|
| 21 | +import io |
| 22 | +import wave |
| 23 | +import numpy as np |
21 | 24 |
|
22 | 25 | # List of supported language codes
|
23 | 26 | supported_languages = {
|
@@ -87,6 +90,16 @@ def check_model_version_old_format(self,model_version):
|
87 | 90 | return "v"+model_version
|
88 | 91 | return model_version
|
89 | 92 |
|
| 93 | + def get_wav_header(self, channels:int=1, sample_rate:int=24000, width:int=2) -> bytes: |
| 94 | + wav_buf = io.BytesIO() |
| 95 | + with wave.open(wav_buf, "wb") as out: |
| 96 | + out.setnchannels(channels) |
| 97 | + out.setsampwidth(width) |
| 98 | + out.setframerate(sample_rate) |
| 99 | + out.writeframes(b"") |
| 100 | + wav_buf.seek(0) |
| 101 | + return wav_buf.read() |
| 102 | + |
90 | 103 | # CACHE FUNCS
|
91 | 104 | def check_cache(self, text_params):
|
92 | 105 | if not self.enable_cache_results:
|
@@ -336,6 +349,48 @@ def clean_text(self,text):
|
336 | 349 | text = re.sub(r'"\s?(.*?)\s?"', r"'\1'", text)
|
337 | 350 | return text
|
338 | 351 |
|
| 352 | + async def stream_generation(self,text,speaker_name,speaker_wav,language,output_file): |
| 353 | + # Log time |
| 354 | + generate_start_time = time.time() # Record the start time of loading the model |
| 355 | + |
| 356 | + gpt_cond_latent, speaker_embedding = self.get_or_create_latents(speaker_name, speaker_wav) |
| 357 | + file_chunks = [] |
| 358 | + |
| 359 | + chunks = self.model.inference_stream( |
| 360 | + text, |
| 361 | + language, |
| 362 | + speaker_embedding=speaker_embedding, |
| 363 | + gpt_cond_latent=gpt_cond_latent, |
| 364 | + temperature=0.75, |
| 365 | + length_penalty=1.0, |
| 366 | + repetition_penalty=5.0, |
| 367 | + top_k=50, |
| 368 | + top_p=0.85, |
| 369 | + enable_text_splitting=True, |
| 370 | + stream_chunk_size=100, |
| 371 | + ) |
| 372 | + |
| 373 | + for chunk in chunks: |
| 374 | + if isinstance(chunk, list): |
| 375 | + chunk = torch.cat(chunk, dim=0) |
| 376 | + file_chunks.append(chunk) |
| 377 | + chunk = chunk.cpu().numpy() |
| 378 | + chunk = chunk[None, : int(chunk.shape[0])] |
| 379 | + chunk = np.clip(chunk, -1, 1) |
| 380 | + chunk = (chunk * 32767).astype(np.int16) |
| 381 | + yield chunk.tobytes() |
| 382 | + |
| 383 | + if len(file_chunks) > 0: |
| 384 | + wav = torch.cat(file_chunks, dim=0) |
| 385 | + torchaudio.save(output_file, wav.cpu().squeeze().unsqueeze(0), 24000) |
| 386 | + else: |
| 387 | + logger.warning("No audio generated.") |
| 388 | + |
| 389 | + generate_end_time = time.time() # Record the time to generate TTS |
| 390 | + generate_elapsed_time = generate_end_time - generate_start_time |
| 391 | + |
| 392 | + logger.info(f"Processing time: {generate_elapsed_time:.2f} seconds.") |
| 393 | + |
339 | 394 | def local_generation(self,text,speaker_name,speaker_wav,language,output_file):
|
340 | 395 | # Log time
|
341 | 396 | generate_start_time = time.time() # Record the start time of loading the model
|
@@ -398,7 +453,7 @@ def get_speaker_wav(self, speaker_name_or_path):
|
398 | 453 |
|
399 | 454 |
|
400 | 455 | # MAIN FUNC
|
401 |
| - def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or_path="out.wav"): |
| 456 | + def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or_path="out.wav", stream=False): |
402 | 457 | try:
|
403 | 458 | speaker_wav = self.get_speaker_wav(speaker_name_or_path)
|
404 | 459 | # Determine output path based on whether a full path or a file name was provided
|
@@ -441,7 +496,16 @@ def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or
|
441 | 496 |
|
442 | 497 | # Define generation if model via api or locally
|
443 | 498 | if self.model_source == "local":
|
444 |
| - self.local_generation(clear_text,speaker_name_or_path,speaker_wav,language,output_file) |
| 499 | + if stream: |
| 500 | + async def stream_fn(): |
| 501 | + async for chunk in self.stream_generation(clear_text,speaker_name_or_path,speaker_wav,language,output_file): |
| 502 | + yield chunk |
| 503 | + self.switch_model_device() |
| 504 | + # After generation completes successfully... |
| 505 | + self.update_cache(text_params,output_file) |
| 506 | + return stream_fn() |
| 507 | + else: |
| 508 | + self.local_generation(clear_text,speaker_name_or_path,speaker_wav,language,output_file) |
445 | 509 | else:
|
446 | 510 | self.api_generation(clear_text,speaker_wav,language,output_file)
|
447 | 511 |
|
|
0 commit comments