Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static OfflineTransducerDecoderResult DecodeOne(

int32_t vocab_size = model->VocabSize();
int32_t blank_id = vocab_size - 1;
int32_t max_symbols_per_frame = 10;

auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());

Expand All @@ -60,30 +61,34 @@ static OfflineTransducerDecoderResult DecodeOne(
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
encoder_shape.data(), encoder_shape.size());

Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
View(&decoder_output_pair.first));

float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));

float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));

if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);

decoder_input_pair = BuildDecoderInput(y, model->Allocator());

decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} else {
break;
} // if (y != blank_id)
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));

if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);

decoder_input_pair = BuildDecoderInput(y, model->Allocator());

decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} // if (y != blank_id)
} // for (int32_t i = 0; i != num_rows; ++i)

return ans;
Expand All @@ -99,7 +104,7 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);

const int64_t *p_length = encoder_out_length.GetTensorData<int64_t>();
const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>();
const float *p = encoder_out.GetTensorData<float>();

std::vector<OfflineTransducerDecoderResult> ans(batch_size);
Expand Down
Loading