File tree Expand file tree Collapse file tree 2 files changed +14
-4
lines changed Expand file tree Collapse file tree 2 files changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -304,3 +304,13 @@ def str_to_dtype(dtype: str) -> torch.dtype:
304
304
return torch .bool
305
305
case _:
306
306
raise ValueError (f"Unknown dtype: { dtype } " )
307
+
308
+
309
+ def device_has_bfloat16 (device : torch .device ) -> bool :
310
+ cuda_version = cast (str | None , torch .version .cuda ) # type: ignore
311
+ if cuda_version is None or int (cuda_version .split ("." )[0 ]):
312
+ return False
313
+ try :
314
+ return torch .cuda .get_device_properties (device ).major >= 8 # type: ignore
315
+ except ValueError :
316
+ return False
Original file line number Diff line number Diff line change 5
5
import torch
6
6
from pytest import FixtureRequest , fixture , skip
7
7
8
- from refiners .fluxion .utils import str_to_dtype
8
+ from refiners .fluxion .utils import device_has_bfloat16 , str_to_dtype
9
9
10
10
PARENT_PATH = Path (__file__ ).parent
11
11
@@ -21,11 +21,11 @@ def test_device() -> torch.device:
21
21
return torch .device (test_device )
22
22
23
23
24
- def dtype_fixture_factory (params : list [str ]) -> Callable [[FixtureRequest ], torch .dtype ]:
24
+ def dtype_fixture_factory (params : list [str ]) -> Callable [[torch . device , FixtureRequest ], torch .dtype ]:
25
25
@fixture (scope = "session" , params = params )
26
- def dtype_fixture (request : FixtureRequest ) -> torch .dtype :
26
+ def dtype_fixture (test_device : torch . device , request : FixtureRequest ) -> torch .dtype :
27
27
torch_dtype = str_to_dtype (request .param )
28
- if torch_dtype == torch .bfloat16 and not torch . cuda . is_bf16_supported ( ):
28
+ if torch_dtype == torch .bfloat16 and not device_has_bfloat16 ( test_device ):
29
29
skip ("bfloat16 is not supported on this test device" )
30
30
return torch_dtype
31
31
You can’t perform that action at this time.
0 commit comments