7
7
from tests .utils import ensure_similar_images
8
8
9
9
from refiners .fluxion .utils import no_grad
10
- from refiners .foundationals .latent_diffusion .auto_encoder import LatentDiffusionAutoencoder
10
+ from refiners .foundationals .latent_diffusion import (
11
+ LatentDiffusionAutoencoder ,
12
+ SD1Autoencoder ,
13
+ SDXLAutoencoder ,
14
+ )
11
15
12
16
13
17
@pytest .fixture (scope = "module" )
@@ -16,25 +20,32 @@ def sample_image() -> Image.Image:
16
20
if not test_image .is_file ():
17
21
warn (f"could not reference image at { test_image } , skipping" )
18
22
pytest .skip (allow_module_level = True )
19
- img = Image .open (test_image ) # type: ignore
23
+ img = Image .open (test_image )
20
24
assert img .size == (512 , 512 )
21
25
return img
22
26
23
27
24
- @pytest .fixture (scope = "module" )
28
+ @pytest .fixture (scope = "module" , params = [ "SD1.5" , "SDXL" ] )
25
29
def autoencoder (
26
- refiners_autoencoder : LatentDiffusionAutoencoder ,
30
+ request : pytest .FixtureRequest ,
31
+ refiners_sd15_autoencoder : SD1Autoencoder ,
32
+ refiners_sdxl_autoencoder : SDXLAutoencoder ,
27
33
test_device : torch .device ,
34
+ test_dtype_fp32_bf16_fp16 : torch .dtype ,
28
35
) -> LatentDiffusionAutoencoder :
29
- return refiners_autoencoder .to (test_device )
36
+ model_version = request .param
37
+ if model_version == "SDXL" and test_dtype_fp32_bf16_fp16 == torch .float16 :
38
+ pytest .skip ("SDXL autoencoder does not support float16" )
39
+ ae = refiners_sd15_autoencoder if model_version == "SD1.5" else refiners_sdxl_autoencoder
40
+ return ae .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
30
41
31
42
32
43
@no_grad ()
33
44
def test_encode_decode_image (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
34
45
encoded = autoencoder .image_to_latents (sample_image )
35
46
decoded = autoencoder .latents_to_image (encoded )
36
47
37
- assert decoded .mode == "RGB" # type: ignore
48
+ assert decoded .mode == "RGB"
38
49
39
50
# Ensure no saturation. The green channel (band = 1) must not max out.
40
51
assert max (iter (decoded .getdata (band = 1 ))) < 255 # type: ignore
@@ -53,7 +64,7 @@ def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_im
53
64
54
65
@no_grad ()
55
66
def test_tiled_autoencoder (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
56
- sample_image = sample_image .resize ((2048 , 2048 )) # type: ignore
67
+ sample_image = sample_image .resize ((2048 , 2048 ))
57
68
58
69
with autoencoder .tiled_inference (sample_image , tile_size = (512 , 512 )):
59
70
encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -64,7 +75,7 @@ def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image
64
75
65
76
@no_grad ()
66
77
def test_tiled_autoencoder_rectangular_tiles (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
67
- sample_image = sample_image .resize ((2048 , 2048 )) # type: ignore
78
+ sample_image = sample_image .resize ((2048 , 2048 ))
68
79
69
80
with autoencoder .tiled_inference (sample_image , tile_size = (512 , 1024 )):
70
81
encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -75,7 +86,7 @@ def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoenc
75
86
76
87
@no_grad ()
77
88
def test_tiled_autoencoder_large_tile (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
78
- sample_image = sample_image .resize ((1024 , 1024 )) # type: ignore
89
+ sample_image = sample_image .resize ((1024 , 1024 ))
79
90
80
91
with autoencoder .tiled_inference (sample_image , tile_size = (2048 , 2048 )):
81
92
encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -87,7 +98,7 @@ def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, s
87
98
@no_grad ()
88
99
def test_tiled_autoencoder_rectangular_image (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
89
100
sample_image = sample_image .crop ((0 , 0 , 300 , 500 ))
90
- sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 )) # type: ignore
101
+ sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 ))
91
102
92
103
with autoencoder .tiled_inference (sample_image , tile_size = (512 , 512 )):
93
104
encoded = autoencoder .tiled_image_to_latents (sample_image )
0 commit comments