1
+ # Copyright (c) 2025 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
1
5
from cv2 import repeat
2
6
import torch
3
7
from einops import rearrange
6
10
7
11
class Inference :
8
12
def __init__ (
9
- self , model , tokenizer_obj , dualcodec_inference_obj , device = "cuda" , normalize = False ,
10
- half = False , split_paragraph = True , offset_sizes = [16384 , 4096 , 4096 , 4096 ], ** kwargs
13
+ self ,
14
+ model ,
15
+ tokenizer_obj ,
16
+ dualcodec_inference_obj ,
17
+ device = "cuda" ,
18
+ normalize = False ,
19
+ half = False ,
20
+ split_paragraph = True ,
21
+ offset_sizes = [16384 , 4096 , 4096 , 4096 ],
22
+ ** kwargs ,
11
23
) -> None :
12
24
self .model = model
13
25
import safetensors .torch
@@ -21,9 +33,9 @@ def __init__(
21
33
self .offset_sizes = offset_sizes
22
34
23
35
self .model = self .model .half ()
24
-
36
+
25
37
self .split_paragraph = split_paragraph
26
-
38
+
27
39
@torch .no_grad ()
28
40
def inference (
29
41
self ,
@@ -68,8 +80,9 @@ def inference(
68
80
prompt_len_tmp = len (self .tokenizer .encode (prompt_text )) // 2
69
81
70
82
if self .split_paragraph :
71
- if prompt_language == 'zh' :
83
+ if prompt_language == "zh" :
72
84
from dualcodec .utils .frontend_utils import split_paragraph
85
+
73
86
texts = split_paragraph (
74
87
target_text ,
75
88
None ,
@@ -79,8 +92,9 @@ def inference(
79
92
merge_len = 20 ,
80
93
comma_split = False ,
81
94
)
82
- elif prompt_language == 'ja' :
95
+ elif prompt_language == "ja" :
83
96
from dualcodec .utils .frontend_utils import split_paragraph
97
+
84
98
texts = split_paragraph (
85
99
target_text ,
86
100
None ,
@@ -90,8 +104,9 @@ def inference(
90
104
merge_len = 20 ,
91
105
comma_split = False ,
92
106
)
93
- elif prompt_language == 'en' :
107
+ elif prompt_language == "en" :
94
108
from dualcodec .utils .frontend_utils import split_paragraph
109
+
95
110
texts = split_paragraph (
96
111
target_text ,
97
112
self .tokenizer .encode ,
@@ -103,8 +118,8 @@ def inference(
103
118
)
104
119
else :
105
120
texts = [target_text ]
106
- if prompt_language == 'en' :
107
- texts = [prompt_text + ' ' + t for t in texts ]
121
+ if prompt_language == "en" :
122
+ texts = [prompt_text + " " + t for t in texts ]
108
123
else :
109
124
texts = [prompt_text + t for t in texts ]
110
125
print (texts )
@@ -115,12 +130,20 @@ def inference(
115
130
116
131
if self .normalize :
117
132
from dualcodec .dataset .processor import normalize
118
- text = list (normalize ([{
119
- 'language' : prompt_language ,
120
- 'text' : text ,
121
- }], en_punct = True , use_kana = False ))[0 ]['text' ]
122
- print (text )
123
133
134
+ text = list (
135
+ normalize (
136
+ [
137
+ {
138
+ "language" : prompt_language ,
139
+ "text" : text ,
140
+ }
141
+ ],
142
+ en_punct = True ,
143
+ use_kana = False ,
144
+ )
145
+ )[0 ]["text" ]
146
+ print (text )
124
147
125
148
prompt_text_tokens = torch .tensor (
126
149
[
@@ -143,13 +166,17 @@ def inference(
143
166
144
167
# prompt semantic codes
145
168
# semantic_code, _ = self._extract_semantic_code(input_features, attention_mask)
146
- semantic_codes , acoustic_codes = self .dualcodec_inference_obj .encode (prompt_speech , n_quantizers = 4 )
147
- semantic_codes = rearrange (semantic_codes , 'b t -> b t 1' )
169
+ semantic_codes , acoustic_codes = self .dualcodec_inference_obj .encode (
170
+ prompt_speech , n_quantizers = 4
171
+ )
172
+ semantic_codes = rearrange (semantic_codes , "b t -> b t 1" )
148
173
num_codec_layers = 4
149
- semantic_code = torch .cat ([semantic_codes , acoustic_codes ], dim = - 1 )[..., :num_codec_layers ]
174
+ semantic_code = torch .cat ([semantic_codes , acoustic_codes ], dim = - 1 )[
175
+ ..., :num_codec_layers
176
+ ]
150
177
151
178
semantic_code = offset_codes (semantic_code , self .offset_sizes )
152
- semantic_code = rearrange (semantic_code , ' b t q -> b (t q)' )
179
+ semantic_code = rearrange (semantic_code , " b t q -> b (t q)" )
153
180
154
181
ret_semantic_code = semantic_code .clone ().detach ()
155
182
@@ -169,6 +196,6 @@ def inference(
169
196
170
197
all_codes .append (out )
171
198
172
- all_codes = torch .cat (all_codes , dim = 1 ) # FIXME not tested
199
+ all_codes = torch .cat (all_codes , dim = 1 ) # FIXME not tested
173
200
out = self .dualcodec_inference_obj .decode (all_codes )
174
201
return out
0 commit comments