Skip to content

Commit 2017c98

Browse files
authored
Fix custom fwd and bwd for older PyTorch versions (#608)
#596 (comment)
1 parent 83a5c90 commit 2017c98

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

mamba_ssm/utils/torch.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import torch
22
from functools import partial
3+
from typing import Callable
34

4-
5-
def custom_amp_decorator(dec, cuda_amp_deprecated):
6-
def decorator(func):
7-
return dec(func) if not cuda_amp_deprecated else partial(dec, func, device_type="cuda")
5+
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6+
def decorator(*args, **kwargs):
7+
if cuda_amp_deprecated:
8+
kwargs["device_type"] = "cuda"
9+
return dec(*args, **kwargs)
810
return decorator
911

1012

11-
if hasattr(torch.amp, "custom_fwd"):
13+
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
1214
deprecated = True
13-
from torch.amp import custom_fwd, custom_bwd
15+
from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
1416
else:
1517
deprecated = False
1618
from torch.cuda.amp import custom_fwd, custom_bwd

0 commit comments

Comments
 (0)