Skip to content

Commit 83a5c90

Browse files
authored
Fix custom fwd and bwd for older PyTorch versions (#596)
* fix custom fwd and bwd for older torch versions * forgot to push the new utils file.. * use partial to fix kwargs passing with dec
1 parent bc84fb1 commit 83a5c90

File tree

5 files changed

+31
-12
lines changed

5 files changed

+31
-12
lines changed

mamba_ssm/distributed/tensor_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch.nn as nn
77
import torch.nn.functional as F
88
from torch import Tensor
9-
from torch.amp import custom_bwd, custom_fwd
109
from torch.distributed import ProcessGroup
10+
from mamba_ssm.utils.torch import custom_bwd, custom_fwd
1111

1212
from einops import rearrange
1313

@@ -22,7 +22,7 @@
2222

2323
class ParallelLinearFunc(torch.autograd.Function):
2424
@staticmethod
25-
@custom_fwd(device_type="cuda")
25+
@custom_fwd
2626
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
2727
"""
2828
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
@@ -58,7 +58,7 @@ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
5858
return output
5959

6060
@staticmethod
61-
@custom_bwd(device_type="cuda")
61+
@custom_bwd
6262
def backward(ctx, grad_output):
6363
grad_output = grad_output.contiguous()
6464
process_group = ctx.process_group

mamba_ssm/ops/selective_scan_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
import torch.nn.functional as F
5-
from torch.amp import custom_bwd, custom_fwd
5+
from mamba_ssm.utils.torch import custom_bwd, custom_fwd
66

77
from einops import rearrange, repeat
88

@@ -160,7 +160,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta
160160
class MambaInnerFn(torch.autograd.Function):
161161

162162
@staticmethod
163-
@custom_fwd(device_type="cuda")
163+
@custom_fwd
164164
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
165165
out_proj_weight, out_proj_bias,
166166
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
@@ -236,7 +236,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
236236
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
237237

238238
@staticmethod
239-
@custom_bwd(device_type="cuda")
239+
@custom_bwd
240240
def backward(ctx, dout):
241241
# dout: (batch, seqlen, dim)
242242
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."

mamba_ssm/ops/triton/layer_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313
import torch.nn.functional as F
14-
from torch.amp import custom_fwd, custom_bwd
14+
from mamba_ssm.utils.torch import custom_bwd, custom_fwd
1515

1616
import triton
1717
import triton.language as tl
@@ -982,7 +982,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
982982

983983
class LayerNormLinearFn(torch.autograd.Function):
984984
@staticmethod
985-
@custom_fwd(device_type="cuda")
985+
@custom_fwd
986986
def forward(
987987
ctx,
988988
x,
@@ -1041,7 +1041,7 @@ def forward(
10411041
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
10421042

10431043
@staticmethod
1044-
@custom_bwd(device_type="cuda")
1044+
@custom_bwd
10451045
def backward(ctx, dout, *args):
10461046
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
10471047
dout = dout.reshape(-1, dout.shape[-1])

mamba_ssm/ops/triton/ssd_combined.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn.functional as F
1313
from torch import Tensor
14-
from torch.amp import custom_bwd, custom_fwd
14+
from mamba_ssm.utils.torch import custom_bwd, custom_fwd
1515

1616
import triton
1717
import triton.language as tl
@@ -754,7 +754,7 @@ def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=
754754
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
755755

756756
@staticmethod
757-
@custom_fwd(device_type="cuda")
757+
@custom_fwd
758758
def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
759759
rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
760760
ngroups=1, norm_before_gate=True):
@@ -832,7 +832,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
832832
return out if not return_final_states else (out, final_states)
833833

834834
@staticmethod
835-
@custom_bwd(device_type="cuda")
835+
@custom_bwd
836836
def backward(ctx, dout, *args):
837837
zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
838838
dfinal_states = args[0] if ctx.return_final_states else None

mamba_ssm/utils/torch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
from functools import partial
3+
4+
5+
def custom_amp_decorator(dec, cuda_amp_deprecated):
6+
def decorator(func):
7+
return dec(func) if not cuda_amp_deprecated else partial(dec, func, device_type="cuda")
8+
return decorator
9+
10+
11+
if hasattr(torch.amp, "custom_fwd"):
12+
deprecated = True
13+
from torch.amp import custom_fwd, custom_bwd
14+
else:
15+
deprecated = False
16+
from torch.cuda.amp import custom_fwd, custom_bwd
17+
18+
custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
19+
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)

0 commit comments

Comments
 (0)