Skip to content

Commit fb16a09

Browse files
authored
Merge pull request #37 from Cohee1207/streaming
Add HTTP streaming for local models
2 parents 6d5eca8 + 894b715 commit fb16a09

File tree

2 files changed

+96
-3
lines changed

2 files changed

+96
-3
lines changed

xtts_api_server/server.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from TTS.api import TTS
2-
from fastapi import FastAPI, HTTPException
2+
from fastapi import FastAPI, HTTPException, Request, Query
33
from fastapi.middleware.cors import CORSMiddleware
44
from fastapi.responses import FileResponse,StreamingResponse
55

@@ -179,6 +179,35 @@ def set_speaker_folder(speaker_req: SpeakerFolderRequest):
179179
logger.error(e)
180180
raise HTTPException(status_code=400, detail=str(e))
181181

182+
@app.get('/tts_stream')
183+
async def tts_stream(request: Request, text: str = Query(), speaker_wav: str = Query(), language: str = Query()):
184+
# Validate local model source.
185+
if XTTS.model_source != "local":
186+
raise HTTPException(status_code=400,
187+
detail="HTTP Streaming is only supported for local models.")
188+
# Validate language code against supported languages.
189+
if language.lower() not in supported_languages:
190+
raise HTTPException(status_code=400,
191+
detail="Language code sent is either unsupported or misspelled.")
192+
193+
async def generator():
194+
chunks = XTTS.process_tts_to_file(
195+
text=text,
196+
speaker_name_or_path=speaker_wav,
197+
language=language.lower(),
198+
stream=True,
199+
)
200+
# Write file header to the output stream.
201+
yield XTTS.get_wav_header()
202+
async for chunk in chunks:
203+
# Check if the client is still connected.
204+
disconnected = await request.is_disconnected()
205+
if disconnected:
206+
break
207+
yield chunk
208+
209+
return StreamingResponse(generator(), media_type='audio/x-wav')
210+
182211
@app.post("/tts_to_audio/")
183212
async def tts_to_audio(request: SynthesisRequest):
184213
if STREAM_MODE or STREAM_MODE_IMPROVE:

xtts_api_server/tts_funcs.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import re
1919
import json
2020
import socket
21+
import io
22+
import wave
23+
import numpy as np
2124

2225
# List of supported language codes
2326
supported_languages = {
@@ -87,6 +90,16 @@ def check_model_version_old_format(self,model_version):
8790
return "v"+model_version
8891
return model_version
8992

93+
def get_wav_header(self, channels:int=1, sample_rate:int=24000, width:int=2) -> bytes:
94+
wav_buf = io.BytesIO()
95+
with wave.open(wav_buf, "wb") as out:
96+
out.setnchannels(channels)
97+
out.setsampwidth(width)
98+
out.setframerate(sample_rate)
99+
out.writeframes(b"")
100+
wav_buf.seek(0)
101+
return wav_buf.read()
102+
90103
# CACHE FUNCS
91104
def check_cache(self, text_params):
92105
if not self.enable_cache_results:
@@ -336,6 +349,48 @@ def clean_text(self,text):
336349
text = re.sub(r'"\s?(.*?)\s?"', r"'\1'", text)
337350
return text
338351

352+
async def stream_generation(self,text,speaker_name,speaker_wav,language,output_file):
353+
# Log time
354+
generate_start_time = time.time() # Record the start time of loading the model
355+
356+
gpt_cond_latent, speaker_embedding = self.get_or_create_latents(speaker_name, speaker_wav)
357+
file_chunks = []
358+
359+
chunks = self.model.inference_stream(
360+
text,
361+
language,
362+
speaker_embedding=speaker_embedding,
363+
gpt_cond_latent=gpt_cond_latent,
364+
temperature=0.75,
365+
length_penalty=1.0,
366+
repetition_penalty=5.0,
367+
top_k=50,
368+
top_p=0.85,
369+
enable_text_splitting=True,
370+
stream_chunk_size=100,
371+
)
372+
373+
for chunk in chunks:
374+
if isinstance(chunk, list):
375+
chunk = torch.cat(chunk, dim=0)
376+
file_chunks.append(chunk)
377+
chunk = chunk.cpu().numpy()
378+
chunk = chunk[None, : int(chunk.shape[0])]
379+
chunk = np.clip(chunk, -1, 1)
380+
chunk = (chunk * 32767).astype(np.int16)
381+
yield chunk.tobytes()
382+
383+
if len(file_chunks) > 0:
384+
wav = torch.cat(file_chunks, dim=0)
385+
torchaudio.save(output_file, wav.cpu().squeeze().unsqueeze(0), 24000)
386+
else:
387+
logger.warning("No audio generated.")
388+
389+
generate_end_time = time.time() # Record the time to generate TTS
390+
generate_elapsed_time = generate_end_time - generate_start_time
391+
392+
logger.info(f"Processing time: {generate_elapsed_time:.2f} seconds.")
393+
339394
def local_generation(self,text,speaker_name,speaker_wav,language,output_file):
340395
# Log time
341396
generate_start_time = time.time() # Record the start time of loading the model
@@ -398,7 +453,7 @@ def get_speaker_wav(self, speaker_name_or_path):
398453

399454

400455
# MAIN FUNC
401-
def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or_path="out.wav"):
456+
def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or_path="out.wav", stream=False):
402457
try:
403458
speaker_wav = self.get_speaker_wav(speaker_name_or_path)
404459
# Determine output path based on whether a full path or a file name was provided
@@ -441,7 +496,16 @@ def process_tts_to_file(self, text, speaker_name_or_path, language, file_name_or
441496

442497
# Define generation if model via api or locally
443498
if self.model_source == "local":
444-
self.local_generation(clear_text,speaker_name_or_path,speaker_wav,language,output_file)
499+
if stream:
500+
async def stream_fn():
501+
async for chunk in self.stream_generation(clear_text,speaker_name_or_path,speaker_wav,language,output_file):
502+
yield chunk
503+
self.switch_model_device()
504+
# After generation completes successfully...
505+
self.update_cache(text_params,output_file)
506+
return stream_fn()
507+
else:
508+
self.local_generation(clear_text,speaker_name_or_path,speaker_wav,language,output_file)
445509
else:
446510
self.api_generation(clear_text,speaker_wav,language,output_file)
447511

0 commit comments

Comments
 (0)