|
| 1 | +from xtts_api_server.tts_funcs import official_model_list |
1 | 2 | from torch.multiprocessing import Process, Pipe, Event, set_start_method |
2 | 3 | from .base_engine import BaseEngine |
3 | 4 | from typing import Union, List |
@@ -92,7 +93,16 @@ def __init__(self, |
92 | 93 | ModelManager().download_model(model_name) |
93 | 94 | else: |
94 | 95 | logging.info(f"Local XTTS Model: \"{specific_model}\" specified") |
95 | | - self.local_model_path = self.download_model(specific_model, local_models_path) |
| 96 | + is_official_model = False |
| 97 | + for model in official_model_list: |
| 98 | + if self.specific_model == model: |
| 99 | + is_official_model = True |
| 100 | + break |
| 101 | + |
| 102 | + if is_official_model: |
| 103 | + self.local_model_path = self.download_model(specific_model, local_models_path) |
| 104 | + else: |
| 105 | + self.local_model_path = os.path.join(local_models_path,specific_model) |
96 | 106 |
|
97 | 107 | self.synthesize_process = Process(target=CoquiEngine._synthesize_worker, args=(child_synthesize_pipe, model_name, cloning_reference_wav, language, self.main_synthesize_ready_event, level, self.speed, thread_count, stream_chunk_size, full_sentences, overlap_wav_len, temperature, length_penalty, repetition_penalty, top_k, top_p, enable_text_splitting, use_mps, self.local_model_path, use_deepspeed, self.voices_path)) |
98 | 108 | self.synthesize_process.start() |
@@ -540,28 +550,29 @@ def download_file(url, destination): |
540 | 550 | progress_bar.close() |
541 | 551 |
|
542 | 552 | @staticmethod |
543 | | - def download_model(model_name = "2.0.2", local_models_path = None): |
| 553 | + def download_model(model_name = "v2.0.2", local_models_path = None): |
544 | 554 |
|
545 | 555 | # Creating a unique folder for each model version |
546 | 556 | if local_models_path and len(local_models_path) > 0: |
547 | | - model_folder = os.path.join(local_models_path, f'v{model_name}') |
| 557 | + model_folder = os.path.join(local_models_path, f'{model_name}') |
548 | 558 | logging.info(f"Local models path: \"{model_folder}\"") |
549 | 559 | else: |
550 | | - model_folder = os.path.join(os.getcwd(), 'models', f'v{model_name}') |
| 560 | + model_folder = os.path.join(os.getcwd(), 'models', f'{model_name}') |
551 | 561 | logging.info(f"Checking for models within application directory: \"{model_folder}\"") |
552 | 562 |
|
553 | 563 | os.makedirs(model_folder, exist_ok=True) |
| 564 | + print(model_name) |
554 | 565 |
|
555 | 566 | files = { |
556 | | - "config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_name}/config.json", |
557 | | - "model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/v{model_name}/model.pth?download=true", |
558 | | - "vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_name}/vocab.json" |
| 567 | + "config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_name}/config.json", |
| 568 | + "model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/{model_name}/model.pth?download=true", |
| 569 | + "vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_name}/vocab.json" |
559 | 570 | } |
560 | 571 |
|
561 | 572 | for file_name, url in files.items(): |
562 | 573 | file_path = os.path.join(model_folder, file_name) |
563 | 574 | if not os.path.exists(file_path): |
564 | | - logger.info(f"Downloading {file_name} for Model v{model_name}...") |
| 575 | + logger.info(f"Downloading {file_name} for Model {model_name}...") |
565 | 576 | CoquiEngine.download_file(url, file_path) |
566 | 577 | # r = requests.get(url, allow_redirects=True) |
567 | 578 | # with open(file_path, 'wb') as f: |
|
0 commit comments