Skip to content

Commit ca189f6

Browse files
committed
Merge branch 'main' into feat/add-cu_seqlens
2 parents 5955450 + 9127d1f commit ca189f6

File tree

9 files changed

+53
-21
lines changed

9 files changed

+53
-21
lines changed

.github/workflows/publish.yaml

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,24 @@ jobs:
4343
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
46-
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
47-
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0']
46+
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
47+
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105']
4848
cuda-version: ['11.8.0', '12.2.2']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
5151
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
5252
# when building without C++11 ABI and using it on nvcr images.
5353
cxx11_abi: ['FALSE', 'TRUE']
5454
exclude:
55+
# Pytorch < 2.2 does not support Python 3.12
56+
- torch-version: '1.12.1'
57+
python-version: '3.12'
58+
- torch-version: '1.13.1'
59+
python-version: '3.12'
60+
- torch-version: '2.0.1'
61+
python-version: '3.12'
62+
- torch-version: '2.1.2'
63+
python-version: '3.12'
5564
# Pytorch <= 1.12 does not support Python 3.11
5665
- torch-version: '1.12.1'
5766
python-version: '3.11'
@@ -62,6 +71,8 @@ jobs:
6271
python-version: '3.7'
6372
- torch-version: '2.2.0'
6473
python-version: '3.7'
74+
- torch-version: '2.3.0.dev20240105'
75+
python-version: '3.7'
6576
# Pytorch <= 2.0 only supports CUDA <= 11.8
6677
- torch-version: '1.12.1'
6778
cuda-version: '12.2.2'
@@ -119,12 +130,24 @@ jobs:
119130
# If we don't install before installing Pytorch, we get error for torch 2.0.1
120131
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
121132
pip install lit
133+
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
134+
pip install setuptools
122135
# We want to figure out the CUDA version to download pytorch
123136
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
124137
# This code is ugly, maybe there's a better way to do this.
125-
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
138+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
139+
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \
140+
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \
141+
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
142+
)
126143
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
127-
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
144+
if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then
145+
# --no-deps because we can't install old versions of pytorch-triton
146+
pip install typing-extensions jinja2
147+
pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
148+
else
149+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
150+
fi
128151
else
129152
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
130153
fi

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla
1313

1414
## Installation
1515

16-
- `pip install causal-conv1d>=1.1.0,<1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
16+
- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
1717
- `pip install mamba-ssm`: the core Mamba package.
1818

1919
It can also be built from source with `pip install .` from this repository.

mamba_ssm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.1.4"
1+
__version__ = "1.2.0.post1"
22

33
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
44
from mamba_ssm.modules.mamba_simple import Mamba

mamba_ssm/models/config_mamba.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ class MambaConfig:
1212
residual_in_fp32: bool = True
1313
fused_add_norm: bool = True
1414
pad_vocab_size_multiple: int = 8
15+
tie_embeddings: bool = True

mamba_ssm/models/mixer_seq_simple.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,9 @@ def __init__(
220220
self.tie_weights()
221221

222222
def tie_weights(self):
223-
self.lm_head.weight = self.backbone.embedding.weight
224-
223+
if self.config.tie_embeddings:
224+
self.lm_head.weight = self.backbone.embedding.weight
225+
225226
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
226227
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
227228

@@ -251,8 +252,7 @@ def save_pretrained(self, save_directory):
251252
Save the model and its configuration file to a directory.
252253
"""
253254
# Ensure save_directory exists
254-
if not os.path.exists(save_directory):
255-
os.makedirs(save_directory)
255+
os.makedirs(save_directory, exist_ok=True)
256256

257257
# Save the model's state_dict
258258
model_path = os.path.join(save_directory, 'pytorch_model.bin')

mamba_ssm/modules/mamba_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
try:
1616
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
1717
except ImportError:
18-
causal_conv1d_fn, causal_conv1d_update = None
18+
causal_conv1d_fn, causal_conv1d_update = None, None
1919

2020
try:
2121
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
@@ -143,7 +143,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
143143

144144
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
145145
# In the backward pass we write dx and dz next to each other to avoid torch.cat
146-
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
146+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
147147
out = mamba_inner_fn(
148148
xz,
149149
self.conv1d.weight,

mamba_ssm/ops/selective_scan_interface.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66

77
from einops import rearrange, repeat
88

9-
from causal_conv1d import causal_conv1d_fn
10-
import causal_conv1d_cuda
9+
try:
10+
from causal_conv1d import causal_conv1d_fn
11+
import causal_conv1d_cuda
12+
except ImportError:
13+
causal_conv1d_fn = None
14+
causal_conv1d_cuda = None
15+
1116
import selective_scan_cuda
1217

1318

@@ -168,6 +173,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
168173
"""
169174
xz: (batch, dim, seqlen)
170175
"""
176+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
171177
assert checkpoint_lvl in [0, 1]
172178
L = xz.shape[-1]
173179
delta_rank = delta_proj_weight.shape[1]
@@ -196,7 +202,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
196202
assert x.shape[2] == (d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2]
197203

198204
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
199-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
205+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
200206

201207
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
202208
if cu_seqlens is not None:
@@ -262,6 +268,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
262268
@custom_bwd
263269
def backward(ctx, dout):
264270
# dout: (batch, seqlen, dim)
271+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
265272
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
266273
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, cu_seqlens) = ctx.saved_tensors
267274
L = xz.shape[-1]
@@ -285,7 +292,7 @@ def backward(ctx, dout):
285292
x = padded_x
286293
assert x.shape[2] == (d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2]
287294

288-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
295+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
289296

290297
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
291298
if cu_seqlens is not None:
@@ -345,8 +352,8 @@ def backward(ctx, dout):
345352
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
346353
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
347354
# backward of conv1d with the backward of chunk).
348-
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
349-
x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
355+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
356+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
350357
)
351358
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
352359
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
@@ -374,11 +381,12 @@ def mamba_inner_ref(
374381
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
375382
C_proj_bias=None, delta_softplus=True
376383
):
384+
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
377385
L = xz.shape[-1]
378386
delta_rank = delta_proj_weight.shape[1]
379387
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
380388
x, z = xz.chunk(2, dim=1)
381-
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
389+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
382390
# We're being very careful here about the layout, to avoid extra transposes.
383391
# We want delta to have d as the slowest moving dimension
384392
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,6 @@ def run(self):
271271
"einops",
272272
"triton",
273273
"transformers",
274-
"causal_conv1d>=1.1.0,<1.2.0",
274+
# "causal_conv1d>=1.2.0",
275275
],
276276
)

tests/ops/triton/test_selective_state_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# @pytest.mark.parametrize("dstate", [16])
2020
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
2121
# @pytest.mark.parametrize("dim", [2048])
22-
def test_causal_conv1d_update(dim, dstate, has_z, itype):
22+
def test_selective_state_update(dim, dstate, has_z, itype):
2323
device = "cuda"
2424
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
2525
if itype == torch.bfloat16:

0 commit comments

Comments
 (0)