Skip to content

Commit 2c42b47

Browse files
authored
Add MaskGCT demo ipynb (#306)
1 parent c37307d commit 2c42b47

File tree

2 files changed

+314
-1
lines changed

2 files changed

+314
-1
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ processed_data
5656
data
5757
model_ckpt
5858
logs
59-
*.ipynb
6059
*.lst
6160
source_audio
6261
result

models/tts/maskgct/maskgct_demo.ipynb

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch\n",
10+
"import numpy as np\n",
11+
"import librosa\n",
12+
"import safetensors\n",
13+
"from utils.util import load_config\n",
14+
"\n",
15+
"from models.codec.kmeans.repcodec_model import RepCodec\n",
16+
"from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A\n",
17+
"from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S\n",
18+
"from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder\n",
19+
"from transformers import Wav2Vec2BertModel\n",
20+
"\n",
21+
"from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"from transformers import SeamlessM4TFeatureExtractor\n",
31+
"processor = SeamlessM4TFeatureExtractor.from_pretrained(\"facebook/w2v-bert-2.0\")"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": 1,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"def g2p_(text, language):\n",
41+
" if language in [\"zh\", \"en\"]:\n",
42+
" return chn_eng_g2p(text)\n",
43+
" else:\n",
44+
" return g2p(text, sentence=None, language=language)\n",
45+
"\n",
46+
"def build_t2s_model(cfg, device):\n",
47+
" t2s_model = MaskGCT_T2S(cfg=cfg)\n",
48+
" t2s_model.eval()\n",
49+
" t2s_model.to(device)\n",
50+
" return t2s_model\n",
51+
"\n",
52+
"def build_s2a_model(cfg, device):\n",
53+
" soundstorm_model = MaskGCT_S2A(cfg=cfg)\n",
54+
" soundstorm_model.eval()\n",
55+
" soundstorm_model.to(device)\n",
56+
" return soundstorm_model\n",
57+
"\n",
58+
"def build_semantic_model(device):\n",
59+
" semantic_model = Wav2Vec2BertModel.from_pretrained(\"facebook/w2v-bert-2.0\")\n",
60+
" semantic_model.eval()\n",
61+
" semantic_model.to(device)\n",
62+
" stat_mean_var = torch.load(\"./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt\")\n",
63+
" semantic_mean = stat_mean_var[\"mean\"]\n",
64+
" semantic_std = torch.sqrt(stat_mean_var[\"var\"])\n",
65+
" semantic_mean = semantic_mean.to(device)\n",
66+
" semantic_std = semantic_std.to(device)\n",
67+
" return semantic_model, semantic_mean, semantic_std\n",
68+
"\n",
69+
"def build_semantic_codec(cfg, device):\n",
70+
" semantic_codec = RepCodec(cfg=cfg)\n",
71+
" semantic_codec.eval()\n",
72+
" semantic_codec.to(device)\n",
73+
" return semantic_codec\n",
74+
"\n",
75+
"def build_acoustic_codec(cfg, device):\n",
76+
" codec_encoder = CodecEncoder(cfg=cfg.encoder)\n",
77+
" codec_decoder = CodecDecoder(cfg=cfg.decoder)\n",
78+
" codec_encoder.eval()\n",
79+
" codec_decoder.eval()\n",
80+
" codec_encoder.to(device)\n",
81+
" codec_decoder.to(device)\n",
82+
" return codec_encoder, codec_decoder"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"@torch.no_grad()\n",
92+
"def extract_features(speech, processor):\n",
93+
" inputs = processor(speech, sampling_rate=16000, return_tensors=\"pt\")\n",
94+
" input_features = inputs[\"input_features\"][0]\n",
95+
" attention_mask = inputs[\"attention_mask\"][0]\n",
96+
" return input_features, attention_mask\n",
97+
"\n",
98+
"@torch.no_grad()\n",
99+
"def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):\n",
100+
" vq_emb = semantic_model(\n",
101+
" input_features=input_features,\n",
102+
" attention_mask=attention_mask,\n",
103+
" output_hidden_states=True,\n",
104+
" )\n",
105+
" feat = vq_emb.hidden_states[17] # (B, T, C)\n",
106+
" feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)\n",
107+
"\n",
108+
" semantic_code, rec_feat = semantic_codec.quantize(feat) # (B, T)\n",
109+
" return semantic_code, rec_feat\n",
110+
"\n",
111+
"@torch.no_grad()\n",
112+
"def extract_acoustic_code(speech):\n",
113+
" vq_emb = codec_encoder(speech.unsqueeze(1))\n",
114+
" _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)\n",
115+
" acoustic_code = vq.permute(\n",
116+
" 1, 2, 0\n",
117+
" )\n",
118+
" return acoustic_code\n",
119+
"\n",
120+
"@torch.no_grad()\n",
121+
"def text2semantic(prompt_speech, prompt_text, prompt_language, target_text, target_language, target_len=None, n_timesteps=50, cfg=2.5, rescale_cfg=0.75):\n",
122+
" \n",
123+
" prompt_phone_id = g2p_(prompt_text, prompt_language)[1]\n",
124+
"\n",
125+
" target_phone_id = g2p_(target_text, target_language)[1]\n",
126+
"\n",
127+
" if target_len is None:\n",
128+
" target_len = int((len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id)) / 16000 * 50)\n",
129+
" else:\n",
130+
" target_len = int(target_len * 50)\n",
131+
"\n",
132+
" prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device)\n",
133+
" target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device)\n",
134+
"\n",
135+
" phone_id = torch.cat([prompt_phone_id, target_phone_id]) \n",
136+
"\n",
137+
" input_fetures, attention_mask = extract_features(prompt_speech, processor)\n",
138+
" input_fetures = input_fetures.unsqueeze(0).to(device)\n",
139+
" attention_mask = attention_mask.unsqueeze(0).to(device)\n",
140+
" semantic_code, _ = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask)\n",
141+
"\n",
142+
" predict_semantic = t2s_model.reverse_diffusion(semantic_code[:, :], target_len, phone_id.unsqueeze(0), n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg)\n",
143+
"\n",
144+
" print(\"predict semantic shape\", predict_semantic.shape)\n",
145+
"\n",
146+
" combine_semantic_code = torch.cat([semantic_code[:,:], predict_semantic], dim=-1)\n",
147+
" prompt_semantic_code = semantic_code\n",
148+
"\n",
149+
" return combine_semantic_code, prompt_semantic_code\n",
150+
"\n",
151+
"@torch.no_grad()\n",
152+
"def semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=[25,10,1,1,1,1,1,1,1,1,1,1], cfg=2.5, rescale_cfg=0.75):\n",
153+
"\n",
154+
" semantic_code = combine_semantic_code\n",
155+
" \n",
156+
" cond = s2a_model_1layer.cond_emb(semantic_code)\n",
157+
" prompt = acoustic_code[:,:,:]\n",
158+
" predict_1layer = s2a_model_1layer.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps[:1], cfg=cfg, rescale_cfg=rescale_cfg)\n",
159+
"\n",
160+
" cond = s2a_model_full.cond_emb(semantic_code)\n",
161+
" prompt = acoustic_code[:,:,:]\n",
162+
" predict_full = s2a_model_full.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg, gt_code=predict_1layer)\n",
163+
" \n",
164+
" vq_emb = codec_decoder.vq2emb(predict_full.permute(2,0,1), n_quantizers=12)\n",
165+
" recovered_audio = codec_decoder(vq_emb)\n",
166+
" prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1), n_quantizers=12)\n",
167+
" recovered_prompt_audio = codec_decoder(prompt_vq_emb)\n",
168+
" recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()\n",
169+
" recovered_audio = recovered_audio[0][0].cpu().numpy()\n",
170+
" combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])\n",
171+
"\n",
172+
" return combine_audio, recovered_audio"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": null,
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"def maskgct_inference(prompt_speech_path, prompt_text, target_text, language=\"en\", target_language=\"en\", target_len=None, n_timesteps=25, cfg=2.5, rescale_cfg=0.75, n_timesteps_s2a=[25,10,1,1,1,1,1,1,1,1,1,1], cfg_s2a=2.5, rescale_cfg_s2a=0.75):\n",
182+
" speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]\n",
183+
" speech = librosa.load(prompt_speech_path, sr=24000)[0]\n",
184+
"\n",
185+
" combine_semantic_code, _ = text2semantic(speech_16k, prompt_text, language, target_text, target_language, target_len, n_timesteps, cfg, rescale_cfg)\n",
186+
" acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))\n",
187+
" _, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=n_timesteps_s2a, cfg=cfg_s2a, rescale_cfg=rescale_cfg_s2a)\n",
188+
"\n",
189+
" return recovered_audio"
190+
]
191+
},
192+
{
193+
"cell_type": "markdown",
194+
"metadata": {},
195+
"source": [
196+
"# Build Model"
197+
]
198+
},
199+
{
200+
"cell_type": "code",
201+
"execution_count": null,
202+
"metadata": {},
203+
"outputs": [],
204+
"source": [
205+
"device = torch.device(\"cuda:0\")\n",
206+
"cfg_path = \"./models/tts/maskgct/config/maskgct.json\"\n",
207+
"cfg = load_config(cfg_path)\n",
208+
"\n",
209+
"# 1. build semantic model (w2v-bert-2.0)\n",
210+
"semantic_model, semantic_mean, semantic_std = build_semantic_model(device)\n",
211+
"# 2. build semantic codec\n",
212+
"semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)\n",
213+
"# 3. build acoustic codec\n",
214+
"codec_encoder, codec_decoder = build_acoustic_codec(cfg.model.acoustic_codec, device)\n",
215+
"# 4. build t2s model\n",
216+
"t2s_model = build_t2s_model(cfg.model.t2s_model, device)\n",
217+
"# 5. build s2a model\n",
218+
"s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)\n",
219+
"s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)"
220+
]
221+
},
222+
{
223+
"cell_type": "markdown",
224+
"metadata": {},
225+
"source": [
226+
"# Load Checkpoints"
227+
]
228+
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": null,
232+
"metadata": {},
233+
"outputs": [],
234+
"source": [
235+
"from huggingface_hub import hf_hub_download\n",
236+
"\n",
237+
"# download semantic codec ckpt\n",
238+
"semantic_code_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"semantic_codec/model.safetensors\")\n",
239+
"# download acoustic codec ckpt\n",
240+
"codec_encoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model.safetensors\")\n",
241+
"codec_decoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model_1.safetensors\")\n",
242+
"# download t2s model ckpt\n",
243+
"t2s_model_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"t2s_model/model.safetensors\")\n",
244+
"# download s2a model ckpt\n",
245+
"s2a_1layer_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_1layer/model.safetensors\")\n",
246+
"s2a_full_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_full/model.safetensors\")"
247+
]
248+
},
249+
{
250+
"cell_type": "code",
251+
"execution_count": null,
252+
"metadata": {},
253+
"outputs": [],
254+
"source": [
255+
"# load semantic codec\n",
256+
"safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)\n",
257+
"# load acoustic codec\n",
258+
"safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)\n",
259+
"safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)\n",
260+
"# load t2s model\n",
261+
"safetensors.torch.load_model(t2s_model, t2s_model_ckpt)\n",
262+
"# load s2a model\n",
263+
"safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)\n",
264+
"safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)"
265+
]
266+
},
267+
{
268+
"cell_type": "code",
269+
"execution_count": null,
270+
"metadata": {},
271+
"outputs": [],
272+
"source": [
273+
"prompt_wav_path = \"./models/tts/maskgct/wav/prompt.wav\"\n",
274+
"prompt_text = \" We do not break. We never give in. We never back down.\"\n",
275+
"target_text = \"In this paper, we introduce MaskGCT, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision.\"\n",
276+
"target_len = 18 # Specify the target duration (in seconds). If target_len = None, we use a simple rule to predict the target duration.\n",
277+
"recovered_audio = maskgct_inference(prompt_wav_path, prompt_text, target_text, \"en\", \"en\", target_len=target_len)"
278+
]
279+
},
280+
{
281+
"cell_type": "code",
282+
"execution_count": null,
283+
"metadata": {},
284+
"outputs": [],
285+
"source": [
286+
"from IPython.display import Audio\n",
287+
"Audio(recovered_audio, rate=24000)"
288+
]
289+
}
290+
],
291+
"metadata": {
292+
"fileId": "8353ad98-61bb-49ea-b655-c8f6a3264cc3",
293+
"filePath": "/opt/tiger/SpeechGeneration2/models/tts/maskgct/maskgct_demo.ipynb",
294+
"kernelspec": {
295+
"display_name": "Python 3",
296+
"language": "python",
297+
"name": "python3"
298+
},
299+
"language_info": {
300+
"codemirror_mode": {
301+
"name": "ipython",
302+
"version": 3
303+
},
304+
"file_extension": ".py",
305+
"mimetype": "text/x-python",
306+
"name": "python",
307+
"nbconvert_exporter": "python",
308+
"pygments_lexer": "ipython3",
309+
"version": "3.9.2"
310+
}
311+
},
312+
"nbformat": 4,
313+
"nbformat_minor": 2
314+
}

0 commit comments

Comments
 (0)