Skip to content

Commit eef07ec

Browse files
committed
properly check for bfloat16
- we check only the test device, not the machine in general - we don't want emulated bfloat16 (e.g. CPU)
1 parent f3d2b6c commit eef07ec

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

src/refiners/fluxion/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,13 @@ def str_to_dtype(dtype: str) -> torch.dtype:
304304
return torch.bool
305305
case _:
306306
raise ValueError(f"Unknown dtype: {dtype}")
307+
308+
309+
def device_has_bfloat16(device: torch.device) -> bool:
310+
cuda_version = cast(str | None, torch.version.cuda) # type: ignore
311+
if cuda_version is None or int(cuda_version.split(".")[0]):
312+
return False
313+
try:
314+
return torch.cuda.get_device_properties(device).major >= 8 # type: ignore
315+
except ValueError:
316+
return False

tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from pytest import FixtureRequest, fixture, skip
77

8-
from refiners.fluxion.utils import str_to_dtype
8+
from refiners.fluxion.utils import device_has_bfloat16, str_to_dtype
99

1010
PARENT_PATH = Path(__file__).parent
1111

@@ -21,11 +21,11 @@ def test_device() -> torch.device:
2121
return torch.device(test_device)
2222

2323

24-
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
24+
def dtype_fixture_factory(params: list[str]) -> Callable[[torch.device, FixtureRequest], torch.dtype]:
2525
@fixture(scope="session", params=params)
26-
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
26+
def dtype_fixture(test_device: torch.device, request: FixtureRequest) -> torch.dtype:
2727
torch_dtype = str_to_dtype(request.param)
28-
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
28+
if torch_dtype == torch.bfloat16 and not device_has_bfloat16(test_device):
2929
skip("bfloat16 is not supported on this test device")
3030
return torch_dtype
3131

0 commit comments

Comments
 (0)