|
| 1 | +from xtts_api_server.tts_funcs import official_model_list |
1 | 2 | from torch.multiprocessing import Process, Pipe, Event, set_start_method
|
2 | 3 | from .base_engine import BaseEngine
|
3 | 4 | from typing import Union, List
|
@@ -92,7 +93,16 @@ def __init__(self,
|
92 | 93 | ModelManager().download_model(model_name)
|
93 | 94 | else:
|
94 | 95 | logging.info(f"Local XTTS Model: \"{specific_model}\" specified")
|
95 |
| - self.local_model_path = self.download_model(specific_model, local_models_path) |
| 96 | + is_official_model = False |
| 97 | + for model in official_model_list: |
| 98 | + if self.specific_model == model: |
| 99 | + is_official_model = True |
| 100 | + break |
| 101 | + |
| 102 | + if is_official_model: |
| 103 | + self.local_model_path = self.download_model(specific_model, local_models_path) |
| 104 | + else: |
| 105 | + self.local_model_path = os.path.join(local_models_path,specific_model) |
96 | 106 |
|
97 | 107 | self.synthesize_process = Process(target=CoquiEngine._synthesize_worker, args=(child_synthesize_pipe, model_name, cloning_reference_wav, language, self.main_synthesize_ready_event, level, self.speed, thread_count, stream_chunk_size, full_sentences, overlap_wav_len, temperature, length_penalty, repetition_penalty, top_k, top_p, enable_text_splitting, use_mps, self.local_model_path, use_deepspeed, self.voices_path))
|
98 | 108 | self.synthesize_process.start()
|
@@ -540,28 +550,29 @@ def download_file(url, destination):
|
540 | 550 | progress_bar.close()
|
541 | 551 |
|
542 | 552 | @staticmethod
|
543 |
| - def download_model(model_name = "2.0.2", local_models_path = None): |
| 553 | + def download_model(model_name = "v2.0.2", local_models_path = None): |
544 | 554 |
|
545 | 555 | # Creating a unique folder for each model version
|
546 | 556 | if local_models_path and len(local_models_path) > 0:
|
547 |
| - model_folder = os.path.join(local_models_path, f'v{model_name}') |
| 557 | + model_folder = os.path.join(local_models_path, f'{model_name}') |
548 | 558 | logging.info(f"Local models path: \"{model_folder}\"")
|
549 | 559 | else:
|
550 |
| - model_folder = os.path.join(os.getcwd(), 'models', f'v{model_name}') |
| 560 | + model_folder = os.path.join(os.getcwd(), 'models', f'{model_name}') |
551 | 561 | logging.info(f"Checking for models within application directory: \"{model_folder}\"")
|
552 | 562 |
|
553 | 563 | os.makedirs(model_folder, exist_ok=True)
|
| 564 | + print(model_name) |
554 | 565 |
|
555 | 566 | files = {
|
556 |
| - "config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_name}/config.json", |
557 |
| - "model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/v{model_name}/model.pth?download=true", |
558 |
| - "vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_name}/vocab.json" |
| 567 | + "config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_name}/config.json", |
| 568 | + "model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/{model_name}/model.pth?download=true", |
| 569 | + "vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_name}/vocab.json" |
559 | 570 | }
|
560 | 571 |
|
561 | 572 | for file_name, url in files.items():
|
562 | 573 | file_path = os.path.join(model_folder, file_name)
|
563 | 574 | if not os.path.exists(file_path):
|
564 |
| - logger.info(f"Downloading {file_name} for Model v{model_name}...") |
| 575 | + logger.info(f"Downloading {file_name} for Model {model_name}...") |
565 | 576 | CoquiEngine.download_file(url, file_path)
|
566 | 577 | # r = requests.get(url, allow_redirects=True)
|
567 | 578 | # with open(file_path, 'wb') as f:
|
|
0 commit comments