Skip to content

Commit 0a15f1d

Browse files
committed
Revert "Merge branch 'state-spaces:main' into feat/add-cu_seqlens"
This reverts commit aeb4d3b, reversing changes made to cda4b5a.
1 parent aeb4d3b commit 0a15f1d

File tree

5 files changed

+24
-54
lines changed

5 files changed

+24
-54
lines changed

.github/workflows/publish.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4545
os: [ubuntu-20.04]
4646
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
47-
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0']
47+
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240505']
4848
cuda-version: ['11.8.0', '12.2.2']
4949
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
5050
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
@@ -117,7 +117,7 @@ jobs:
117117
# This code is ugly, maybe there's a better way to do this.
118118
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
119119
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
120-
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124}[env['MATRIX_TORCH_VERSION']]; \
120+
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
121121
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
122122
)
123123
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla
2121

2222
- [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
2323
- `pip install mamba-ssm`: the core Mamba package.
24-
- `pip install mamba-ssm[causal-conv1d]`: To install core Mamba package and causal-conv1d.
25-
- `pip install mamba-ssm[dev]`: To install core Mamba package and dev depdencies.
2624

2725
It can also be built from source with `pip install .` from this repository.
2826

mamba_ssm/modules/mha.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ def _update_kvcache_attention(self, q, kv, inference_params):
180180
).transpose(1, 2)
181181
else:
182182
batch = q.shape[0]
183-
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184-
kv_cache = kv_cache[:batch]
183+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
185184
cache_seqlens = (
186185
inference_params.lengths_per_sample[:batch]
187186
if inference_params.lengths_per_sample is not None

pyproject.toml

Lines changed: 0 additions & 46 deletions
This file was deleted.

setup.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from torch.utils.cpp_extension import (
2121
BuildExtension,
22+
CppExtension,
2223
CUDAExtension,
2324
CUDA_HOME,
2425
HIP_HOME
@@ -348,13 +349,31 @@ def run(self):
348349
"mamba_ssm.egg-info",
349350
)
350351
),
352+
author="Tri Dao, Albert Gu",
353+
author_email="tri@tridao.me, agu@cs.cmu.edu",
354+
description="Mamba state-space model",
351355
long_description=long_description,
352356
long_description_content_type="text/markdown",
353-
357+
url="https://github.com/state-spaces/mamba",
358+
classifiers=[
359+
"Programming Language :: Python :: 3",
360+
"License :: OSI Approved :: BSD License",
361+
"Operating System :: Unix",
362+
],
354363
ext_modules=ext_modules,
355364
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
356365
if ext_modules
357366
else {
358367
"bdist_wheel": CachedWheelsCommand,
359-
}
368+
},
369+
python_requires=">=3.8",
370+
install_requires=[
371+
"torch",
372+
"packaging",
373+
"ninja",
374+
"einops",
375+
"triton",
376+
"transformers",
377+
# "causal_conv1d>=1.4.0",
378+
],
360379
)

0 commit comments

Comments
 (0)