Skip to content

Commit aeb4d3b

Browse files
authored
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
2 parents cda4b5a + 62db608 commit aeb4d3b

File tree

5 files changed

+54
-24
lines changed

5 files changed

+54
-24
lines changed

.github/workflows/publish.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
4646
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
47-
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240505']
47+
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0']
4848
cuda-version: ['11.8.0', '12.2.2']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
@@ -117,7 +117,7 @@ jobs:
117117
# This code is ugly, maybe there's a better way to do this.
118118
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
119119
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
120-
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
120+
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124}[env['MATRIX_TORCH_VERSION']]; \
121121
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
122122
)
123123
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla
2121

2222
- [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
2323
- `pip install mamba-ssm`: the core Mamba package.
24+
- `pip install mamba-ssm[causal-conv1d]`: To install core Mamba package and causal-conv1d.
25+
- `pip install mamba-ssm[dev]`: To install core Mamba package and dev depdencies.
2426

2527
It can also be built from source with `pip install .` from this repository.
2628

mamba_ssm/modules/mha.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def _update_kvcache_attention(self, q, kv, inference_params):
180180
).transpose(1, 2)
181181
else:
182182
batch = q.shape[0]
183-
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
183+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184+
kv_cache = kv_cache[:batch]
184185
cache_seqlens = (
185186
inference_params.lengths_per_sample[:batch]
186187
if inference_params.lengths_per_sample is not None

pyproject.toml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
[project]
2+
name = "mamba_ssm"
3+
description = "Mamba state-space model"
4+
readme = "README.md"
5+
authors = [
6+
{ name = "Tri Dao", email = "tri@tridao.me" },
7+
{ name = "Albert Gu", email = "agu@cs.cmu.edu" }
8+
]
9+
requires-python = ">= 3.7"
10+
dynamic = ["version"]
11+
license = { file = "LICENSE" } # Include a LICENSE file in your repo
12+
keywords = ["cuda", "pytorch", "state-space model"]
13+
classifiers = [
14+
"Programming Language :: Python :: 3",
15+
"License :: OSI Approved :: BSD License",
16+
"Operating System :: Unix"
17+
]
18+
dependencies = [
19+
"torch",
20+
"ninja",
21+
"einops",
22+
"triton",
23+
"transformers",
24+
"packaging",
25+
"setuptools>=61.0.0",
26+
]
27+
urls = { name = "Repository", url = "https://github.com/state-spaces/mamba"}
28+
29+
[project.optional-dependencies]
30+
causal-conv1d = [
31+
"causal-conv1d>=1.2.0"
32+
]
33+
dev = [
34+
"pytest"
35+
]
36+
37+
38+
[build-system]
39+
requires = [
40+
"setuptools>=61.0.0",
41+
"wheel",
42+
"torch",
43+
"packaging",
44+
"ninja",
45+
]
46+
build-backend = "setuptools.build_meta"

setup.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
from torch.utils.cpp_extension import (
2121
BuildExtension,
22-
CppExtension,
2322
CUDAExtension,
2423
CUDA_HOME,
2524
HIP_HOME
@@ -349,31 +348,13 @@ def run(self):
349348
"mamba_ssm.egg-info",
350349
)
351350
),
352-
author="Tri Dao, Albert Gu",
353-
author_email="tri@tridao.me, agu@cs.cmu.edu",
354-
description="Mamba state-space model",
355351
long_description=long_description,
356352
long_description_content_type="text/markdown",
357-
url="https://github.com/state-spaces/mamba",
358-
classifiers=[
359-
"Programming Language :: Python :: 3",
360-
"License :: OSI Approved :: BSD License",
361-
"Operating System :: Unix",
362-
],
353+
363354
ext_modules=ext_modules,
364355
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
365356
if ext_modules
366357
else {
367358
"bdist_wheel": CachedWheelsCommand,
368-
},
369-
python_requires=">=3.8",
370-
install_requires=[
371-
"torch",
372-
"packaging",
373-
"ninja",
374-
"einops",
375-
"triton",
376-
"transformers",
377-
# "causal_conv1d>=1.4.0",
378-
],
359+
}
379360
)

0 commit comments

Comments
 (0)