-
Notifications
You must be signed in to change notification settings - Fork 1.3k
[Feature] Support variable-length sequences for mamba block #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hello @tridao @albertfgu Thanks for the awesome work on mamba and it is really a strong competitor for transformer! We have noticed some issues (#236, #180) stated that they have a need for training on variable-length sequences. But they can’t find functionalities such as Also, in real world scenarios, length distribution of datasets varies much, simply padding token to maximum length would waste computing resources on the meaningless padded tokens. So we implemented this PR and hope it helps! |
77e58cb
to
a78a9eb
Compare
aea08ca
to
842bef5
Compare
Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding? |
Thank you for your interest in this PR! Update (2024/03/19): |
Thank you for your reply. Due to performance considerations, I would like to use bidirectional mamba. Should I wait for your updated code? |
Hi @EricPaul03 , @Dmovic has created unit test on the backward pass of mamba block with variable-length sequences, and the test results show numerical equality for both forward and backward pass in the scenarios of varlen inputs. I haven't tried it with bidirectional mamba. But since it is numerical equivalent for the default unidirectional mamba, I think you can just give it a try! |
To give a simple example. What we originally pass into the original mamba block is an input with shape From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens. Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via |
Thank you for your answer. This is a great code that I will try to use for my project! |
Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv?
|
I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?) |
I copy some method in MixerModel to help use this feature.
|
For bidirectional mamba, you need to pass in the
For example, if you have We can calculate
|
I think you might not need to divide these items into subsequences. All you need is to pass in the For combining the benefits of bidirectional mamba and this PR's variable-length sequences, I drew my graphical understanding here, The mechanism can be simply viewed as that when scanning bidirectionally, hidden_states need to be reset on sequence boundaries of both directions. |
It's great to see that there already one paper/project (Is Mamba Compatible with Trajectory Optimization in Offline Reinforcement Learning, NeurIPS'24) adopting our code in the area of offline Reinforcement Learning. |
Hi @zigzagcai, thank you for the great work! I tried to install your version but encountered the The full pipeline I did is the following: # (optionally) clone causal-conv1d, also tried pip install causal-conv1d==1.4.0
git clone https://github.com/Dao-AILab/causal-conv1d
cd causal-conv1d
git checkout v1.4.0
pip install -e .
cd ..
# clone and checkout your pr
git clone https://github.com/state-spaces/mamba
cd mamba
git fetch origin pull/244/head:pr-244
git checkout pr-244
pip install -e . Tried installing with pytorch 2.4, 2.1, cuda 12.5, 12.1. All settings have the same problem: > python tests/ops/test_mamba_cu_seqlens_equivalence.py
Traceback (most recent call last):
File "/.../mamba/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 5, in <module>
from mamba_ssm.modules.mamba_simple import Mamba
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/__init__.py", line 3, in <module>
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/selective_scan_interface.py", line 16, in <module>
import selective_scan_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE Additionally, I also found that the installed
Similarly, > pip show mamba-ssm
Name: mamba_ssm
Version: 2.2.2
Summary: Mamba state-space model
Home-page:
Author:
Author-email: Tri Dao <tri@tridao.me>, Albert Gu <agu@cs.cmu.edu>
...
Location: /usr/local/lib/python3.10/dist-packages
Requires: einops, ninja, packaging, setuptools, torch, transformers, triton (causal-conv1d is not here)
Required-by: If this issue does't occur to you, could you provide the installing script you are using for the most up-to-date version? Thanks! |
Hi @JindongJiang , I share my minimum reproducing steps here.
|
Hi, @JindongJiang Firstly, Thanks for your interest in this PR!
|
Hi @zigzagcai, thank you very much for the help. Interestingly, deleting the
Beside the pytorch and cuda version, I used the same setup as you suggested:
I will now try using cuda 11.8 as well and will let you know if I get the same problem. |
Hi @zigzagcai, I am back with cuda 11.8 results, problem still exist. This time I am (almost) fully following your setup script:
Only difference is that I have to do
Complete results and env:
It is actually quite surprising that the big discrepancies only happen at the beginning and end: in_proj and out_proj. Could you provide some comments on this? Thanks! |
Hi @JindongJiang , The error below is caused by the
|
I just revert the recent merge commit in 0a15f1d Could you please re-try my branch? I just re-tested the code on my env and it is okay.
The test results
FYI. My local envs (including cuda version and pip packages):
|
BTW. @JindongJiang Which model of GPU are you using, A100, H100 or others? This way I can have better knowledge about your software and hardware environment. |
Hi @zigzagcai , thank you very much for the updates and new commits. I will test the new setup. I got the above results using A100. |
0a15f1d
to
cda4b5a
Compare
Hi @zigzagcai , it seems that the grad discrepancy only exist when I use docker image in slurm. I have two ways to run the experiments:
Thank you for your help again! I think the problem is not in the implementation then. I will use conda without docker for now. |
Very glad to see it is helpful to you! You are right. I guess there might be some conflicts when you try to install packages with |
Hi @zigzagcai Here is how I install dependencies, which might be useful for those working with CUDA 12.5: `conda create -n your_env_name python=3.10.13 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 pip install -r requirements.txt git clone git@github.com:hustvl/Vim.git pip install -e causal-conv1d>=1.1.0 pip install -e mamba-1p1p1 pip install --upgrade huggingface-hub==0.24.0` I made a slight adjustment to your example, and here is the revised version: `from collections import Counter import torch sentences = [ word_counter = Counter(chain(*[sentence.lower().split() for sentence in sentences])) def variable_length_sequences(new_tensor): def unpack(packed_hidden_states, cu_seqlens): def pack(hidden_states, cu_seqlens): hidden_dim = 256 new_tensor_reeshaped_index = variable_length_sequences(padded_sequences) out_ref = mamba(hidden_states) I noticed that when processing 4 sentences, you receive embeddings for only 3 sentences (torch.Size([3, 6, 256])). It might be helpful to append last_index + 1 to the list in your variable_length_sequences function (i.e., start_indexes.append(last_index + 1)). This adjustment should ensure that the number of output sentences matches the number of input sentences (torch.Size([4, 6, 256])). I am receiving embeddings with a shape of torch.Size([4, 6, 256]). However, one of my sentences contains only three words. Should I apply masking to the returned sequences to remove embeddings that might not be meaningful? Thanks, |
Hi,
Thank you very much for your code and illustrations, but I have some doubts about the parameters seqlen and seq_idx in Mamba2 in the following figure. Could you provide the corresponding illustration for these parameters? |
Thanks for the great job! I confuse that this version shares hidden states between each batches? That will be a great help to me! |
Hi @CacatuaAlan , Thank you! This version of code can handle packed hidden states, which combines multiple batches of hidden states into one. How to avoid avoid cross-batch contamination?
|
Hi @zongtianhu ,
For example, a packed sentence consisting of 7 sub-sentences:
|
Hi @zigzagcai great work! If I pack all samples as per your example and run with Looking at the cuda kernel, storing the ssm_state happens inside this block (L#279-L#298, selective_scan_fwd_kernel.cuh):
It seems that the InclusiveScan op only returns the last state in the scan? |
Hey @zigzagcai! First of all, thank you for your work. I'm trying to use your feature for my project, and benchmarked the variable length forward pass against separate forward passes for each sample. Unfortunately, the results are rather disappointing, separate passes are ~2.5x faster than variable length batching. Any idea what might be going wrong?
Output:
Environment:
|
Hi @fzsomb , Sorry for the late response. I have checked with the code and there are two points should be pointed out:
Using my performance test script, the performance comparison for forward pass is shown here: Generate random cu_seqlens = [0, 239450, 335932, 339432, 429781, 449130, 490937, 596597, 627200]
max diff for output in varlen_mamba fwd pass: 5.960464477539063e-08
mean diff for output in varlen_mamba fwd pass: 2.5034383455135867e-09
max diff for output in varlen_mamba fwd pass: 8.940696716308594e-08
mean diff for output in varlen_mamba fwd pass: 2.5367627998207354e-09
max diff for output in varlen_mamba fwd pass: 5.21540641784668e-08
mean diff for output in varlen_mamba fwd pass: 2.57035126516314e-09
max diff for output in varlen_mamba fwd pass: 8.940696716308594e-08
mean diff for output in varlen_mamba fwd pass: 2.532647425113055e-09
max diff for output in varlen_mamba fwd pass: 7.450580596923828e-08
mean diff for output in varlen_mamba fwd pass: 2.5388156021932673e-09
max diff for output in varlen_mamba fwd pass: 8.940696716308594e-08
mean diff for output in varlen_mamba fwd pass: 2.545051724922587e-09
max diff for output in varlen_mamba fwd pass: 7.450580596923828e-08
mean diff for output in varlen_mamba fwd pass: 2.5323849683900335e-09
max diff for output in varlen_mamba fwd pass: 7.450580596923828e-08
mean diff for output in varlen_mamba fwd pass: 2.5522350899365165e-09
Total forward time for separate: 0.4627113342285156
Total forward time for batched: 0.015723705291748047 Environment
You would see nearly 30x speedup for the example batched inputs, measured by the forward time in varlen_mamba block. And as a comparison, if you comment out the two lines and let them be computed on-the-fly in varlen_mamba forward pass.
The performance would be:
We can clearly see that the performance of varlen_mamba is bottlenecked by the on-the-fly constructure of seq_idx. Therefore, in actual training scenarios, we need to prepare the necessary |
Support variable-length sequences for mamba block via
cu_seqlens/seq_idx/position_ids
in theforward
pass andbackward
pass, similar to what has been done (such as cumulative sequencescu_seqlens
or lower triangular block diagonal matrixattention mask
) in flash attentionvarlen_fwd/varlen_bwd
API.We have tested that training with variable-length sequences on real world datasets can bring end-to-end 2~4x speedup.
Why we need?
High speedup and hardware utilization on real world datasets that we tested. Can be used to improve hardware utilization when you have variable-length sequences and you don't want to waste computing resources on meaningless padded tokens. Especially useful when you do mamba training on real world datasets, where length distribution varies much and large proportion of samples are short sequences. Last but not least, we ensure exact fwd/bwd numerical equality with padding approach.
How to use?
Zero learning overhead, packed mamba API is similar to packed flash-attn API or packed mamba2 API. Just need to pack multiple variable-length sequences into one and additionally pass
cu_seqlens/seq_idx/position_ids
into mambaforward
pass.No need to modify
causal-conv1d
, just use the original https://github.com/Dao-AILab/causal-conv1d is fine. (version>=1.4.0)Note:
We thank @wang-zerui for the fwd pass python reference implementation and invaluable discussion on how to ensure numerical equality.
This is a joint work with @wang-zerui and @Dmovic and @ptxu78
Example usage:
https://github.com/zigzagcai/varlen_mamba/blob/feat/add-cu_seqlens/tests/ops/test_mamba_varlen.py
Limitation:
Some related issues about mamba and flash-attn variable-length training: