diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc index cff6b81d9e..9db537c0f8 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc @@ -62,34 +62,34 @@ static OfflineTransducerDecoderResult DecodeOne( encoder_shape.data(), encoder_shape.size()); for (int32_t q = 0; q != max_symbols_per_frame; ++q) { - Ort::Value logit = model->RunJoiner(View(&cur_encoder_out), - View(&decoder_output_pair.first)); - - float *p_logit = logit.GetTensorMutableData(); - if (blank_penalty > 0) { - p_logit[blank_id] -= blank_penalty; - } - - auto y = static_cast(std::distance( - static_cast(p_logit), - std::max_element(static_cast(p_logit), - static_cast(p_logit) + vocab_size))); - - if (y != blank_id) { - ans.tokens.push_back(y); - ans.timestamps.push_back(t); - - decoder_input_pair = BuildDecoderInput(y, model->Allocator()); - - decoder_output_pair = - model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_input_pair.second), - std::move(decoder_output_pair.second)); - } else { - break; - } // if (y != blank_id) + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out), + View(&decoder_output_pair.first)); + + float *p_logit = logit.GetTensorMutableData(); + if (blank_penalty > 0) { + p_logit[blank_id] -= blank_penalty; + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + + if (y != blank_id) { + ans.tokens.push_back(y); + ans.timestamps.push_back(t); + + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + + decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + std::move(decoder_output_pair.second)); + } else { + break; + } // if (y != blank_id) } - } // for (int32_t i = 0; i != num_rows; ++i) + } // for (int32_t i = 0; i != num_rows; ++i) return ans; } @@ -104,14 +104,24 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( int32_t dim1 = static_cast(shape[1]); int32_t dim2 = static_cast(shape[2]); - const int32_t *p_length = encoder_out_length.GetTensorData(); + auto length_type = + encoder_out_length.GetTensorTypeAndShapeInfo().GetElementType(); + if ((length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) && + (length_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) { + SHERPA_ONNX_LOGE("Unsupported encoder_out_length data type: %d", + static_cast(length_type)); + SHERPA_ONNX_EXIT(-1); + } + const float *p = encoder_out.GetTensorData(); std::vector ans(batch_size); for (int32_t i = 0; i != batch_size; ++i) { const float *this_p = p + dim1 * dim2 * i; - int32_t this_len = p_length[i]; + int32_t this_len = length_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 + ? encoder_out_length.GetTensorData()[i] + : encoder_out_length.GetTensorData()[i]; ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); }