Skip to content

Commit 154e366

Browse files
committed
update training scripts.
1 parent 7c7dda3 commit 154e366

File tree

9 files changed

+123
-73
lines changed

9 files changed

+123
-73
lines changed

egs/aishell/s10/chain/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def main():
3434
output_dim=args.output_dim,
3535
lda_mat_filename=args.lda_mat_filename,
3636
hidden_dim=args.hidden_dim,
37-
kernel_size_list=args.kernel_size_list,
38-
stride_list=args.stride_list)
37+
bottleneck_dim=args.bottleneck_dim,
38+
time_stride_list=args.time_stride_list,
39+
conv_stride_list=args.conv_stride_list)
3940

4041
load_checkpoint(args.checkpoint, model)
4142

egs/aishell/s10/chain/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ def forward(self, x):
201201

202202
return nnet_output, xent_output
203203

204+
def constrain_orthonormal(self):
205+
for i in range(len(self.tdnnfs)):
206+
self.tdnnfs[i].constrain_orthonormal()
207+
208+
self.prefinal_l.constrain_orthonormal()
209+
self.prefinal_chain.constrain_orthonormal()
210+
self.prefinal_xent.constrain_orthonormal()
211+
204212

205213
if __name__ == '__main__':
206214
feat_dim = 43
@@ -212,3 +220,4 @@ def forward(self, x):
212220
x = torch.arange(N * T * C).reshape(N, T, C).float()
213221
nnet_output, xent_output = model(x)
214222
print(x.shape, nnet_output.shape, xent_output.shape)
223+
model.constrain_orthonormal()

egs/aishell/s10/chain/options.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,19 @@ def _check_args(args):
129129
assert args.feat_dim > 0
130130
assert args.output_dim > 0
131131
assert args.hidden_dim > 0
132+
assert args.bottleneck_dim > 0
132133

133-
assert args.kernel_size_list is not None
134-
assert len(args.kernel_size_list) > 0
134+
assert args.time_stride_list is not None
135+
assert len(args.time_stride_list) > 0
135136

136-
assert args.stride_list is not None
137-
assert len(args.stride_list) > 0
137+
assert args.conv_stride_list is not None
138+
assert len(args.conv_stride_list) > 0
138139

139-
args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
140+
args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]
140141

141-
args.stride_list = [int(k) for k in args.stride_list.split(', ')]
142+
args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]
142143

143-
assert len(args.kernel_size_list) == len(args.stride_list)
144+
assert len(args.time_stride_list) == len(args.conv_stride_list)
144145

145146
assert args.log_level in ['debug', 'info', 'warning']
146147

@@ -195,15 +196,21 @@ def get_args():
195196
required=True,
196197
type=int)
197198

198-
parser.add_argument('--kernel-size-list',
199-
dest='kernel_size_list',
200-
help='kernel size list',
199+
parser.add_argument('--bottleneck-dim',
200+
dest='bottleneck_dim',
201+
help='nn bottleneck dimension',
202+
required=True,
203+
type=int)
204+
205+
parser.add_argument('--time-stride-list',
206+
dest='time_stride_list',
207+
help='time stride list',
201208
required=True,
202209
type=str)
203210

204-
parser.add_argument('--stride-list',
205-
dest='stride_list',
206-
help='stride list',
211+
parser.add_argument('--conv-stride-list',
212+
dest='conv_stride_list',
213+
help='conv stride list',
207214
required=True,
208215
type=str)
209216

egs/aishell/s10/chain/tdnnf_layer.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn.functional as F
99

1010

11-
def _constraint_orthonormal_internal(M):
11+
def _constrain_orthonormal_internal(M):
1212
'''
1313
Refer to
1414
void ConstrainOrthonormalInternal(BaseFloat scale, CuMatrixBase<BaseFloat> *M)
@@ -58,7 +58,7 @@ def __init__(self, dim, bottleneck_dim, time_stride):
5858
assert time_stride in [0, 1]
5959
# WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer
6060
# and [0, 1] for the second affine layer;
61-
# We use [-1, 0, 1] for the first linear layer
61+
# we use [-1, 0, 1] for the first linear layer if time_stride == 1
6262

6363
if time_stride == 0:
6464
kernel_size = 1
@@ -79,7 +79,7 @@ def forward(self, x):
7979
x = self.conv(x)
8080
return x
8181

82-
def constraint_orthonormal(self):
82+
def constrain_orthonormal(self):
8383
state_dict = self.conv.state_dict()
8484
w = state_dict['weight']
8585
# w is of shape [out_channels, in_channels, kernel_size]
@@ -97,7 +97,7 @@ def constraint_orthonormal(self):
9797
w = w.t()
9898
need_transpose = True
9999

100-
w = _constraint_orthonormal_internal(w)
100+
w = _constrain_orthonormal_internal(w)
101101

102102
if need_transpose:
103103
w = w.t()
@@ -142,6 +142,9 @@ def forward(self, x):
142142

143143
return x
144144

145+
def constrain_orthonormal(self):
146+
self.linear.constrain_orthonormal()
147+
145148

146149
class FactorizedTDNN(nn.Module):
147150
'''
@@ -175,6 +178,8 @@ def __init__(self,
175178
time_stride=time_stride)
176179

177180
# affine requires [N, C, T]
181+
# WARNING(fangjun): we do not use nn.Linear here
182+
# since we want to use `stride`
178183
self.affine = nn.Conv1d(in_channels=bottleneck_dim,
179184
out_channels=dim,
180185
kernel_size=1,
@@ -191,31 +196,34 @@ def forward(self, x):
191196
input_x = x
192197

193198
x = self.linear(x)
199+
194200
# at this point, x is [N, C, T]
195201

196202
x = self.affine(x)
203+
197204
# at this point, x is [N, C, T]
198205

199206
x = F.relu(x)
207+
200208
# at this point, x is [N, C, T]
201209

202210
x = self.batchnorm(x)
211+
203212
# at this point, x is [N, C, T]
204213

205214
# TODO(fangjun): implement GeneralDropoutComponent in PyTorch
206215

207-
# at this point, x is [N, C, T]
208216
if self.linear.kernel_size == 3:
209217
x = self.bypass_scale * input_x[:, :, 1:-1:self.conv_stride] + x
210218
else:
211219
x = self.bypass_scale * input_x[:, :, ::self.conv_stride] + x
212220
return x
213221

214-
def constraint_orthonormal(self):
215-
self.linear.constraint_orthonormal()
222+
def constrain_orthonormal(self):
223+
self.linear.constrain_orthonormal()
216224

217225

218-
def _test_constraint_orthonormal():
226+
def _test_constrain_orthonormal():
219227

220228
def compute_loss(M):
221229
P = torch.mm(M, M.t())
@@ -238,7 +246,7 @@ def compute_loss(M):
238246
loss.append(compute_loss(w))
239247

240248
for i in range(15):
241-
w = _constraint_orthonormal_internal(w)
249+
w = _constrain_orthonormal_internal(w)
242250
loss.append(compute_loss(w))
243251

244252
for i in range(1, len(loss)):
@@ -252,11 +260,11 @@ def compute_loss(M):
252260
time_stride=1,
253261
conv_stride=3)
254262
loss = []
255-
model.constraint_orthonormal()
263+
model.constrain_orthonormal()
256264
loss.append(
257265
compute_loss(model.linear.conv.state_dict()['weight'].reshape(128, -1)))
258266
for i in range(5):
259-
model.constraint_orthonormal()
267+
model.constrain_orthonormal()
260268
loss.append(
261269
compute_loss(model.linear.conv.state_dict()['weight'].reshape(
262270
128, -1)))
@@ -308,4 +316,4 @@ def _test_factorized_tdnn():
308316
if __name__ == '__main__':
309317
torch.manual_seed(20200130)
310318
_test_factorized_tdnn()
311-
_test_constraint_orthonormal()
319+
_test_constrain_orthonormal()

egs/aishell/s10/chain/train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# disable warnings when loading tensorboard
1212
warnings.simplefilter(action='ignore', category=FutureWarning)
1313

14+
import numpy as np
1415
import torch
1516
import torch.optim as optim
1617
from torch.nn.utils import clip_grad_value_
@@ -84,6 +85,11 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion,
8485
total_weight += objf_l2_term_weight[2].item()
8586
num_frames = nnet_output.shape[0]
8687
total_frames += num_frames
88+
89+
if np.random.choice(4) == 0:
90+
with torch.no_grad():
91+
model.constraint_orthonormal()
92+
8793
if batch_idx % 100 == 0:
8894
logging.info(
8995
'Process {}/{}({:.6f}%) global average objf: {:.6f} over {} '
@@ -135,8 +141,9 @@ def main():
135141
output_dim=args.output_dim,
136142
lda_mat_filename=args.lda_mat_filename,
137143
hidden_dim=args.hidden_dim,
138-
kernel_size_list=args.kernel_size_list,
139-
stride_list=args.stride_list)
144+
bottleneck_dim=args.bottleneck_dim,
145+
time_stride_list=args.time_stride_list,
146+
conv_stride_list=args.conv_stride_list)
140147

141148
start_epoch = 0
142149
num_epochs = args.num_epochs

egs/aishell/s10/conf/mfcc_hires.conf

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# config for high-resolution MFCC features, intended for neural network training.
2+
# Note: we keep all cepstra, so it has the same info as filterbank features,
3+
# but MFCC is more easily compressible (because less correlated) which is why
4+
# we prefer this method.
5+
--use-energy=false # use average of log energy, not energy.
6+
--sample-frequency=16000 # AISHELL-2 is sampled at 16kHz
7+
--num-mel-bins=40 # similar to Google's setup.
8+
--num-ceps=40 # there is no dimensionality reduction.
9+
--low-freq=20 # low cutoff frequency for mel bins
10+
--high-freq=-400 # high cutoff frequency, relative to Nyquist of 8000 (=7600)

egs/aishell/s10/local/run_chain.sh

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ stage=0
99

1010
# GPU device id to use (count from 0).
1111
# you can also set `CUDA_VISIBLE_DEVICES` and set `device_id=0`
12-
device_id=0
12+
device_id=6
1313

1414
nj=10
1515

@@ -19,8 +19,8 @@ lat_dir=exp/tri5a_lats # input lat dir
1919
treedir=exp/chain/tri5_tree # output tree dir
2020

2121
# You should know how to calculate your model's left/right context **manually**
22-
model_left_context=12
23-
model_right_context=12
22+
model_left_context=28
23+
model_right_context=28
2424
egs_left_context=$[$model_left_context + 1]
2525
egs_right_context=$[$model_right_context + 1]
2626
frames_per_eg=150,110,90
@@ -30,9 +30,10 @@ minibatch_size=128
3030
num_epochs=6
3131
lr=1e-3
3232

33-
hidden_dim=625
34-
kernel_size_list="1, 3, 3, 3, 3, 3" # comma separated list
35-
stride_list="1, 1, 3, 1, 1, 1" # comma separated list
33+
hidden_dim=1024
34+
bottleneck_dim=128
35+
time_stride_list="1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list
36+
conv_stride_list="1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list
3637

3738
log_level=info # valid values: debug, info, warning
3839

@@ -47,11 +48,16 @@ save_nn_output_as_compressed=false
4748

4849
if [[ $stage -le 0 ]]; then
4950
for datadir in train dev test; do
50-
dst_dir=data/fbank_pitch/$datadir
51+
dst_dir=data/mfcc_hires/$datadir
5152
if [[ ! -f $dst_dir/feats.scp ]]; then
53+
echo "making mfcc-pitch features for LF-MMI training"
5254
utils/copy_data_dir.sh data/$datadir $dst_dir
53-
echo "making fbank-pitch features for LF-MMI training"
54-
steps/make_fbank_pitch.sh --cmd $train_cmd --nj $nj $dst_dir || exit 1
55+
steps/make_mfcc_pitch.sh \
56+
--mfcc-config conf/mfcc_hires.conf \
57+
--pitch-config conf/pitch.conf \
58+
--cmd "$train_cmd" \
59+
--nj $nj \
60+
$dst_dir || exit 1
5561
steps/compute_cmvn_stats.sh $dst_dir || exit 1
5662
utils/fix_data_dir.sh $dst_dir
5763
else
@@ -80,12 +86,12 @@ if [[ $stage -le 2 ]]; then
8086
# step compared with other recipes.
8187
steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \
8288
--context-opts "--context-width=2 --central-position=1" \
83-
--cmd "$train_cmd" 5000 data/train $lang $ali_dir $treedir
89+
--cmd "$train_cmd" 5000 data/mfcc/train $lang $ali_dir $treedir
8490
fi
8591

8692
if [[ $stage -le 3 ]]; then
8793
echo "creating phone language-model"
88-
$train_cmd exp/chain/log/make_phone_lm.log \
94+
"$train_cmd" exp/chain/log/make_phone_lm.log \
8995
chain-est-phone-lm \
9096
"ark:gunzip -c $treedir/ali.*.gz | ali-to-phones $treedir/final.mdl ark:- ark:- |" \
9197
exp/chain/phone_lm.fst || exit 1
@@ -95,7 +101,7 @@ if [[ $stage -le 4 ]]; then
95101
echo "creating denominator FST"
96102
copy-transition-model $treedir/final.mdl exp/chain/0.trans_mdl
97103
cp $treedir/tree exp/chain
98-
$train_cmd exp/chain/log/make_den_fst.log \
104+
"$train_cmd" exp/chain/log/make_den_fst.log \
99105
chain-make-den-fst exp/chain/tree exp/chain/0.trans_mdl exp/chain/phone_lm.fst \
100106
exp/chain/den.fst exp/chain/normalization.fst || exit 1
101107
fi
@@ -119,7 +125,7 @@ if [[ $stage -le 5 ]]; then
119125
--right-tolerance 5 \
120126
--srand 0 \
121127
--stage -10 \
122-
data/fbank_pitch/train \
128+
data/mfcc_hires/train \
123129
exp/chain $lat_dir exp/chain/egs
124130
fi
125131

@@ -157,16 +163,17 @@ if [[ $stage -le 8 ]]; then
157163

158164
# sort the options alphabetically
159165
python3 ./chain/train.py \
166+
--bottleneck-dim $bottleneck_dim \
160167
--checkpoint=${train_checkpoint:-} \
168+
--conv-stride-list "$conv_stride_list" \
161169
--device-id $device_id \
162170
--dir exp/chain/train \
163171
--feat-dim $feat_dim \
164172
--hidden-dim $hidden_dim \
165173
--is-training true \
166-
--kernel-size-list "$kernel_size_list" \
167174
--log-level $log_level \
168175
--output-dim $output_dim \
169-
--stride-list "$stride_list" \
176+
--time-stride-list "$time_stride_list" \
170177
--train.cegs-dir exp/chain/merged_egs \
171178
--train.den-fst exp/chain/den.fst \
172179
--train.egs-left-context $egs_left_context \
@@ -186,20 +193,21 @@ if [[ $stage -le 9 ]]; then
186193
best_epoch=$(cat exp/chain/train/best-epoch-info | grep 'best epoch' | awk '{print $NF}')
187194
inference_checkpoint=exp/chain/train/epoch-${best_epoch}.pt
188195
python3 ./chain/inference.py \
196+
--bottleneck-dim $bottleneck_dim \
189197
--checkpoint $inference_checkpoint \
198+
--conv-stride-list "$conv_stride_list" \
190199
--device-id $device_id \
191200
--dir exp/chain/inference/$x \
192201
--feat-dim $feat_dim \
193-
--feats-scp data/fbank_pitch/$x/feats.scp \
202+
--feats-scp data/mfcc_hires/$x/feats.scp \
194203
--hidden-dim $hidden_dim \
195204
--is-training false \
196-
--kernel-size-list "$kernel_size_list" \
197205
--log-level $log_level \
198206
--model-left-context $model_left_context \
199207
--model-right-context $model_right_context \
200208
--output-dim $output_dim \
201209
--save-as-compressed $save_nn_output_as_compressed \
202-
--stride-list "$stride_list" || exit 1
210+
--time-stride-list "$time_stride_list" || exit 1
203211
fi
204212
done
205213
fi
@@ -228,7 +236,7 @@ if [[ $stage -le 11 ]]; then
228236

229237
for x in test dev; do
230238
./local/score.sh --cmd "$decode_cmd" \
231-
data/fbank_pitch/$x \
239+
data/mfcc_hires/$x \
232240
exp/chain/graph \
233241
exp/chain/decode_res/$x || exit 1
234242
done

0 commit comments

Comments
 (0)