Skip to content

Commit 11a177b

Browse files
committed
remove chinese characters in dualcodec
1 parent 3145aae commit 11a177b

File tree

14 files changed

+1139
-43
lines changed

14 files changed

+1139
-43
lines changed

models/codec/dualcodec/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ python -m dualcodec.infer.valle.cli_valle_infer --ref_audio <path_to_ref_audio>
144144
```
145145
You can also leave all options empty and it will use the default values.
146146

147+
#### Gradio interface
148+
```bash
149+
python -m dualcodec.infer.valle.gradio_valle_demo
150+
```
151+
147152
### DualCodec-Voicebox
148153
#### CLI Inference
149154
```bash
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
vocab_size: 84644 # ${51866+16384+10}
2+
speech_vocab_size: 32768
3+
initial_offset: 10
4+
5+
llama_cfg:
6+
_target_: transformers.models.llama.modeling_llama.LlamaConfig
7+
vocab_size: ${..vocab_size}
8+
hidden_size: 2048
9+
intermediate_size: 8192
10+
num_hidden_layers: 10
11+
num_attention_heads: 16
12+
pad_token_id: 0
13+
bos_token_id: 1
14+
eos_token_id: 2
15+
16+
llm:
17+
_target_: transformers.models.llama.modeling_llama.LlamaForCausalLM
18+
config: ${..llama_cfg}
19+
model:
20+
_target_: dualcodec.model_tts.flattened_ar.llama_wrapper.LLM
21+
llm: ${..llm}
22+
config: ${..llama_cfg}
23+
speech_vocab_size: ${..speech_vocab_size}
24+
initial_offset: ${..initial_offset}
25+
sep_token: 3
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
from einops import rearrange
3+
4+
import numpy as np
5+
6+
def offset_codes(semantic_code, offset_sizes):
7+
"""
8+
Applies layer-specific offsets to each codec layer.
9+
10+
Args:
11+
semantic_code (torch.Tensor): Input tensor of shape (batch_size, T, num_codec_layers).
12+
offset_sizes (list[int]): List of offsets for each codec layer to distinguish them.
13+
14+
Returns:
15+
torch.Tensor: Offset-applied tensor of shape (batch_size, T, num_codec_layers).
16+
"""
17+
# Calculate cumulative offsets for each layer
18+
cumulative_offsets = np.cumsum([0] + offset_sizes[:-1]) # Start with 0 for the first layer
19+
# Apply offsets layer by layer
20+
offsetted_code = []
21+
for i, offset in enumerate(cumulative_offsets):
22+
current_layer_code = semantic_code[..., i].clone().detach() # Extract layer i
23+
current_layer_code += offset # Apply the cumulative offset
24+
offsetted_code.append(current_layer_code)
25+
26+
# Stack all layers along the codec layer dimension
27+
offsetted_code = torch.stack(offsetted_code, dim=-1) # Shape: (batch_size, T, num_codec_layers)
28+
29+
return offsetted_code
30+
31+
def deoffset_codes(flattened_codes, offset_sizes):
32+
"""
33+
De-offsets a flattened tensor by subtracting the codebook size offsets for each codec layer.
34+
35+
Args:
36+
flattened_codes (torch.Tensor): The offset and flattened tensor of shape (batch_size, T * num_codec_layers).
37+
codebook_sizes (list[int]): A list of codebook sizes for each codec layer, used to remove offsets.
38+
39+
Returns:
40+
torch.Tensor: The de-offset tensor of shape (batch_size, T, num_codec_layers).
41+
"""
42+
# Calculate cumulative offsets for each layer
43+
cumulative_offsets = np.cumsum([0] + offset_sizes[:-1]) # Start with 0 for the first layer
44+
45+
# Determine dimensions for reshaping
46+
batch_size, flattened_dim = flattened_codes.shape
47+
num_codec_layers = len(offset_sizes)
48+
T = flattened_dim // num_codec_layers
49+
50+
# Reshape flattened_codes back to (batch_size, T, num_codec_layers)
51+
reshaped_codes = flattened_codes.view(batch_size, T, num_codec_layers)
52+
53+
# De-offset each layer by subtracting the respective cumulative offset
54+
deoffsetted_code = []
55+
for i, offset in enumerate(cumulative_offsets):
56+
current_layer_code = reshaped_codes[..., i].clone() # Clone to avoid in-place operation
57+
current_layer_code = current_layer_code - offset # Remove the cumulative offset
58+
deoffsetted_code.append(current_layer_code)
59+
60+
# Stack all layers along the codec layer dimension
61+
deoffsetted_code = torch.stack(deoffsetted_code, dim=-1) # Shape: (batch_size, T, num_codec_layers)
62+
63+
return deoffsetted_code
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from cv2 import repeat
2+
import torch
3+
from einops import rearrange
4+
from .flatten_patterns import offset_codes, deoffset_codes
5+
6+
7+
class Inference:
8+
def __init__(
9+
self, model, tokenizer_obj, dualcodec_inference_obj, device="cuda", normalize=False,
10+
half=False, split_paragraph=True, offset_sizes=[16384, 4096, 4096, 4096], **kwargs
11+
) -> None:
12+
self.model = model
13+
import safetensors.torch
14+
15+
self.model.to(device)
16+
self.model.eval()
17+
self.tokenizer = tokenizer_obj
18+
self.dualcodec_inference_obj = dualcodec_inference_obj
19+
self.device = device
20+
self.normalize = normalize
21+
self.offset_sizes = offset_sizes
22+
23+
self.model = self.model.half()
24+
25+
self.split_paragraph = split_paragraph
26+
27+
@torch.no_grad()
28+
def inference(
29+
self,
30+
speech_24k,
31+
prompt_speech,
32+
prompt_text,
33+
prompt_language,
34+
target_text,
35+
target_language,
36+
use_prompt_text=True,
37+
temp=1.0,
38+
top_k=1000,
39+
top_p=0.85,
40+
repeat_penalty=1.1,
41+
):
42+
"""
43+
Generate text given speech and text prompts.
44+
45+
Args:
46+
prompt_speech (str or Tensor): Speech file path or a tensor with shape (n_samples,).
47+
prompt_text (str): Text prompt.
48+
prompt_language (str): Language of the prompt.
49+
target_text (str): Target text to be completed.
50+
target_language (str): Language of the target text.
51+
use_prompt_text (bool, optional): Whether to use the prompt text as input. Defaults to True.
52+
temp (float, optional): Temperature parameter for the distribution. Defaults to 1.0.
53+
top_k (int, optional): Number of tokens to keep before applying `top_p`. Defaults to 1000.
54+
top_p (float, optional): Probability threshold to use for filtering tokens. Defaults to 0.85.
55+
56+
Returns:
57+
str: Completed text.
58+
"""
59+
self.model.eval()
60+
prompt_text = prompt_text.strip()
61+
# prompt_text = prompt_text.replace('.',',')
62+
# prompt_text = prompt_text.replace('。',',')
63+
target_text = target_text.replace("\n", "")
64+
target_text = target_text.replace("\t", "")
65+
return_values_0 = []
66+
return_values_1 = []
67+
68+
prompt_len_tmp = len(self.tokenizer.encode(prompt_text)) // 2
69+
70+
if self.split_paragraph:
71+
if prompt_language == 'zh':
72+
from dualcodec.utils.frontend_utils import split_paragraph
73+
texts = split_paragraph(
74+
target_text,
75+
None,
76+
"zh",
77+
token_max_n=60 - prompt_len_tmp,
78+
token_min_n=40 - prompt_len_tmp,
79+
merge_len=20,
80+
comma_split=False,
81+
)
82+
elif prompt_language == 'ja':
83+
from dualcodec.utils.frontend_utils import split_paragraph
84+
texts = split_paragraph(
85+
target_text,
86+
None,
87+
"zh",
88+
token_max_n=70,
89+
token_min_n=60,
90+
merge_len=20,
91+
comma_split=False,
92+
)
93+
elif prompt_language == 'en':
94+
from dualcodec.utils.frontend_utils import split_paragraph
95+
texts = split_paragraph(
96+
target_text,
97+
self.tokenizer.encode,
98+
"en",
99+
token_max_n=70 - prompt_len_tmp,
100+
token_min_n=60 - prompt_len_tmp,
101+
merge_len=20,
102+
comma_split=True,
103+
)
104+
else:
105+
texts = [target_text]
106+
if prompt_language == 'en':
107+
texts = [prompt_text + ' ' + t for t in texts]
108+
else:
109+
texts = [prompt_text + t for t in texts]
110+
print(texts)
111+
112+
all_codes = []
113+
114+
for text in texts:
115+
116+
if self.normalize:
117+
from dualcodec.dataset.processor import normalize
118+
text = list(normalize([{
119+
'language': prompt_language,
120+
'text': text,
121+
}], en_punct=True, use_kana=False))[0]['text']
122+
print(text)
123+
124+
125+
prompt_text_tokens = torch.tensor(
126+
[
127+
[self.tokenizer.to_language_token(prompt_language)]
128+
+ self.tokenizer.encode(text)
129+
],
130+
dtype=torch.int32,
131+
device=self.device,
132+
)
133+
prompt_text_len = torch.tensor(
134+
[prompt_text_tokens.shape[-1]], device=self.device
135+
)
136+
137+
# target_text_tokens = torch.tensor(
138+
# [tokenizer.encode(target_text)], dtype=torch.int32
139+
# )
140+
# target_text_len = torch.tensor([target_text_tokens.shape[-1]])
141+
142+
text_token = prompt_text_tokens
143+
144+
# prompt semantic codes
145+
# semantic_code, _ = self._extract_semantic_code(input_features, attention_mask)
146+
semantic_codes, acoustic_codes = self.dualcodec_inference_obj.encode(prompt_speech, n_quantizers=4)
147+
semantic_codes = rearrange(semantic_codes, 'b t -> b t 1')
148+
num_codec_layers = 4
149+
semantic_code = torch.cat([semantic_codes, acoustic_codes], dim=-1)[..., :num_codec_layers]
150+
151+
semantic_code = offset_codes(semantic_code, self.offset_sizes)
152+
semantic_code = rearrange(semantic_code, 'b t q -> b (t q)')
153+
154+
ret_semantic_code = semantic_code.clone().detach()
155+
156+
out = self.model.inference(
157+
text=text_token,
158+
text_len=prompt_text_len,
159+
prompt_text=None,
160+
prompt_text_len=None,
161+
prompt_speech_token=semantic_code,
162+
prompt_speech_token_len=torch.tensor([semantic_code.shape[-1]]),
163+
top_k=top_k,
164+
top_p=top_p,
165+
repeat_penalty=repeat_penalty,
166+
temperature=temp,
167+
)
168+
out = deoffset_codes(out, self.offset_sizes)
169+
170+
all_codes.append(out)
171+
172+
all_codes = torch.cat(all_codes, dim=1) # FIXME not tested
173+
out = self.dualcodec_inference_obj.decode(all_codes)
174+
return out

0 commit comments

Comments
 (0)