Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,34 @@ static OfflineTransducerDecoderResult DecodeOne(
encoder_shape.data(), encoder_shape.size());

for (int32_t q = 0; q != max_symbols_per_frame; ++q) {
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));

float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));

if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);

decoder_input_pair = BuildDecoderInput(y, model->Allocator());

decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} else {
break;
} // if (y != blank_id)
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));

float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));

if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);

decoder_input_pair = BuildDecoderInput(y, model->Allocator());

decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} else {
break;
} // if (y != blank_id)
}
} // for (int32_t i = 0; i != num_rows; ++i)
} // for (int32_t i = 0; i != num_rows; ++i)

return ans;
}
Expand All @@ -104,14 +104,24 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);

const int32_t *p_length = encoder_out_length.GetTensorData<int32_t>();
auto length_type =
encoder_out_length.GetTensorTypeAndShapeInfo().GetElementType();
if ((length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) &&
(length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) {
SHERPA_ONNX_LOGE("Unsupported encoder_out_length data type: %d",
static_cast<int32_t>(length_type));
SHERPA_ONNX_EXIT(-1);
}

const float *p = encoder_out.GetTensorData<float>();

std::vector<OfflineTransducerDecoderResult> ans(batch_size);

for (int32_t i = 0; i != batch_size; ++i) {
const float *this_p = p + dim1 * dim2 * i;
int32_t this_len = p_length[i];
int32_t this_len = length_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
? encoder_out_length.GetTensorData<int32_t>()[i]
: encoder_out_length.GetTensorData<int64_t>()[i];
Comment on lines 118 to +124
Copy link

Copilot AI Jul 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor data is being accessed inside the loop for each iteration. Consider caching the tensor data pointer outside the loop to avoid repeated GetTensorData calls.

Copilot uses AI. Check for mistakes.


ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
}
Expand Down
Loading