Skip to content

Fixes and Enhancements for Mamba Inference and Reference Implementations #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mohiuddin-khan-shiam
Copy link

This pull request addresses several bugs and limitations within the Mamba codebase, primarily aimed at improving inference robustness in the Mamba2 module and increasing the accuracy of reference implementations.

Key changes include:

  • In mamba_ssm/modules/mamba2.py:
    • Resolved an issue in _get_states_from_cache to correctly handle dynamic batch sizes during inference, ensuring proper state re-initialization when batch sizes change.
    • Removed the batch == 1 assertion in the forward method for variable-length sequence inference, enabling batched processing for these inputs.
    • Updated the fallback path in the step method to support ngroups > 1, allowing grouped SSM inference even if Triton kernels are not available.
  • In mamba_ssm/ops/selective_scan_interface.py:
    • Added optional RMS Normalization for B, C, and delta tensors within mamba_inner_ref to better match the main MambaInnerFn's behavior and improve numerical consistency.
    • Corrected a shape comment in selective_scan_ref for clarity.
  • In mamba_ssm/models/mixer_seq_simple.py:
    • Removed a redundant comment in the _init_weights function.
  • In mamba_ssm/utils/hf.py:
    • Addressed a bug in load_state_dict_hf to ensure correct dtype conversion and device placement when loading Hugging Face model weights.

These modifications enhance the stability, flexibility, and correctness of the Mamba library.

This pull request addresses several bugs and limitations within the Mamba codebase, primarily aimed at improving inference robustness in the Mamba2 module and increasing the accuracy of reference implementations.

Key changes include:

In mamba_ssm/modules/mamba2.py:
Resolved an issue in _get_states_from_cache to correctly handle dynamic batch sizes during inference, ensuring proper state re-initialization when batch sizes change.
Removed the batch == 1 assertion in the forward method for variable-length sequence inference, enabling batched processing for these inputs.
Updated the fallback path in the step method to support ngroups > 1, allowing grouped SSM inference even if Triton kernels are not available.
In mamba_ssm/ops/selective_scan_interface.py:
Added optional RMS Normalization for B, C, and delta tensors within mamba_inner_ref to better match the main MambaInnerFn's behavior and improve numerical consistency.
Corrected a shape comment in selective_scan_ref for clarity.
In mamba_ssm/models/mixer_seq_simple.py:
Removed a redundant comment in the _init_weights function.
In mamba_ssm/utils/hf.py:
Addressed a bug in load_state_dict_hf to ensure correct dtype conversion and device placement when loading Hugging Face model weights.
These modifications enhance the stability, flexibility, and correctness of the Mamba library.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant