Skip to content

Support cuda 12.8.0 and SBSA wheels #677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
64 changes: 28 additions & 36 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,25 @@ jobs:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.2 does not support Python 3.12
- torch-version: '2.1.2'
python-version: '3.12'
# Pytorch < 2.5 does not support Python 3.13
- torch-version: '2.1.2'
python-version: '3.13'
- torch-version: '2.2.2'
python-version: '3.13'
- torch-version: '2.3.1'
python-version: '3.13'
- torch-version: '2.4.0'
python-version: '3.13'
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ ubuntu-22.04, ubuntu-22.04-arm ]
python-version: [ '3.9', '3.10', '3.11', '3.12', '3.13' ]
torch-version: [ '2.4.0', '2.5.1', '2.6.0', '2.7.0' ]
cuda-version: [ '12.8.1' ]
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: [ 'FALSE', 'TRUE' ]
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.5 does not support Python 3.13
- torch-version: '2.4.0'
python-version: '3.13'

steps:
- name: Checkout
Expand All @@ -81,7 +71,8 @@ jobs:
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV

echo "MATRIX_ARCH=$(echo ${{ matrix.arch }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV

- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
Expand All @@ -93,13 +84,13 @@ jobs:

- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10

- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.19
uses: Jimver/cuda-toolkit@v0.2.23
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
Expand All @@ -113,25 +104,26 @@ jobs:
run: |
pip install --upgrade pip
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools==68.0.0
pip install setuptools==75.8.0
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
# pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
# Hard-coding this version of pytorch-triton for torch 2.8.0.dev20250425
pip install jinja2
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.3.0+gitab727c40-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_${{ matrix.arch }}.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_${{ matrix.arch }}.whl
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
Expand Down Expand Up @@ -197,7 +189,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.12'

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.2.4"
__version__ = "2.2.5"

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2025, Tri Dao.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from typing import Optional

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.
from typing import Optional

import torch
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import math

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mamba2_simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import math
import torch
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mha.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import math

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/mlp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.
from torch import nn
from torch.nn import functional as F

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/ssd_minimal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Albert Gu and Tri Dao.
# Copyright (c) 2025, Albert Gu and Tri Dao.
"""Minimal implementation of SSD.

This is the same as Listing 1 from the paper.
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/k_activations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

import torch

Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2025, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm.

# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/layernorm_gated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao.
# Copyright (c) 2025, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/selective_state_update.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_bmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_state_passing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2025, Tri Dao, Albert Gu.

"""We want triton==2.1.0 or 2.2.0 for this
"""
Expand Down
Loading