-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: Initial state support for Mamba SSM (1) #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@mzusman I've noticed you made changes to files in the csrc directory, but I'm having trouble getting these changes to take effect in my environment. Could you please tell me the exact instructions to rebuild and install the mamba_ssm package so the changes are applied? It seems I always get the original package using |
@daphneOdera-618 Yeah, the default setup.py behaviour is to download the upstream's wheel upon "installing", What you would need to do to force build is to add |
Unfortunately, this PR changes the API for selective_scan_cuda.fwd in an incompatible way. The same API is also invoked in MambaInnerFn.forward besides of SelectiveScanFn.forward, leading to runtime errors in code which uses MambaInnerFn (e.g. the Mamba implementation found in the transformers library while running in vanilla training mode without cache_params). I think MambaInnerFn.forward could be modified to use the new API version, but I don't know how to produce the prerequisite additional empty vector (x) from what is available in MambaInnerFn.fowrard. |
Since conv1d_out in MambeInnerFn seems to play the same role as u in SelectiveScanFn, adding this hack in place of the original invocation of selective_scan_cuda.fwd seems to work:
|
Add chunked prefill / use initial state capability to Mamba ssm ( Mamba 1 ) , Done it by prepending the last forward pass state to the FWD pass kernel and read the data accordingly .
Latency is not affected. ( benchmark script shows similar latencies between this PR and main - 130ms )
Added tests that check correctness when running on chunks.
Limitations:
This PR enables efficient Speculative decoding, prefix caching and prefill chunking.
FIX #233 #473 #258 #101