Skip to content

Commit e1a458e

Browse files
committed
correct model loading, customize model folder, and generation params
1 parent 5c56c21 commit e1a458e

File tree

6 files changed

+157
-38
lines changed

6 files changed

+157
-38
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ models
77
xtts_api_server/models
88
*.pyc
99
xtts_api_server/RealtimeTTS/engines/coqui_engine_old.py
10+
xtts_models

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "xtts-api-server"
7-
version = "0.7.6"
7+
version = "0.8.0"
88
authors = [
99
{ name="daswer123", email="daswerq123@gmail.com" },
1010
]

xtts_api_server/__main__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
parser.add_argument("-hs", "--host", default="localhost", help="Host to bind")
77
parser.add_argument("-p", "--port", default=8020, type=int, help="Port to bind")
88
parser.add_argument("-d", "--device", default="cuda", type=str, help="Device that will be used, you can choose cpu or cuda")
9-
parser.add_argument("-sf", "--speaker_folder", default="speakers/", type=str, help="The folder where you get the samples for tts")
9+
parser.add_argument("-sf", "--speaker-folder", default="speakers/", type=str, help="The folder where you get the samples for tts")
1010
parser.add_argument("-o", "--output", default="output/", type=str, help="Output folder")
1111
parser.add_argument("-t", "--tunnel", default="", type=str, help="URL of tunnel used (e.g: ngrok, localtunnel)")
12+
parser.add_argument("-mf", "--model-folder", default="xtts_models/", type=str, help="The place where models for XTTS will be stored, finetuned models should be stored in this folder.")
1213
parser.add_argument("-ms", "--model-source", default="local", choices=["api","apiManual", "local"],
1314
help="Define the model source: 'api' for latest version from repository, apiManual for 2.0.2 model and api inference or 'local' for using local inference and model v2.0.2.")
1415
parser.add_argument("-v", "--version", default="v2.0.2", type=str, help="You can specify which version of xtts to use or specify your own model, just upload model folder in models folder ,This version will be used everywhere in local and apiManual.")
@@ -28,6 +29,7 @@
2829
os.environ['DEVICE'] = args.device # Set environment variable for output folder.
2930
os.environ['OUTPUT'] = args.output # Set environment variable for output folder.
3031
os.environ['SPEAKER'] = args.speaker_folder # Set environment variable for speaker folder.
32+
os.environ['MODEL'] = args.model_folder # Set environment variable for model folder.
3133
os.environ['BASE_HOST'] = host_ip # Set environment variable for base host."
3234
os.environ['BASE_PORT'] = str(args.port) # Set environment variable for base port."
3335
os.environ['BASE_URL'] = "http://" + host_ip + ":" + str(args.port) # Set environment variable for base url."
@@ -41,8 +43,6 @@
4143
os.environ["STREAM_MODE_IMPROVE"] = str(args.streaming_mode_improve).lower() # Enable improved Streaming mode
4244
os.environ["STREAM_PLAY_SYNC"] = str(args.stream_play_sync).lower() # Enable Streaming mode
4345

44-
45-
4646
from xtts_api_server.server import app
4747

4848
uvicorn.run(app, host=host_ip, port=args.port)

xtts_api_server/modeldownloader.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,15 @@ def check_stream2sentence_version():
119119

120120
def download_model(this_dir,model_version):
121121
# Define paths
122-
base_path = this_dir / 'models'
122+
base_path = this_dir
123123
model_path = base_path / f'{model_version}'
124124

125125
# Define files and their corresponding URLs
126126
files_to_download = {
127127
"config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/config.json",
128128
"model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/{model_version}/model.pth?download=true",
129-
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/vocab.json"
129+
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/vocab.json",
130+
"speakers_xtts.pth": "https://huggingface.co/coqui/XTTS-v2/resolve/main/speakers_xtts.pth?download=true"
130131
}
131132

132133
# Check and create directories

xtts_api_server/server.py

+44-9
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
from argparse import ArgumentParser
1515
from pathlib import Path
1616

17-
from xtts_api_server.tts_funcs import TTSWrapper,supported_languages
17+
from xtts_api_server.tts_funcs import TTSWrapper,supported_languages,InvalidSettingsError
1818
from xtts_api_server.RealtimeTTS import TextToAudioStream, CoquiEngine
1919
from xtts_api_server.modeldownloader import check_stream2sentence_version,install_deepspeed_based_on_python_version
2020

2121
# Default Folders , you can change them via API
2222
DEVICE = os.getenv('DEVICE',"cuda")
2323
OUTPUT_FOLDER = os.getenv('OUTPUT', 'output')
2424
SPEAKER_FOLDER = os.getenv('SPEAKER', 'speakers')
25+
MODEL_FOLDER = os.getenv('MODEL', 'models')
2526
BASE_HOST = os.getenv('BASE_URL', '127.0.0.1:8020')
2627
BASE_URL = os.getenv('BASE_URL', '127.0.0.1:8020')
2728
MODEL_SOURCE = os.getenv("MODEL_SOURCE", "local")
@@ -40,7 +41,7 @@
4041

4142
# Create an instance of the TTSWrapper class and server
4243
app = FastAPI()
43-
XTTS = TTSWrapper(OUTPUT_FOLDER,SPEAKER_FOLDER,LOWVRAM_MODE,MODEL_SOURCE,MODEL_VERSION,DEVICE,DEEPSPEED,USE_CACHE)
44+
XTTS = TTSWrapper(OUTPUT_FOLDER,SPEAKER_FOLDER,MODEL_FOLDER,LOWVRAM_MODE,MODEL_SOURCE,MODEL_VERSION,DEVICE,DEEPSPEED,USE_CACHE)
4445

4546
# Check for old format model version
4647
XTTS.model_version = XTTS.check_model_version_old_format(MODEL_VERSION)
@@ -63,12 +64,7 @@
6364
if STREAM_MODE_IMPROVE:
6465
logger.info("You launched an improved version of streaming, this version features an improved tokenizer and more context when processing sentences, which can be good for complex languages like Chinese")
6566

66-
this_dir = Path(__file__).parent.resolve()
67-
68-
if XTTS.isModelOfficial(MODEL_VERSION):
69-
model_path = this_dir / "models"
70-
else:
71-
model_path = "models"
67+
model_path = XTTS.model_folder
7268

7369
engine = CoquiEngine(specific_model=MODEL_VERSION,use_deepspeed=DEEPSPEED,local_models_path=str(model_path))
7470
stream = TextToAudioStream(engine)
@@ -120,6 +116,18 @@ class OutputFolderRequest(BaseModel):
120116
class SpeakerFolderRequest(BaseModel):
121117
speaker_folder: str
122118

119+
class ModelNameRequest(BaseModel):
120+
model_name: str
121+
122+
class TTSSettingsRequest(BaseModel):
123+
temperature: float
124+
speed: float
125+
length_penalty: float
126+
repetition_penalty: float
127+
top_p: float
128+
top_k: int
129+
enable_text_splitting: bool
130+
123131
class SynthesisRequest(BaseModel):
124132
text: str
125133
speaker_wav: str
@@ -150,7 +158,16 @@ def get_languages():
150158
def get_folders():
151159
speaker_folder = XTTS.speaker_folder
152160
output_folder = XTTS.output_folder
153-
return {"speaker_folder": speaker_folder, "output_folder": output_folder}
161+
model_folder = XTTS.model_folder
162+
return {"speaker_folder": speaker_folder, "output_folder": output_folder,"model_folder":model_folder}
163+
164+
@app.get("/get_models_list")
165+
def get_models_list():
166+
return XTTS.get_models_list()
167+
168+
@app.get("/get_tts_settings")
169+
def get_tts_settings():
170+
return XTTS.tts_settings
154171

155172
@app.get("/sample/{file_name:path}")
156173
def get_sample(file_name: str):
@@ -179,6 +196,24 @@ def set_speaker_folder(speaker_req: SpeakerFolderRequest):
179196
logger.error(e)
180197
raise HTTPException(status_code=400, detail=str(e))
181198

199+
@app.post("/switch_model")
200+
def switch_model(modelReq: ModelNameRequest):
201+
try:
202+
XTTS.switch_model(modelReq.model_name)
203+
return {"message": f"Model switched to {modelReq.model_name}"}
204+
except InvalidSettingsError as e:
205+
logger.error(e)
206+
raise HTTPException(status_code=400, detail=str(e))
207+
208+
@app.post("/set_tts_settings")
209+
def set_tts_settings_endpoint(tts_settings_req: TTSSettingsRequest):
210+
try:
211+
XTTS.set_tts_settings(**tts_settings_req.dict())
212+
return {"message": "Settings successfully applied"}
213+
except InvalidSettingsError as e:
214+
logger.error(e)
215+
raise HTTPException(status_code=400, detail=str(e))
216+
182217
@app.get('/tts_stream')
183218
async def tts_stream(request: Request, text: str = Query(), speaker_wav: str = Query(), language: str = Query()):
184219
# Validate local model source.

xtts_api_server/tts_funcs.py

+105-23
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
import wave
2323
import numpy as np
2424

25+
# Class to check tts settings
26+
class InvalidSettingsError(Exception):
27+
pass
28+
2529
# List of supported language codes
2630
supported_languages = {
2731
"ar":"Arabic",
@@ -43,13 +47,23 @@
4347
"hi":"Hindi"
4448
}
4549

50+
default_tts_settings = {
51+
"temperature" : 0.75,
52+
"length_penalty" : 1.0,
53+
"repetition_penalty": 5.0,
54+
"top_k" : 50,
55+
"top_p" : 0.85,
56+
"speed" : 1,
57+
"enable_text_splitting": True
58+
}
59+
4660
official_model_list = ["v2.0.0","v2.0.1","v2.0.2","v2.0.3","main"]
4761
official_model_list_v2 = ["2.0.0","2.0.1","2.0.2","2.0.3"]
4862

4963
reversed_supported_languages = {name: code for code, name in supported_languages.items()}
5064

5165
class TTSWrapper:
52-
def __init__(self,output_folder = "./output", speaker_folder="./speakers",lowvram = False,model_source = "local",model_version = "2.0.2",device = "cuda",deepspeed = False,enable_cache_results = True):
66+
def __init__(self,output_folder = "./output", speaker_folder="./speakers",model_folder="./xtts_folder",lowvram = False,model_source = "local",model_version = "2.0.2",device = "cuda",deepspeed = False,enable_cache_results = True):
5367

5468
self.cuda = device # If the user has chosen what to use, we rewrite the value to the value we want to use
5569
self.device = 'cpu' if lowvram else (self.cuda if torch.cuda.is_available() else "cpu")
@@ -59,12 +73,13 @@ def __init__(self,output_folder = "./output", speaker_folder="./speakers",lowvra
5973

6074
self.model_source = model_source
6175
self.model_version = model_version
76+
self.tts_settings = default_tts_settings
6277

6378
self.deepspeed = deepspeed
6479

6580
self.speaker_folder = speaker_folder
6681
self.output_folder = output_folder
67-
self.custom_models_folder = "./models"
82+
self.model_folder = model_folder
6883

6984
self.create_directories()
7085
check_tts_version()
@@ -90,6 +105,14 @@ def check_model_version_old_format(self,model_version):
90105
return "v"+model_version
91106
return model_version
92107

108+
def get_models_list(self):
109+
# Fetch all entries in the directory given by self.model_folder
110+
entries = os.listdir(self.model_folder)
111+
112+
# Filter out and return only directories
113+
return [name for name in entries if os.path.isdir(os.path.join(self.model_folder, name))]
114+
115+
93116
def get_wav_header(self, channels:int=1, sample_rate:int=24000, width:int=2) -> bytes:
94117
wav_buf = io.BytesIO()
95118
with wave.open(wav_buf, "wb") as out:
@@ -147,12 +170,11 @@ def load_model(self,load=True):
147170
self.model = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
148171

149172
if self.model_source == "apiManual":
150-
this_dir = Path(__file__).parent.resolve() / "models"
173+
this_dir = Path(self.model_folder)
174+
151175
if self.isModelOfficial(self.model_version):
152176
download_model(this_dir,self.model_version)
153-
else:
154-
this_dir = Path(self.custom_models_folder).resolve()
155-
177+
156178
config_path = this_dir / f'{self.model_version}' / 'config.json'
157179
checkpoint_dir = this_dir / f'{self.model_version}'
158180

@@ -170,13 +192,11 @@ def load_model(self,load=True):
170192
logger.info("Model successfully loaded ")
171193

172194
def load_local_model(self,load=True):
173-
this_model_dir = Path(__file__).parent.resolve()
195+
this_model_dir = Path(self.model_folder)
174196

175197
if self.isModelOfficial(self.model_version):
176198
download_model(this_model_dir,self.model_version)
177-
this_model_dir = this_model_dir / "models"
178-
else:
179-
this_model_dir = Path(self.custom_models_folder)
199+
this_model_dir = this_model_dir
180200

181201
config = XttsConfig()
182202
config_path = this_model_dir / f'{self.model_version}' / 'config.json'
@@ -188,6 +208,34 @@ def load_local_model(self,load=True):
188208
self.model.load_checkpoint(config,use_deepspeed=self.deepspeed, checkpoint_dir=str(checkpoint_dir))
189209
self.model.to(self.device)
190210

211+
def switch_model(self,model_name):
212+
213+
model_list = self.get_models_list()
214+
# Check to see if the same name is selected
215+
if(model_name == self.model_version):
216+
raise InvalidSettingsError("The model with this name is already loaded in memory")
217+
return
218+
219+
# Check if the model is in the list at all
220+
if(model_name not in model_list):
221+
raise InvalidSettingsError(f"A model with `{model_name}` name is not in the models folder, the current available models: {model_list}")
222+
return
223+
224+
# Clear gpu cache from old model
225+
self.model = ""
226+
torch.cuda.empty_cache()
227+
logger.info("Model successfully unloaded from memory")
228+
229+
# Start load model
230+
logger.info(f"Start loading {model_name} model")
231+
self.model_version = model_name
232+
if self.model_source == "local":
233+
self.load_local_model()
234+
else:
235+
self.load_model()
236+
237+
logger.info(f"Model successfully loaded")
238+
191239
# LOWVRAM FUNCS
192240
def switch_model_device(self):
193241
# We check for lowram and the existence of cuda
@@ -222,7 +270,7 @@ def create_latents_for_all(self):
222270

223271
# DIRICTORIES FUNCS
224272
def create_directories(self):
225-
directories = [self.output_folder, self.speaker_folder,self.custom_models_folder]
273+
directories = [self.output_folder, self.speaker_folder,self.model_folder]
226274

227275
for sanctuary in directories:
228276
# List of folders to be checked for existence
@@ -249,6 +297,50 @@ def set_out_folder(self, folder):
249297
else:
250298
raise ValueError("Provided path is not a valid directory")
251299

300+
def set_tts_settings(self, temperature, speed, length_penalty,
301+
repetition_penalty, top_p, top_k, enable_text_splitting):
302+
# Validate each parameter and raise an exception if any checks fail.
303+
304+
# Check temperature
305+
if not (0.01 <= temperature <= 1):
306+
raise InvalidSettingsError("Temperature must be between 0.01 and 1.")
307+
308+
# Check speed
309+
if not (0.2 <= speed <= 2):
310+
raise InvalidSettingsError("Speed must be between 0.2 and 2.")
311+
312+
# Check length_penalty (no explicit range specified)
313+
if not isinstance(length_penalty, float):
314+
raise InvalidSettingsError("Length penalty must be a floating point number.")
315+
316+
# Check repetition_penalty
317+
if not (0.1 <= repetition_penalty <= 10.0):
318+
raise InvalidSettingsError("Repetition penalty must be between 0.1 and 10.0.")
319+
320+
# Check top_p
321+
if not (0.01 <= top_p <= 1):
322+
raise InvalidSettingsError("Top_p must be between 0.01 and 1 and must be a float.")
323+
324+
# Check top_k
325+
if not (1 <= top_k <= 100):
326+
raise InvalidSettingsError("Top_k must be an integer between 1 and 100.")
327+
328+
# Check enable_text_splitting
329+
if not isinstance(enable_text_splitting, bool):
330+
raise InvalidSettingsError("Enable text splitting must be either True or False.")
331+
332+
# All validations passed - proceed to apply settings.
333+
self.tts_settings = {
334+
"temperature": temperature,
335+
"speed": speed,
336+
"length_penalty": length_penalty,
337+
"repetition_penalty": repetition_penalty,
338+
"top_p": top_p,
339+
"top_k": top_k,
340+
"enable_text_splitting": enable_text_splitting,
341+
}
342+
print("Successfully updated TTS settings.")
343+
252344
# GET FUNCS
253345
def get_wav_files(self, directory):
254346
""" Finds all the wav files in a directory. """
@@ -361,12 +453,7 @@ async def stream_generation(self,text,speaker_name,speaker_wav,language,output_f
361453
language,
362454
speaker_embedding=speaker_embedding,
363455
gpt_cond_latent=gpt_cond_latent,
364-
temperature=0.75,
365-
length_penalty=1.0,
366-
repetition_penalty=5.0,
367-
top_k=50,
368-
top_p=0.85,
369-
enable_text_splitting=True,
456+
**self.tts_settings, # Expands the object with the settings and applies them for generation
370457
stream_chunk_size=100,
371458
)
372459

@@ -402,12 +489,7 @@ def local_generation(self,text,speaker_name,speaker_wav,language,output_file):
402489
language,
403490
gpt_cond_latent=gpt_cond_latent,
404491
speaker_embedding=speaker_embedding,
405-
temperature=0.75,
406-
length_penalty=1.0,
407-
repetition_penalty=5.0,
408-
top_k=50,
409-
top_p=0.85,
410-
enable_text_splitting=True
492+
**self.tts_settings, # Expands the object with the settings and applies them for generation
411493
)
412494

413495
torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)

0 commit comments

Comments
 (0)