File tree Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Original file line number Diff line number Diff line change 1
1
import os
2
2
from pathlib import Path
3
+ from typing import Callable
3
4
4
5
import torch
5
- from pytest import fixture
6
+ from pytest import FixtureRequest , fixture , skip
7
+
8
+ from refiners .fluxion .utils import str_to_dtype
6
9
7
10
PARENT_PATH = Path (__file__ ).parent
8
11
@@ -18,6 +21,23 @@ def test_device() -> torch.device:
18
21
return torch .device (test_device )
19
22
20
23
24
+ def dtype_fixture_factory (params : list [str ]) -> Callable [[FixtureRequest ], torch .dtype ]:
25
+ @fixture (scope = "session" , params = params )
26
+ def dtype_fixture (request : FixtureRequest ) -> torch .dtype :
27
+ torch_dtype = str_to_dtype (request .param )
28
+ if torch_dtype == torch .bfloat16 and not torch .cuda .is_bf16_supported ():
29
+ skip ("bfloat16 is not supported on this test device" )
30
+ return torch_dtype
31
+
32
+ return dtype_fixture
33
+
34
+
35
+ test_dtype_fp32_bf16_fp16 = dtype_fixture_factory (["float32" , "bfloat16" , "float16" ])
36
+ test_dtype_fp32_fp16 = dtype_fixture_factory (["float32" , "float16" ])
37
+ test_dtype_fp32_bf16 = dtype_fixture_factory (["float32" , "bfloat16" ])
38
+ test_dtype_fp16_bf16 = dtype_fixture_factory (["float16" , "bfloat16" ])
39
+
40
+
21
41
@fixture (scope = "session" )
22
42
def test_weights_path () -> Path :
23
43
from_env = os .getenv ("REFINERS_TEST_WEIGHTS_DIR" )
You can’t perform that action at this time.
0 commit comments