Skip to content

Commit 1121c31

Browse files
authored
[src] Batched spectral feature extraction on GPU (#3889)
1 parent 48d2115 commit 1121c31

7 files changed

+1375
-5
lines changed

src/cudafeat/Makefile

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ ifeq ($(CUDA), true)
88
TESTFILES =
99

1010
ifeq ($(CUDA), true)
11-
OBJFILES += feature-window-cuda.o feature-spectral-cuda.o feature-online-cmvn-cuda.o \
12-
online-ivector-feature-cuda-kernels.o online-ivector-feature-cuda.o \
13-
online-cuda-feature-pipeline.o feature-online-batched-cmvn-cuda.o \
14-
feature-online-batched-cmvn-cuda-kernels.o
11+
OBJFILES += feature-window-cuda.o feature-spectral-cuda.o \
12+
feature-online-cmvn-cuda.o feature-online-batched-spectral-cuda.o \
13+
feature-spectral-batched-kernels.o \
14+
online-ivector-feature-cuda-kernels.o online-ivector-feature-cuda.o \
15+
online-cuda-feature-pipeline.o feature-online-batched-cmvn-cuda.o \
16+
feature-online-batched-cmvn-cuda-kernels.o
1517
endif
1618

1719
LIBNAME = kaldi-cudafeat
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// cudafeature/feature-online-batched-spectral-cuda.cc
2+
//
3+
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
4+
// Justin Luitjens, Levi Barnes
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#include "cudafeat/feature-online-batched-spectral-cuda.h"
19+
#include "cudafeat/feature-spectral-batched-kernels.h"
20+
21+
namespace kaldi {
22+
23+
CudaOnlineBatchedSpectralFeatures::CudaOnlineBatchedSpectralFeatures(
24+
const CudaSpectralFeatureOptions &opts, int32_t max_chunk_frames,
25+
int32_t num_channels, int32_t max_lanes)
26+
: MfccComputer(opts.mfcc_opts),
27+
cu_lifter_coeffs_(lifter_coeffs_),
28+
cu_dct_matrix_(dct_matrix_),
29+
window_function_(opts.mfcc_opts.frame_opts),
30+
max_chunk_frames_(max_chunk_frames),
31+
num_channels_(num_channels),
32+
max_lanes_(max_lanes) {
33+
KALDI_ASSERT(max_chunk_frames > 0);
34+
const MelBanks *mel_banks = GetMelBanks(1.0);
35+
const std::vector<std::pair<int32, Vector<BaseFloat>>> &bins =
36+
mel_banks->GetBins();
37+
int size = bins.size();
38+
bin_size_ = size;
39+
std::vector<int32> offsets(size), sizes(size);
40+
std::vector<float *> vecs(size);
41+
cu_vecs_ = new CuVector<float>[size];
42+
for (int i = 0; i < bins.size(); i++) {
43+
cu_vecs_[i].Resize(bins[i].second.Dim(), kUndefined);
44+
cu_vecs_[i].CopyFromVec(bins[i].second);
45+
vecs[i] = cu_vecs_[i].Data();
46+
sizes[i] = cu_vecs_[i].Dim();
47+
offsets[i] = bins[i].first;
48+
}
49+
offsets_ = static_cast<int32 *>(
50+
CuDevice::Instantiate().Malloc(size * sizeof(int32)));
51+
sizes_ = static_cast<int32 *>(
52+
CuDevice::Instantiate().Malloc(size * sizeof(int32)));
53+
vecs_ = static_cast<float **>(
54+
CuDevice::Instantiate().Malloc(size * sizeof(float *)));
55+
56+
CU_SAFE_CALL(cudaMemcpyAsync(vecs_, &vecs[0], size * sizeof(float *),
57+
cudaMemcpyHostToDevice, cudaStreamPerThread));
58+
CU_SAFE_CALL(cudaMemcpyAsync(offsets_, &offsets[0], size * sizeof(int32),
59+
cudaMemcpyHostToDevice, cudaStreamPerThread));
60+
CU_SAFE_CALL(cudaMemcpyAsync(sizes_, &sizes[0], size * sizeof(int32),
61+
cudaMemcpyHostToDevice, cudaStreamPerThread));
62+
CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread));
63+
64+
const FrameExtractionOptions frame_opts = opts.mfcc_opts.frame_opts;
65+
frame_length_ = frame_opts.WindowSize();
66+
padded_length_ = frame_opts.PaddedWindowSize();
67+
fft_length_ = padded_length_ / 2; // + 1;
68+
fft_batch_size_ = 800;
69+
70+
// place holders to get strides for cufft. these will be resized correctly
71+
// later. The +2 for cufft/fftw requirements of an extra element at the end.
72+
// Turning off stride because cufft seems buggy with a stride
73+
int32_t fft_num_frames =
74+
max_chunk_frames +
75+
(fft_batch_size_ - max_chunk_frames_ % fft_batch_size_);
76+
cu_windows_.Resize(fft_num_frames * max_lanes_, padded_length_, kUndefined,
77+
kStrideEqualNumCols);
78+
//+1 matches cufft/fftw requirements
79+
tmp_window_.Resize(fft_num_frames * max_lanes_, padded_length_ + 2,
80+
kUndefined, kStrideEqualNumCols);
81+
82+
// Pre-allocated memory for power spectra
83+
power_spectrum_.Resize(max_chunk_frames_ * max_lanes_, padded_length_ / 2 + 1,
84+
kUndefined);
85+
raw_log_energies_.Resize(max_lanes_, max_chunk_frames_, kUndefined);
86+
cu_mel_energies_.Resize(max_chunk_frames_ * max_lanes_, bin_size_,
87+
kUndefined);
88+
int32_t max_stash_size =
89+
2 * (frame_opts.WindowSize() / 2 + frame_opts.WindowShift());
90+
stash_.Resize(num_channels_, max_stash_size);
91+
92+
stride_ = cu_windows_.Stride();
93+
tmp_stride_ = tmp_window_.Stride();
94+
95+
cufftPlanMany(&plan_, 1, &padded_length_, NULL, 1, stride_, NULL, 1,
96+
tmp_stride_ / 2, CUFFT_R2C, fft_batch_size_);
97+
cufftSetStream(plan_, cudaStreamPerThread);
98+
cumfcc_opts_ = opts;
99+
}
100+
101+
// ExtractWindow extracts a windowed frame of waveform with a power-of-two,
102+
// padded size. It does mean subtraction, pre-emphasis and dithering as
103+
// requested.
104+
void CudaOnlineBatchedSpectralFeatures::ExtractWindowsBatched(
105+
const LaneDesc *lanes, int32_t num_lanes,
106+
const CuMatrixBase<BaseFloat> &wave) {
107+
CU_SAFE_CALL(cudaGetLastError());
108+
const FrameExtractionOptions &opts = GetFrameOptions();
109+
cuda_extract_window(
110+
lanes, num_lanes, max_chunk_frames_, opts.WindowShift(),
111+
opts.WindowSize(), opts.PaddedWindowSize(), opts.snip_edges, wave.Data(),
112+
wave.Stride(), cu_windows_.Data(), opts.WindowSize(),
113+
cu_windows_.Stride(), stash_.Data(), stash_.NumCols(), stash_.Stride());
114+
}
115+
116+
void CudaOnlineBatchedSpectralFeatures::ProcessWindowsBatched(
117+
const LaneDesc *lanes, int32_t num_lanes,
118+
const FrameExtractionOptions &opts,
119+
CuMatrixBase<BaseFloat> *log_energy_pre_window) {
120+
int fft_num_frames = cu_windows_.NumRows();
121+
KALDI_ASSERT(fft_num_frames % fft_batch_size_ == 0);
122+
123+
cuda_process_window(
124+
lanes, num_lanes, max_chunk_frames_, frame_length_, opts.dither,
125+
std::numeric_limits<float>::epsilon(), opts.remove_dc_offset,
126+
opts.preemph_coeff, NeedRawLogEnergy(), log_energy_pre_window->Data(),
127+
log_energy_pre_window->Stride(), window_function_.cu_window.Data(),
128+
tmp_window_.Data(), tmp_window_.Stride(), cu_windows_.Data(),
129+
cu_windows_.Stride());
130+
131+
CU_SAFE_CALL(cudaGetLastError());
132+
}
133+
134+
void CudaOnlineBatchedSpectralFeatures::UpdateStashBatched(
135+
const LaneDesc *lanes, int32_t num_lanes,
136+
const CuMatrixBase<BaseFloat> &wave) {
137+
KALDI_ASSERT(stash_.NumCols() < 1024);
138+
139+
cuda_update_stash(lanes, num_lanes, wave.Data(), wave.Stride(), stash_.Data(),
140+
stash_.NumCols(), stash_.Stride());
141+
}
142+
143+
void CudaOnlineBatchedSpectralFeatures::ComputeFinalFeaturesBatched(
144+
const LaneDesc *lanes, int32_t num_lanes, BaseFloat vtln_wrap,
145+
CuMatrix<BaseFloat> *cu_signal_log_energy,
146+
CuMatrix<BaseFloat> *cu_features) {
147+
MfccOptions mfcc_opts = cumfcc_opts_.mfcc_opts;
148+
Vector<float> tmp;
149+
KALDI_ASSERT(mfcc_opts.htk_compat == false);
150+
151+
if (num_lanes == 0) return;
152+
153+
if (mfcc_opts.use_energy && !mfcc_opts.raw_energy) {
154+
cuda_dot_log(max_chunk_frames_, num_lanes, cu_windows_.NumCols(),
155+
cu_windows_.Data(), cu_windows_.Stride(),
156+
cu_signal_log_energy->Data(), cu_signal_log_energy->Stride());
157+
CU_SAFE_CALL(cudaGetLastError());
158+
}
159+
160+
// make sure a reallocation hasn't changed these
161+
KALDI_ASSERT(cu_windows_.Stride() == stride_);
162+
KALDI_ASSERT(tmp_window_.Stride() == tmp_stride_);
163+
164+
// Perform FFTs in batches of fft_size. This reduces memory requirements
165+
for (int idx = 0; idx < max_chunk_frames_ * num_lanes;
166+
idx += fft_batch_size_) {
167+
CUFFT_SAFE_CALL(cufftExecR2C(
168+
plan_, cu_windows_.Data() + cu_windows_.Stride() * idx,
169+
(cufftComplex *)(tmp_window_.Data() + tmp_window_.Stride() * idx)));
170+
}
171+
172+
// Compute Power spectrum
173+
cuda_power_spectrum(max_chunk_frames_, num_lanes, padded_length_,
174+
tmp_window_.Data(), tmp_window_.Stride(),
175+
power_spectrum_.Data(), power_spectrum_.Stride(),
176+
cumfcc_opts_.use_power);
177+
CU_SAFE_CALL(cudaGetLastError());
178+
179+
// mel banks
180+
int num_bins = bin_size_;
181+
cuda_mel_banks_compute(lanes, num_lanes, max_chunk_frames_, num_bins,
182+
std::numeric_limits<float>::epsilon(), offsets_,
183+
sizes_, vecs_, power_spectrum_.Data(),
184+
power_spectrum_.Stride(), cu_mel_energies_.Data(),
185+
cu_mel_energies_.Stride(), cumfcc_opts_.use_log_fbank);
186+
CU_SAFE_CALL(cudaGetLastError());
187+
188+
// dct transform
189+
if (cumfcc_opts_.use_dct) {
190+
if (cu_features->NumRows() > cu_mel_energies_.NumRows()) {
191+
CuSubMatrix<BaseFloat> cu_feats_sub(*cu_features, 0,
192+
cu_mel_energies_.NumRows(), 0,
193+
cu_features->NumCols());
194+
cu_feats_sub.AddMatMat(1.0, cu_mel_energies_, kNoTrans, cu_dct_matrix_,
195+
kTrans, 0.0);
196+
} else {
197+
cu_features->AddMatMat(1.0, cu_mel_energies_, kNoTrans, cu_dct_matrix_,
198+
kTrans, 0.0);
199+
}
200+
cuda_apply_lifter_and_floor_energy(
201+
lanes, num_lanes, max_chunk_frames_, cu_features->NumCols(),
202+
mfcc_opts.cepstral_lifter, mfcc_opts.use_energy, mfcc_opts.energy_floor,
203+
cu_signal_log_energy->Data(), cu_signal_log_energy->Stride(),
204+
cu_lifter_coeffs_.Data(), cu_features->Data(), cu_features->Stride());
205+
206+
} else {
207+
cudaMemcpyAsync(cu_features->Data(), cu_mel_energies_.Data(),
208+
sizeof(BaseFloat) * max_chunk_frames_ * num_lanes *
209+
cu_features->Stride(),
210+
cudaMemcpyDeviceToDevice, cudaStreamPerThread);
211+
}
212+
CU_SAFE_CALL(cudaGetLastError());
213+
}
214+
215+
void CudaOnlineBatchedSpectralFeatures::ComputeFeaturesBatched(
216+
const LaneDesc *lanes, int32_t n_lanes,
217+
const CuMatrixBase<BaseFloat> &cu_wave_in, BaseFloat sample_freq,
218+
BaseFloat vtln_warp, CuMatrix<BaseFloat> *cu_feats_out) {
219+
// Note: cu_features is actually a rank 3 tensor.
220+
// channels x frames x features
221+
// it is currently represented as a matrix with n_channels*n_frames rows and
222+
// n_features cols
223+
const FrameExtractionOptions &frame_opts = GetFrameOptions();
224+
225+
if (frame_opts.dither != 0.0f) {
226+
// Calling cu-rand directly
227+
// CuRand class works on CuMatrixBase which must
228+
// assume that the matrix is part of a larger matrix
229+
// Doing this directly avoids unecessary memory copies
230+
CURAND_SAFE_CALL(
231+
curandGenerateNormal(GetCurandHandle(), tmp_window_.Data(),
232+
tmp_window_.NumRows() * tmp_window_.Stride(),
233+
0.0 /*mean*/, 1.0 /*stddev*/));
234+
}
235+
236+
// Extract Windows
237+
ExtractWindowsBatched(lanes, n_lanes, cu_wave_in);
238+
239+
UpdateStashBatched(lanes, n_lanes, cu_wave_in);
240+
241+
// Process Windows
242+
ProcessWindowsBatched(lanes, n_lanes, frame_opts, &raw_log_energies_);
243+
244+
// Compute Features
245+
ComputeFinalFeaturesBatched(lanes, n_lanes, 1.0, &raw_log_energies_,
246+
cu_feats_out);
247+
}
248+
249+
CudaOnlineBatchedSpectralFeatures::~CudaOnlineBatchedSpectralFeatures() {
250+
delete[] cu_vecs_;
251+
CuDevice::Instantiate().Free(vecs_);
252+
CuDevice::Instantiate().Free(offsets_);
253+
CuDevice::Instantiate().Free(sizes_);
254+
cufftDestroy(plan_);
255+
}
256+
} // namespace kaldi
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// cudafeat/feature-batched-spectral-cuda.h
2+
//
3+
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
4+
// Justin Luitjens, Levi Barnes
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#ifndef KALDI_CUDAFEAT_FEATURE_BATCHED_SPECTRAL_CUDA_H_
19+
#define KALDI_CUDAFEAT_FEATURE_BATCHED_SPECTRAL_CUDA_H_
20+
21+
#if HAVE_CUDA == 1
22+
#include <cufft.h>
23+
#endif
24+
25+
#include "cudafeat/feature-spectral-cuda.h"
26+
#include "cudafeat/feature-window-cuda.h"
27+
#include "cudafeat/lane-desc.h"
28+
#include "cudamatrix/cu-matrix.h"
29+
#include "cudamatrix/cu-vector.h"
30+
#include "feat/feature-fbank.h"
31+
#include "feat/feature-mfcc.h"
32+
33+
namespace kaldi {
34+
// This class implements MFCC and Fbank computation in CUDA.
35+
// It handles batched input.
36+
// It takes input from device memory and outputs to
37+
// device memory. It also does no synchronization.
38+
class CudaOnlineBatchedSpectralFeatures : public MfccComputer {
39+
public:
40+
void ComputeFeatures(const CuVectorBase<BaseFloat> &cu_wave,
41+
BaseFloat sample_freq, BaseFloat vtln_warp,
42+
CuMatrix<BaseFloat> *cu_features) {
43+
// Non-batched processing not allowed from
44+
// CudaOnlineBatchedSpectralFeatures
45+
KALDI_ASSERT(false);
46+
}
47+
48+
void ComputeFeaturesBatched(const LaneDesc *lanes, int32_t n_lanes,
49+
const CuMatrixBase<BaseFloat> &cu_wave_in,
50+
BaseFloat sample_freq, BaseFloat vtln_warp,
51+
CuMatrix<BaseFloat> *cu_feats_out);
52+
53+
CudaOnlineBatchedSpectralFeatures(const CudaSpectralFeatureOptions &opts,
54+
int32_t max_chunk_frames,
55+
int32_t num_channels, int32_t max_lanes);
56+
~CudaOnlineBatchedSpectralFeatures();
57+
CudaSpectralFeatureOptions cumfcc_opts_;
58+
int32 Dim()
59+
// The dimension of the output is different for MFCC and Fbank.
60+
// This returns the appropriate value depending on the feature
61+
// extraction algorithm
62+
{
63+
if (cumfcc_opts_.feature_type == MFCC) return MfccComputer::Dim();
64+
// If we're running fbank, we need to set the dimension right
65+
else
66+
return cumfcc_opts_.mfcc_opts.mel_opts.num_bins +
67+
(cumfcc_opts_.mfcc_opts.use_energy ? 1 : 0);
68+
}
69+
70+
private:
71+
72+
void ExtractWindowsBatched(const LaneDesc *lanes, int32_t num_lanes,
73+
const CuMatrixBase<BaseFloat> &wave);
74+
75+
void UpdateStashBatched(const LaneDesc *lanes, int32_t num_lanes,
76+
const CuMatrixBase<BaseFloat> &wave);
77+
78+
void ProcessWindowsBatched(const LaneDesc *lanes, int32_t num_lanes,
79+
const FrameExtractionOptions &opts,
80+
CuMatrixBase<BaseFloat> *log_energy_pre_window);
81+
82+
void ComputeFinalFeaturesBatched(const LaneDesc *lanes, int32_t num_lanes,
83+
BaseFloat vtln_wrap,
84+
CuMatrix<BaseFloat> *cu_signal_log_energy,
85+
CuMatrix<BaseFloat> *cu_features);
86+
87+
CuVector<float> cu_lifter_coeffs_;
88+
CuMatrix<BaseFloat> cu_windows_;
89+
CuMatrix<float> tmp_window_, cu_mel_energies_;
90+
CuMatrix<float> cu_dct_matrix_;
91+
CuMatrix<BaseFloat> stash_;
92+
CuMatrix<BaseFloat> power_spectrum_;
93+
CuMatrix<BaseFloat> raw_log_energies_;
94+
95+
int frame_length_, padded_length_, fft_length_, fft_batch_size_;
96+
cufftHandle plan_;
97+
CudaFeatureWindowFunction window_function_;
98+
99+
int bin_size_;
100+
int32 *offsets_, *sizes_;
101+
CuVector<float> *cu_vecs_;
102+
float **vecs_;
103+
104+
// for sanity checking cufft
105+
int32_t stride_, tmp_stride_;
106+
107+
int32_t max_chunk_frames_;
108+
int32_t num_channels_;
109+
int32_t max_lanes_;
110+
};
111+
} // namespace kaldi
112+
113+
#endif

0 commit comments

Comments
 (0)