From 6e5a3a9c99a60586c36520f9d6ea60a1b1556440 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 22 Jan 2025 00:26:04 +0100 Subject: [PATCH 01/12] initial blackwell support --- .github/workflows/publish.yaml | 4 ++-- mamba_ssm/__init__.py | 2 +- mamba_ssm/distributed/tensor_parallel.py | 2 +- mamba_ssm/modules/block.py | 2 +- mamba_ssm/modules/mamba2.py | 2 +- mamba_ssm/modules/mamba2_simple.py | 2 +- mamba_ssm/modules/mha.py | 2 +- mamba_ssm/modules/mlp.py | 2 +- mamba_ssm/modules/ssd_minimal.py | 2 +- mamba_ssm/ops/triton/k_activations.py | 2 +- mamba_ssm/ops/triton/layer_norm.py | 2 +- mamba_ssm/ops/triton/layernorm_gated.py | 2 +- mamba_ssm/ops/triton/selective_state_update.py | 2 +- mamba_ssm/ops/triton/ssd_bmm.py | 2 +- mamba_ssm/ops/triton/ssd_chunk_scan.py | 2 +- mamba_ssm/ops/triton/ssd_chunk_state.py | 2 +- mamba_ssm/ops/triton/ssd_combined.py | 2 +- mamba_ssm/ops/triton/ssd_state_passing.py | 2 +- setup.py | 12 +++++++++++- 19 files changed, 30 insertions(+), 20 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 192f9562..273ee818 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] - cuda-version: ['11.8.0', '12.3.2'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + cuda-version: ['11.8.0', '12.3.2', '12.6.3'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index ac4f6e31..6280931e 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.4" +__version__ = "2.2.5" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba diff --git a/mamba_ssm/distributed/tensor_parallel.py b/mamba_ssm/distributed/tensor_parallel.py index 2d67b530..5d4f1000 100644 --- a/mamba_ssm/distributed/tensor_parallel.py +++ b/mamba_ssm/distributed/tensor_parallel.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao. +# Copyright (c) 2025, Tri Dao. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py from typing import Optional diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py index 1bd968a0..8ebb8dd1 100644 --- a/mamba_ssm/modules/block.py +++ b/mamba_ssm/modules/block.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. from typing import Optional import torch diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d47..ceeb3d04 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import math diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index 77a6af28..cc51be4f 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import math import torch diff --git a/mamba_ssm/modules/mha.py b/mamba_ssm/modules/mha.py index 978f3ea4..0818394b 100644 --- a/mamba_ssm/modules/mha.py +++ b/mamba_ssm/modules/mha.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import math diff --git a/mamba_ssm/modules/mlp.py b/mamba_ssm/modules/mlp.py index 33bab5c7..7e6fb16e 100644 --- a/mamba_ssm/modules/mlp.py +++ b/mamba_ssm/modules/mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. from torch import nn from torch.nn import functional as F diff --git a/mamba_ssm/modules/ssd_minimal.py b/mamba_ssm/modules/ssd_minimal.py index 9632ebd4..6e8d5382 100644 --- a/mamba_ssm/modules/ssd_minimal.py +++ b/mamba_ssm/modules/ssd_minimal.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Albert Gu and Tri Dao. +# Copyright (c) 2025, Albert Gu and Tri Dao. """Minimal implementation of SSD. This is the same as Listing 1 from the paper. diff --git a/mamba_ssm/ops/triton/k_activations.py b/mamba_ssm/ops/triton/k_activations.py index 79fa2cc6..1b0c2640 100644 --- a/mamba_ssm/ops/triton/k_activations.py +++ b/mamba_ssm/ops/triton/k_activations.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. import torch diff --git a/mamba_ssm/ops/triton/layer_norm.py b/mamba_ssm/ops/triton/layer_norm.py index 200b415a..a2699c4b 100755 --- a/mamba_ssm/ops/triton/layer_norm.py +++ b/mamba_ssm/ops/triton/layer_norm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao. +# Copyright (c) 2025, Tri Dao. # Implement dropout + residual + layer_norm / rms_norm. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html diff --git a/mamba_ssm/ops/triton/layernorm_gated.py b/mamba_ssm/ops/triton/layernorm_gated.py index de4b2f48..33ccc0e1 100644 --- a/mamba_ssm/ops/triton/layernorm_gated.py +++ b/mamba_ssm/ops/triton/layernorm_gated.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao. +# Copyright (c) 2025, Tri Dao. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index d425bc72..a11c426c 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_bmm.py b/mamba_ssm/ops/triton/ssd_bmm.py index 48fd4f06..4f505bcc 100644 --- a/mamba_ssm/ops/triton/ssd_bmm.py +++ b/mamba_ssm/ops/triton/ssd_bmm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index fa5b813a..b7b1d7e6 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index bb49c9a9..04625490 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 58a6e04a..54e7a3d9 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py index 63863b82..ebf0176d 100644 --- a/mamba_ssm/ops/triton/ssd_state_passing.py +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. +# Copyright (c) 2025, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ diff --git a/setup.py b/setup.py index 7c6196d7..0614a2d8 100755 --- a/setup.py +++ b/setup.py @@ -184,10 +184,20 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("-gencode") cc_flag.append("arch=compute_87,code=sm_87") - + cc_flag.append("-gencode") + cc_flag.append("arch=compute_89,code=sm_89") if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90a,code=sm_90a") + if bare_metal_version >= Version("12.7"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_100,code=sm_100") # B100 + cc_flag.append("-gencode") + cc_flag.append("arch=compute_101,code=sm_101") # Thor + cc_flag.append("-gencode") + cc_flag.append("arch=compute_120,code=sm_100") # RTX50 # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as From 932bc7fc370b8c798045b89bdf8c1afcef068121 Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 23 Jan 2025 22:30:59 +0100 Subject: [PATCH 02/12] Update publish.yaml --- .github/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 273ee818..f8af10f4 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -45,7 +45,7 @@ jobs: os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] - cuda-version: ['11.8.0', '12.3.2', '12.6.3'] + cuda-version: ['11.8.0', '12.3.2', '12.6.3', '12.8.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) From a35151710abc04687038b807313a2086ffeba8b7 Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 23 Jan 2025 22:31:49 +0100 Subject: [PATCH 03/12] Update setup.py --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0614a2d8..c10044de 100755 --- a/setup.py +++ b/setup.py @@ -197,7 +197,9 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("-gencode") cc_flag.append("arch=compute_101,code=sm_101") # Thor cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_100") # RTX50 + cc_flag.append("arch=compute_120,code=sm_120") # RTX50 + cc_flag.append("-gencode") + cc_flag.append("arch=compute_120a,code=sm_120a") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as From 25e883bafe447f1c17de796936d83df3ae202a76 Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 24 Jan 2025 23:54:27 +0100 Subject: [PATCH 04/12] . --- .github/workflows/publish.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index f8af10f4..8b140f72 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] - cuda-version: ['11.8.0', '12.3.2', '12.6.3', '12.8.0'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250130'] + cuda-version: ['11.8.0', '12.6.3', '12.8.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -93,13 +93,13 @@ jobs: - name: Set up swap space if: runner.os == 'Linux' - uses: pierotofy/set-swap-space@v1.0 + uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.19 + uses: Jimver/cuda-toolkit@v0.2.20 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} @@ -121,16 +121,16 @@ jobs: # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 121 }[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128 }[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 + # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250130 pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} From f091680e8259073feafcdc8e59aa57b688c946b7 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 5 Feb 2025 11:05:37 +0100 Subject: [PATCH 05/12] Update publish.yaml --- .github/workflows/publish.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 8b140f72..2a7c2d54 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250130'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250205'] cuda-version: ['11.8.0', '12.6.3', '12.8.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -99,7 +99,7 @@ jobs: - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.20 + uses: Jimver/cuda-toolkit@v0.2.21 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} @@ -121,16 +121,16 @@ jobs: # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 121 }[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 124 }[env['MATRIX_TORCH_VERSION']]; \ maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128 }[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250130 + # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250205 pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} From bd1113b1b5e69f61ffb411b4f29f03bbf059f63e Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 5 Feb 2025 11:11:10 +0100 Subject: [PATCH 06/12] Update publish.yaml --- .github/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 2a7c2d54..8fa796e8 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -131,7 +131,7 @@ jobs: # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250205 pip install jinja2 pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_x86_64.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi From b91b86d5e8c877108202da2c5c7166d8b5d12958 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 5 Feb 2025 12:43:15 +0100 Subject: [PATCH 07/12] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c10044de..1127fed3 100755 --- a/setup.py +++ b/setup.py @@ -191,7 +191,7 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("arch=compute_90,code=sm_90") cc_flag.append("-gencode") cc_flag.append("arch=compute_90a,code=sm_90a") - if bare_metal_version >= Version("12.7"): + if bare_metal_version >= Version("12.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_100,code=sm_100") # B100 cc_flag.append("-gencode") From 768f5373df89a6863ecc3a7eb16972be1c63c9e3 Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 25 Feb 2025 15:05:10 +0100 Subject: [PATCH 08/12] Update publish.yaml --- .github/workflows/publish.yaml | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 8fa796e8..dd21c2a4 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -42,10 +42,11 @@ jobs: matrix: # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-20.04] + os: [ubuntu-22.04, ubuntu-22.04-arm] + arch: [x86_64, aarch64] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0', '2.7.0.dev20250205'] - cuda-version: ['11.8.0', '12.6.3', '12.8.0'] + torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + cuda-version: ['12.8.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -53,12 +54,12 @@ jobs: cxx11_abi: ['FALSE', 'TRUE'] exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.2 does not support Python 3.12 - - torch-version: '2.1.2' - python-version: '3.12' - # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.1.2' - python-version: '3.13' + - os: ubuntu-22.04-arm + arch: x86_64 + # Prevent trying to run aarch64 on ubuntu-22.04 x86 + - os: ubuntu-22.04 + arch: aarch64 + # Pytorch < 2.5 does not support Python 3.13 (your existing excludes) - torch-version: '2.2.2' python-version: '3.13' - torch-version: '2.3.1' @@ -81,7 +82,8 @@ jobs: echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - + echo "MATRIX_ARCH=$(echo ${{ matrix.arch }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + - name: Free up disk space if: ${{ runner.os == 'Linux' }} # https://github.com/easimon/maximize-build-space/blob/master/action.yml @@ -130,8 +132,8 @@ jobs: # Can't use --no-deps because we need cudnn etc. # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250205 pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_x86_64.whl + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${{ matrix.python-version|replace('.', '') }}-cp${{ matrix.python-version|replace('.', '') }}-manylinux_2_28_${{ matrix.arch }}.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_${{ matrix.arch }}.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi From 133e476e5061783824112f39f40c105ec6e93cdf Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 25 Feb 2025 15:14:03 +0100 Subject: [PATCH 09/12] Update setup.py --- setup.py | 379 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 213 insertions(+), 166 deletions(-) diff --git a/setup.py b/setup.py index 1127fed3..5ac28b39 100755 --- a/setup.py +++ b/setup.py @@ -1,13 +1,17 @@ -# Copyright (c) 2023, Albert Gu, Tri Dao. + +# Copyright (c) 2025, Tri Dao. + import sys +import functools import warnings import os import re import ast +import glob +import shutil from pathlib import Path from packaging.version import parse, Version import platform -import shutil from setuptools import setup, find_packages import subprocess @@ -19,9 +23,11 @@ import torch from torch.utils.cpp_extension import ( BuildExtension, + CppExtension, CUDAExtension, CUDA_HOME, - HIP_HOME + ROCM_HOME, + IS_HIP_EXTENSION, ) @@ -32,9 +38,24 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) +BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") + +if BUILD_TARGET == "auto": + if IS_HIP_EXTENSION: + IS_ROCM = True + else: + IS_ROCM = False +else: + if BUILD_TARGET == "cuda": + IS_ROCM = False + elif BUILD_TARGET == "rocm": + IS_ROCM = True + PACKAGE_NAME = "mamba_ssm" -BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" +BASE_WHEEL_URL = ( + "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" +) # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation @@ -42,14 +63,24 @@ SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" +USE_TRITON_ROCM = os.getenv("MAMBA_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -def get_platform(): +@functools.lru_cache(maxsize=None) +def cuda_archs() -> str: + return os.getenv("MAMBA_CUDA_ARCHS", "80;90;100;120").split(";") + +def get_arch(): """ - Returns the platform name as used in wheel filenames. + Returns the system aarch for the current system. """ if sys.platform.startswith("linux"): - return "linux_x86_64" + if platform.machine() == "x86_64": + return "x86_64" + elif platform.machine() == "arm64" or platform.machine() == "aarch64": + return "aarch64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) elif sys.platform == "darwin": mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) return f"macosx_{mac_version}_x86_64" @@ -58,56 +89,37 @@ def get_platform(): else: raise ValueError("Unsupported platform: {}".format(sys.platform)) +def get_system() -> str: + """ + Returns the system name as used in wheel filenames. + """ + if platform.system() == "Windows": + return "win" + elif platform.system() == "Darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:1]) + return f"macos_{mac_version}" + elif platform.system() == "Linux": + return "linux" + else: + raise ValueError("Unsupported system: {}".format(platform.system())) + +def get_platform() -> str: + """ + Returns the platform name as used in wheel filenames. + """ + return f"{get_system()}_{get_arch()}" def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True - ) + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 - bare_metal_ver = parse(output[release_idx].split(",")[0]) + bare_metal_version = parse(output[release_idx].split(",")[0]) - return raw_output, bare_metal_ver + return raw_output, bare_metal_version -def get_hip_version(rocm_dir): - - hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc") - try: - raw_output = subprocess.check_output( - [hipcc_bin, "--version"], universal_newlines=True - ) - except Exception as e: - print( - f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}" - ) - return None, None - - for line in raw_output.split("\n"): - if "HIP version" in line: - rocm_version = parse(line.split()[-1].rstrip('-').replace('-', '+')) # local version is not parsed correctly - return line, rocm_version - - return None, None - - -def get_torch_hip_version(): - - if torch.version.hip: - return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) - else: - return None - - -def check_if_hip_home_none(global_option: str) -> None: - - if HIP_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary - # in that case. - warnings.warn( - f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?" - ) +def get_hip_version(): + return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) def check_if_cuda_home_none(global_option: str) -> None: @@ -122,108 +134,107 @@ def check_if_cuda_home_none(global_option: str) -> None: ) -def append_nvcc_threads(nvcc_extra_args): - return nvcc_extra_args + ["--threads", "4"] - +def check_if_rocm_home_none(global_option: str) -> None: + if ROCM_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but hipcc was not found." + ) -cmdclass = {} -ext_modules = [] +def append_nvcc_threads(nvcc_extra_args): + nvcc_threads = os.getenv("NVCC_THREADS") or "2" + return nvcc_extra_args + ["--threads", nvcc_threads] -HIP_BUILD = bool(torch.version.hip) -if not SKIP_CUDA_BUILD: - print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) +def rename_cpp_to_cu(cpp_files): + for entry in cpp_files: + shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") - cc_flag = [] - if HIP_BUILD: - check_if_hip_home_none(PACKAGE_NAME) +def validate_and_update_archs(archs): + # List of allowed architectures + allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"] - rocm_home = os.getenv("ROCM_PATH") - _, hip_version = get_hip_version(rocm_home) + # Validate if each element in archs is in allowed_archs + assert all( + arch in allowed_archs for arch in archs + ), f"One of GPU archs of {archs} is invalid or not supported by Causal-conv1d" - if HIP_HOME is not None: - if hip_version < Version("6.0"): - raise RuntimeError( - f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. " - "Note: make sure HIP has a supported version by running hipcc --version." - ) - if hip_version == Version("6.0"): - warnings.warn( - f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. " - "Refer to the README.md for detailed instructions.", - UserWarning - ) - cc_flag.append("-DBUILD_PYTHON_PACKAGE") +cmdclass = {} +ext_modules = [] +# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp +# files included in the source distribution, in case the user compiles from source. +if os.path.isdir(".git"): + subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) +else: + if IS_ROCM: + if not USE_TRITON_ROCM: + assert ( + os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py") + ), "csrc/composable_kernel is missing, please use source distribution or git clone" else: - check_if_cuda_home_none(PACKAGE_NAME) - # Check, if CUDA11 is installed for compute capability 8.0 + assert ( + os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") + ), "csrc/cutlass is missing, please use source distribution or git clone" - if CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version < Version("11.6"): - raise RuntimeError( +if not SKIP_CUDA_BUILD and not IS_ROCM: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + check_if_cuda_home_none(PACKAGE_NAME) + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.7"): + raise RuntimeError( f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." - ) + ) - cc_flag.append("-gencode") - cc_flag.append("arch=compute_53,code=sm_53") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_62,code=sm_62") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_72,code=sm_72") + if "80" in cuda_archs(): cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_87,code=sm_87") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_89,code=sm_89") - if bare_metal_version >= Version("11.8"): + if CUDA_HOME is not None: + if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") + if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): cc_flag.append("-gencode") - cc_flag.append("arch=compute_90a,code=sm_90a") - if bare_metal_version >= Version("12.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") # B100 + cc_flag.append("arch=compute_100,code=sm_100") + if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): cc_flag.append("-gencode") - cc_flag.append("arch=compute_101,code=sm_101") # Thor - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") # RTX50 - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120a,code=sm_120a") - + cc_flag.append("arch=compute_120,code=sm_120") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True - - if HIP_BUILD: - - extra_compile_args = { - "cxx": ["-O3", "-std=c++17"], - "nvcc": [ - "-O3", - "-std=c++17", - f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-fgpu-flush-denormals-to-zero", - ] - + cc_flag, - } - else: - extra_compile_args = { + + ext_modules.append( + CUDAExtension( + name="selective_scan_cuda", + sources=[ + "csrc/selective_scan/selective_scan.cpp", + "csrc/selective_scan/selective_scan_fwd_fp32.cu", + "csrc/selective_scan/selective_scan_fwd_fp16.cu", + "csrc/selective_scan/selective_scan_fwd_bf16.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", + ], + extra_compile_args = { "cxx": ["-O3", "-std=c++17"], "nvcc": append_nvcc_threads( [ @@ -243,30 +254,48 @@ def append_nvcc_threads(nvcc_extra_args): ] + cc_flag ), - } - - ext_modules.append( - CUDAExtension( - name="selective_scan_cuda", - sources=[ - "csrc/selective_scan/selective_scan.cpp", - "csrc/selective_scan/selective_scan_fwd_fp32.cu", - "csrc/selective_scan/selective_scan_fwd_fp16.cu", - "csrc/selective_scan/selective_scan_fwd_bf16.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", - ], - extra_compile_args=extra_compile_args, + }, include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], ) ) + +elif not SKIP_CUDA_BUILD and IS_ROCM: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + generator_flag = [] + torch_dir = torch.__path__[0] + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + cc_flag = [f"--offload-arch={arch}" for arch in archs] + hip_version = get_hip_version() + if hip_version > Version('5.7.23302'): + cc_flag += ["-fno-offload-uniform-block"] + if hip_version > Version('6.1.40090'): + cc_flag += ["-mllvm", "-enable-post-misched=0"] + if hip_version > Version('6.2.41132'): + cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false"] + if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'): + cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] + if USE_TRITON_ROCM: + # Skip C++ extension compilation if using Triton Backend + pass + else: + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"], + "nvcc": [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-fgpu-flush-denormals-to-zero", + ] + + cc_flag, + } -def get_package_version(): +def get_package_version() -> str: with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) @@ -277,49 +306,46 @@ def get_package_version(): return str(public_version) -def get_wheel_url(): - # Determine the version numbers that will be used to determine the correct wheel +def get_wheel_url() -> tuple[str, str]: torch_version_raw = parse(torch.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + mamba_ssm_version = get_package_version() + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() - if HIP_BUILD: - # We're using the HIP version used to build torch, not the one currently installed - torch_hip_version = get_torch_hip_version() - hip_ver = f"{torch_hip_version.major}{torch_hip_version.minor}" + if IS_ROCM: + torch_hip_version = get_hip_version() + hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}" + wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" else: + # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 # to save CI time. Minor versions should be compatible. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" cuda_version = f"{torch_cuda_version.major}" - gpu_compute_version = hip_ver if HIP_BUILD else cuda_version - cuda_or_hip = "hip" if HIP_BUILD else "cu" + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" - platform_name = get_platform() - mamba_ssm_version = get_package_version() - torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" - cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename) - # Determine wheel URL based on CUDA version, torch version, python version and OS - wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename - ) return wheel_url, wheel_filename class CachedWheelsCommand(_bdist_wheel): """ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot - find an existing wheel (which is currently the case for all installs). We use + find an existing wheel (which is currently the case for all cusual conv1d installs). We use the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ - def run(self): + def run(self) -> None: if FORCE_BUILD: return super().run() @@ -339,12 +365,33 @@ def run(self): wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) - shutil.move(wheel_filename, wheel_path) - except urllib.error.HTTPError: + os.rename(wheel_filename, wheel_path) + except (urllib.error.HTTPError, urllib.error.URLError): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source super().run() + +class NinjaBuildExtension(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + # do not override env MAX_JOBS if already exists + if not os.environ.get("MAX_JOBS"): + import psutil + + # calculate the maximum allowed NUM_JOBS based on cores + max_num_jobs_cores = max(1, os.cpu_count() // 2) + + # calculate the maximum allowed NUM_JOBS based on free memory + free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB + max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4 + + # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation + max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) + os.environ["MAX_JOBS"] = str(max_jobs) + + super().__init__(*args, **kwargs) + + setup( name=PACKAGE_NAME, version=get_package_version(), From 29063912258d093c8d8157493f891c560f2dbc60 Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 25 Feb 2025 15:17:06 +0100 Subject: [PATCH 10/12] Update setup.py --- setup.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 5ac28b39..aaa475f2 100755 --- a/setup.py +++ b/setup.py @@ -170,18 +170,7 @@ def validate_and_update_archs(archs): # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. if os.path.isdir(".git"): - subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) - subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) -else: - if IS_ROCM: - if not USE_TRITON_ROCM: - assert ( - os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py") - ), "csrc/composable_kernel is missing, please use source distribution or git clone" - else: - assert ( - os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") - ), "csrc/cutlass is missing, please use source distribution or git clone" + subprocess.run(["git", "submodule", "update", "--init", "csrc/selective_scan"], check=True) if not SKIP_CUDA_BUILD and not IS_ROCM: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) From eca5db81af7a325de417c40554870d2d961f9361 Mon Sep 17 00:00:00 2001 From: johnnynunez Date: Thu, 1 May 2025 09:58:35 +0200 Subject: [PATCH 11/12] Update publish.yaml and setup.py for compatibility improvements and dependency updates --- .github/workflows/publish.yaml | 73 ++++++++++++++++++---------------- setup.py | 6 +++ 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index dd21c2a4..096fda3b 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -36,36 +36,38 @@ jobs: name: Build Wheel needs: setup_release runs-on: ${{ matrix.os }} - strategy: fail-fast: false matrix: - # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the - # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-22.04, ubuntu-22.04-arm] - arch: [x86_64, aarch64] - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] - cuda-version: ['12.8.0'] - # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. - # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. - # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) - # when building without C++11 ABI and using it on nvcr images. - cxx11_abi: ['FALSE', 'TRUE'] - exclude: - # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - - os: ubuntu-22.04-arm - arch: x86_64 - # Prevent trying to run aarch64 on ubuntu-22.04 x86 - - os: ubuntu-22.04 - arch: aarch64 - # Pytorch < 2.5 does not support Python 3.13 (your existing excludes) - - torch-version: '2.2.2' - python-version: '3.13' - - torch-version: '2.3.1' - python-version: '3.13' - - torch-version: '2.4.0' - python-version: '3.13' + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ ubuntu-22.04, ubuntu-22.04-arm ] + python-version: [ '3.9', '3.10', '3.11', '3.12', '3.13' ] + torch-version: [ '2.4.0', '2.5.1', '2.6.0', '2.7.0' ] + cuda-version: [ '12.4.1', '12.8.1' ] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: [ 'FALSE', 'TRUE' ] + exclude: + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # Pytorch < 2.5 does not support Python 3.13 + # PyTorch < 2.5 doesn’t support Python 3.13 + - torch-version: '2.4.0' + python-version: '3.13' + + # PyTorch 2.7.0 must only use CUDA 12.8.1 + - torch-version: '2.7.0' + cuda-version: '12.4.1' + + # All other PyTorch (< 2.7.0) must only use CUDA 12.4.1 + - torch-version: '2.4.0' + cuda-version: '12.8.1' + - torch-version: '2.5.1' + cuda-version: '12.8.1' + - torch-version: '2.6.0' + cuda-version: '12.8.1' steps: - name: Checkout @@ -101,7 +103,7 @@ jobs: - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.21 + uses: Jimver/cuda-toolkit@v0.2.23 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} @@ -115,25 +117,26 @@ jobs: run: | pip install --upgrade pip # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools - pip install setuptools==68.0.0 + pip install setuptools==75.8.0 # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable pip install typing-extensions==4.12.2 # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118, '2.7': 124 }[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128 }[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128}[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.7.0.dev20250205 + # Hard-coding this version of pytorch-triton for torch 2.8.0.dev20250425 pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0-cp${{ matrix.python-version|replace('.', '') }}-cp${{ matrix.python-version|replace('.', '') }}-manylinux_2_28_${{ matrix.arch }}.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_${{ matrix.arch }}.whl + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.3.0+gitab727c40-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_${{ matrix.arch }}.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_${{ matrix.arch }}.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi @@ -199,7 +202,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.12' - name: Install dependencies run: | diff --git a/setup.py b/setup.py index aaa475f2..6adabea8 100755 --- a/setup.py +++ b/setup.py @@ -198,9 +198,15 @@ def validate_and_update_archs(archs): if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): cc_flag.append("-gencode") cc_flag.append("arch=compute_100,code=sm_100") + if bare_metal_version >= Version("12.8") and "101" in cuda_archs(): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_101,code=sm_101") if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): cc_flag.append("-gencode") cc_flag.append("arch=compute_120,code=sm_120") + if bare_metal_version >= Version("13.0") and "110" in cuda_archs(): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_110,code=sm_110") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI From 7ec75ed0796f707d36e6ffd95a1c5fdef3480cb0 Mon Sep 17 00:00:00 2001 From: johnnynunez Date: Thu, 1 May 2025 10:09:04 +0200 Subject: [PATCH 12/12] Update publish.yaml for improved compatibility and dependency management --- .github/workflows/publish.yaml | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 096fda3b..2d659139 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -39,12 +39,12 @@ jobs: strategy: fail-fast: false matrix: - # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ ubuntu-22.04, ubuntu-22.04-arm ] python-version: [ '3.9', '3.10', '3.11', '3.12', '3.13' ] torch-version: [ '2.4.0', '2.5.1', '2.6.0', '2.7.0' ] - cuda-version: [ '12.4.1', '12.8.1' ] + cuda-version: [ '12.8.1' ] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -53,22 +53,9 @@ jobs: exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # Pytorch < 2.5 does not support Python 3.13 - # PyTorch < 2.5 doesn’t support Python 3.13 - torch-version: '2.4.0' python-version: '3.13' - # PyTorch 2.7.0 must only use CUDA 12.8.1 - - torch-version: '2.7.0' - cuda-version: '12.4.1' - - # All other PyTorch (< 2.7.0) must only use CUDA 12.4.1 - - torch-version: '2.4.0' - cuda-version: '12.8.1' - - torch-version: '2.5.1' - cuda-version: '12.8.1' - - torch-version: '2.6.0' - cuda-version: '12.8.1' - steps: - name: Checkout uses: actions/checkout@v4