-
Notifications
You must be signed in to change notification settings - Fork 903
Limit number of tokens in fire red asr decoding. #2459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe Decode method in the ASR decoder classes and their implementations was updated to accept an additional integer parameter representing the number of feature frames. Corresponding method signatures and invocations were updated, and the decoding loop now uses this parameter to determine the maximum number of tokens processed. Changes
Sequence Diagram(s)sequenceDiagram
participant RecognizerImpl
participant Decoder
participant GreedySearchDecoder
RecognizerImpl->>Decoder: Decode(cross_k, cross_v, num_frames)
alt GreedySearchDecoder implementation
Decoder->>GreedySearchDecoder: Decode(cross_k, cross_v, num_feature_frames)
GreedySearchDecoder->>GreedySearchDecoder: Calculate num_possible_tokens
GreedySearchDecoder->>GreedySearchDecoder: Run decoding loop (up to num_possible_tokens)
GreedySearchDecoder-->>Decoder: Return results
end
Decoder-->>RecognizerImpl: Return results
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
sherpa-onnx/csrc/offline-fire-red-asr-decoder.h (1)
34-35
: Update documentation to include the new parameter.The method signature correctly adds the
num_feature_frames
parameter, but the documentation comment (lines 24-32) should be updated to describe this new parameter.Apply this diff to update the documentation:
/** Run beam search given the output from the FireRedAsr encoder model. * * @param n_layer_cross_k A 4-D tensor of shape * (num_decoder_layers, N, T, d_model). * @param n_layer_cross_v A 4-D tensor of shape * (num_decoder_layers, N, T, d_model). + * @param num_feature_frames The number of feature frames in the input. * * @return Return a vector of size `N` containing the decoded results. */
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
(1 hunks)sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
(2 hunks)sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
(1 hunks)sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
(1 hunks)
🔇 Additional comments (2)
sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h (1)
21-22
: LGTM!The method signature correctly overrides the base class interface with the new
num_feature_frames
parameter. Theoverride
specifier ensures compile-time verification of the signature match.sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc (1)
19-20
: LGTM!The method signature correctly implements the updated interface with the new
num_feature_frames
parameter.
// assume at most 6 tokens per second | ||
int32_t num_possible_tokens = num_feature_frames / 100 * 6; | ||
num_possible_tokens = | ||
std::min<int32_t>(num_possible_tokens, meta_data.max_len / 2); | ||
|
||
for (int32_t i = 0; i < num_possible_tokens; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve the token calculation logic with better documentation and validation.
The dynamic token limit calculation has several areas for improvement:
- The hardcoded constants (100 Hz frame rate, 6 tokens/second) should be documented or made configurable
- Missing validation for
num_feature_frames
- Potential integer overflow for very large frame counts
Apply this diff to improve the implementation:
- // assume at most 6 tokens per second
- int32_t num_possible_tokens = num_feature_frames / 100 * 6;
- num_possible_tokens =
- std::min<int32_t>(num_possible_tokens, meta_data.max_len / 2);
+ // Calculate maximum tokens based on feature frames
+ // Assumptions: 100 Hz feature frame rate (10ms), max 6 tokens per second
+ int32_t num_possible_tokens = 0;
+ if (num_feature_frames > 0) {
+ // Use int64_t to prevent overflow, then cast back
+ int64_t tokens_estimate = static_cast<int64_t>(num_feature_frames) * 6 / 100;
+ num_possible_tokens = static_cast<int32_t>(
+ std::min(tokens_estimate, static_cast<int64_t>(meta_data.max_len / 2)));
+ }
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
// assume at most 6 tokens per second | |
int32_t num_possible_tokens = num_feature_frames / 100 * 6; | |
num_possible_tokens = | |
std::min<int32_t>(num_possible_tokens, meta_data.max_len / 2); | |
for (int32_t i = 0; i < num_possible_tokens; ++i) { | |
// Calculate maximum tokens based on feature frames | |
// Assumptions: 100 Hz feature frame rate (10ms), max 6 tokens per second | |
int32_t num_possible_tokens = 0; | |
if (num_feature_frames > 0) { | |
// Use int64_t to prevent overflow, then cast back | |
int64_t tokens_estimate = | |
static_cast<int64_t>(num_feature_frames) * 6 / 100; | |
num_possible_tokens = static_cast<int32_t>( | |
std::min(tokens_estimate, | |
static_cast<int64_t>(meta_data.max_len / 2))); | |
} | |
for (int32_t i = 0; i < num_possible_tokens; ++i) { |
🤖 Prompt for AI Agents
In sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc around lines
57 to 62, improve the token calculation by adding comments explaining the
constants used (100 Hz frame rate and 6 tokens per second) or making them
configurable variables. Add validation to ensure num_feature_frames is positive
before using it in calculations. Also, implement checks or use safe arithmetic
to prevent integer overflow when computing num_possible_tokens for very large
num_feature_frames values.
auto results = decoder_->Decode(std::move(cross_kv.first), | ||
std::move(cross_kv.second), num_frames); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix potential type mismatch between num_frames and expected parameter type.
The Decode
method expects an int32_t num_feature_frames
parameter, but num_frames
is declared as int64_t
on line 109. This could cause issues on systems where these types differ in size.
Apply this diff to fix the type mismatch:
- int64_t num_frames = f.size() / feat_dim;
+ int32_t num_frames = static_cast<int32_t>(f.size() / feat_dim);
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
auto results = decoder_->Decode(std::move(cross_kv.first), | |
std::move(cross_kv.second), num_frames); | |
// Compute number of feature frames | |
int32_t num_frames = static_cast<int32_t>(f.size() / feat_dim); | |
auto results = decoder_->Decode(std::move(cross_kv.first), | |
std::move(cross_kv.second), num_frames); |
🤖 Prompt for AI Agents
In sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h around lines 109 and
122-123, the variable num_frames is declared as int64_t but passed to Decode
which expects int32_t. Change the declaration of num_frames from int64_t to
int32_t to match the expected parameter type and avoid potential type mismatch
issues.
To avoid the following case reported by one of our users.
Summary by CodeRabbit
New Features
Bug Fixes