File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change 1
1
import torch
2
2
from functools import partial
3
+ from typing import Callable
3
4
4
-
5
- def custom_amp_decorator (dec , cuda_amp_deprecated ):
6
- def decorator (func ):
7
- return dec (func ) if not cuda_amp_deprecated else partial (dec , func , device_type = "cuda" )
5
+ def custom_amp_decorator (dec : Callable , cuda_amp_deprecated : bool ):
6
+ def decorator (* args , ** kwargs ):
7
+ if cuda_amp_deprecated :
8
+ kwargs ["device_type" ] = "cuda"
9
+ return dec (* args , ** kwargs )
8
10
return decorator
9
11
10
12
11
- if hasattr (torch .amp , "custom_fwd" ):
13
+ if hasattr (torch .amp , "custom_fwd" ): # type: ignore[attr-defined]
12
14
deprecated = True
13
- from torch .amp import custom_fwd , custom_bwd
15
+ from torch .amp import custom_fwd , custom_bwd # type: ignore[attr-defined]
14
16
else :
15
17
deprecated = False
16
18
from torch .cuda .amp import custom_fwd , custom_bwd
You can’t perform that action at this time.
0 commit comments