Skip to content

Commit e708c31

Browse files
committed
fix use of dtypes in autoencoder tests
1 parent 346a0fb commit e708c31

File tree

2 files changed

+21
-30
lines changed

2 files changed

+21
-30
lines changed

tests/foundationals/latent_diffusion/conftest.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from pathlib import Path
22

33
import pytest
4-
import torch
54
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
65
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
76
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
@@ -93,25 +92,6 @@ def refiners_sdxl(
9392
)
9493

9594

96-
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
97-
def refiners_autoencoder(
98-
request: pytest.FixtureRequest,
99-
refiners_sd15_autoencoder: SD1Autoencoder,
100-
refiners_sdxl_autoencoder: SDXLAutoencoder,
101-
test_dtype_fp32_bf16_fp16: torch.dtype,
102-
) -> SD1Autoencoder | SDXLAutoencoder:
103-
model_version = request.param
104-
match (model_version, test_dtype_fp32_bf16_fp16):
105-
case ("SD1.5", _):
106-
return refiners_sd15_autoencoder
107-
case ("SDXL", torch.float16):
108-
return refiners_sdxl_autoencoder
109-
case ("SDXL", _):
110-
return refiners_sdxl_autoencoder
111-
case _:
112-
raise ValueError(f"Unknown model version: {model_version}")
113-
114-
11595
@pytest.fixture(scope="module")
11696
def diffusers_sd15_pipeline(
11797
sd15_diffusers_runwayml_path: str,

tests/foundationals/latent_diffusion/test_autoencoders.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from tests.utils import ensure_similar_images
88

99
from refiners.fluxion.utils import no_grad
10-
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
10+
from refiners.foundationals.latent_diffusion import (
11+
LatentDiffusionAutoencoder,
12+
SD1Autoencoder,
13+
SDXLAutoencoder,
14+
)
1115

1216

1317
@pytest.fixture(scope="module")
@@ -16,25 +20,32 @@ def sample_image() -> Image.Image:
1620
if not test_image.is_file():
1721
warn(f"could not reference image at {test_image}, skipping")
1822
pytest.skip(allow_module_level=True)
19-
img = Image.open(test_image) # type: ignore
23+
img = Image.open(test_image)
2024
assert img.size == (512, 512)
2125
return img
2226

2327

24-
@pytest.fixture(scope="module")
28+
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
2529
def autoencoder(
26-
refiners_autoencoder: LatentDiffusionAutoencoder,
30+
request: pytest.FixtureRequest,
31+
refiners_sd15_autoencoder: SD1Autoencoder,
32+
refiners_sdxl_autoencoder: SDXLAutoencoder,
2733
test_device: torch.device,
34+
test_dtype_fp32_bf16_fp16: torch.dtype,
2835
) -> LatentDiffusionAutoencoder:
29-
return refiners_autoencoder.to(test_device)
36+
model_version = request.param
37+
if model_version == "SDXL" and test_dtype_fp32_bf16_fp16 == torch.float16:
38+
pytest.skip("SDXL autoencoder does not support float16")
39+
ae = refiners_sd15_autoencoder if model_version == "SD1.5" else refiners_sdxl_autoencoder
40+
return ae.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
3041

3142

3243
@no_grad()
3344
def test_encode_decode_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
3445
encoded = autoencoder.image_to_latents(sample_image)
3546
decoded = autoencoder.latents_to_image(encoded)
3647

37-
assert decoded.mode == "RGB" # type: ignore
48+
assert decoded.mode == "RGB"
3849

3950
# Ensure no saturation. The green channel (band = 1) must not max out.
4051
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
@@ -53,7 +64,7 @@ def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_im
5364

5465
@no_grad()
5566
def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
56-
sample_image = sample_image.resize((2048, 2048)) # type: ignore
67+
sample_image = sample_image.resize((2048, 2048))
5768

5869
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
5970
encoded = autoencoder.tiled_image_to_latents(sample_image)
@@ -64,7 +75,7 @@ def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image
6475

6576
@no_grad()
6677
def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
67-
sample_image = sample_image.resize((2048, 2048)) # type: ignore
78+
sample_image = sample_image.resize((2048, 2048))
6879

6980
with autoencoder.tiled_inference(sample_image, tile_size=(512, 1024)):
7081
encoded = autoencoder.tiled_image_to_latents(sample_image)
@@ -75,7 +86,7 @@ def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoenc
7586

7687
@no_grad()
7788
def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
78-
sample_image = sample_image.resize((1024, 1024)) # type: ignore
89+
sample_image = sample_image.resize((1024, 1024))
7990

8091
with autoencoder.tiled_inference(sample_image, tile_size=(2048, 2048)):
8192
encoded = autoencoder.tiled_image_to_latents(sample_image)
@@ -87,7 +98,7 @@ def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, s
8798
@no_grad()
8899
def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
89100
sample_image = sample_image.crop((0, 0, 300, 500))
90-
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
101+
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4))
91102

92103
with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
93104
encoded = autoencoder.tiled_image_to_latents(sample_image)

0 commit comments

Comments
 (0)