Skip to content

Commit b20474f

Browse files
LaurentLaurent2916
authored andcommitted
add various torch.dtype test fixtures
1 parent 16714e6 commit b20474f

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

tests/conftest.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import os
22
from pathlib import Path
3+
from typing import Callable
34

45
import torch
5-
from pytest import fixture
6+
from pytest import FixtureRequest, fixture, skip
7+
8+
from refiners.fluxion.utils import str_to_dtype
69

710
PARENT_PATH = Path(__file__).parent
811

@@ -18,6 +21,23 @@ def test_device() -> torch.device:
1821
return torch.device(test_device)
1922

2023

24+
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
25+
@fixture(scope="session", params=params)
26+
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
27+
torch_dtype = str_to_dtype(request.param)
28+
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
29+
skip("bfloat16 is not supported on this test device")
30+
return torch_dtype
31+
32+
return dtype_fixture
33+
34+
35+
test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"])
36+
test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"])
37+
test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"])
38+
test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"])
39+
40+
2141
@fixture(scope="session")
2242
def test_weights_path() -> Path:
2343
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")

0 commit comments

Comments
 (0)