Skip to content

Commit f3d2b6c

Browse files
LaurentLaurent2916
authored andcommitted
modify some foundational tests to also test in float16 and bfloat16
1 parent b20474f commit f3d2b6c

File tree

6 files changed

+170
-45
lines changed

6 files changed

+170
-45
lines changed

tests/foundationals/clip/test_image_encoder.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010

1111

1212
@pytest.fixture(scope="module")
13-
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH:
13+
def our_encoder(
14+
test_weights_path: Path,
15+
test_device: torch.device,
16+
test_dtype_fp32_bf16_fp16: torch.dtype,
17+
) -> CLIPImageEncoderH:
1418
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
1519
if not weights.is_file():
1620
warn(f"could not find weights at {weights}, skipping")
1721
pytest.skip(allow_module_level=True)
18-
encoder = CLIPImageEncoderH(device=test_device)
22+
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
1923
tensors = load_from_safetensors(weights)
2024
encoder.load_state_dict(tensors)
2125
return encoder
@@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):
3135

3236

3337
@pytest.fixture(scope="module")
34-
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection:
35-
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore
36-
test_device # type: ignore
37-
)
38+
def ref_encoder(
39+
stabilityai_unclip_weights_path: Path,
40+
test_device: torch.device,
41+
test_dtype_fp32_bf16_fp16: torch.dtype,
42+
) -> CLIPVisionModelWithProjection:
43+
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
44+
stabilityai_unclip_weights_path,
45+
subfolder="image_encoder",
46+
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
3847

3948

49+
@no_grad()
50+
@pytest.mark.flaky(reruns=3)
4051
def test_encoder(
4152
ref_encoder: CLIPVisionModelWithProjection,
4253
our_encoder: CLIPImageEncoderH,
43-
test_device: torch.device,
4454
):
45-
x = torch.randn(1, 3, 224, 224).to(test_device)
55+
assert ref_encoder.dtype == our_encoder.dtype
56+
assert ref_encoder.device == our_encoder.device
57+
x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device)
4658

47-
with no_grad():
48-
ref_embeddings = ref_encoder(x).image_embeds
49-
our_embeddings = our_encoder(x)
59+
ref_embeddings = ref_encoder(x).image_embeds
60+
our_embeddings = our_encoder(x)
5061

5162
assert ref_embeddings.shape == (1, 1024)
5263
assert our_embeddings.shape == (1, 1024)
5364

54-
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
65+
assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05)

tests/foundationals/clip/test_text_encoder.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@
3030

3131

3232
@pytest.fixture(scope="module")
33-
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL:
33+
def our_encoder(
34+
test_weights_path: Path,
35+
test_device: torch.device,
36+
test_dtype_fp32_fp16: torch.dtype,
37+
) -> CLIPTextEncoderL:
3438
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
3539
if not weights.is_file():
3640
warn(f"could not find weights at {weights}, skipping")
3741
pytest.skip(allow_module_level=True)
38-
encoder = CLIPTextEncoderL(device=test_device)
3942
tensors = load_from_safetensors(weights)
43+
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
4044
encoder.load_state_dict(tensors)
4145
return encoder
4246

@@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
5660

5761

5862
@pytest.fixture(scope="module")
59-
def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel:
60-
return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore
63+
def ref_encoder(
64+
runwayml_weights_path: Path,
65+
test_device: torch.device,
66+
test_dtype_fp32_fp16: torch.dtype,
67+
) -> transformers.CLIPTextModel:
68+
return transformers.CLIPTextModel.from_pretrained( # type: ignore
69+
runwayml_weights_path,
70+
subfolder="text_encoder",
71+
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore
6172

6273

6374
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
@@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest):
7081
return long_prompt if request.param == "<long prompt>" else request.param
7182

7283

84+
@no_grad()
7385
def test_encoder(
7486
prompt: str,
7587
ref_tokenizer: transformers.CLIPTokenizer,
7688
ref_encoder: transformers.CLIPTextModel,
7789
our_encoder: CLIPTextEncoderL,
78-
test_device: torch.device,
7990
):
8091
ref_tokens = ref_tokenizer( # type: ignore
8192
prompt,
@@ -89,18 +100,16 @@ def test_encoder(
89100
our_tokens = tokenizer(prompt)
90101
assert torch.equal(our_tokens, ref_tokens)
91102

92-
with no_grad():
93-
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
94-
our_embeddings = our_encoder(prompt)
103+
ref_embeddings = ref_encoder(ref_tokens.to(device=ref_encoder.device))[0]
104+
our_embeddings = our_encoder(prompt)
95105

96106
assert ref_embeddings.shape == (1, 77, 768)
97107
assert our_embeddings.shape == (1, 77, 768)
98108

99109
# FG-336 - Not strictly equal because we do not use the same implementation
100110
# of self-attention. We use `scaled_dot_product_attention` which can have
101-
# numerical differences depending on the backend.
102-
# Also we use FP16 weights.
103-
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
111+
# numerical differences depending on the backend. Also we use FP16 weights.
112+
torch.testing.assert_close(our_embeddings, ref_embeddings, atol=0.035, rtol=0.0)
104113

105114

106115
def test_list_string_tokenizer(

tests/foundationals/dinov2/test_dinov2.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_dinov2_facebook_weights(
109109
) -> None:
110110
manual_seed(2)
111111
input_data = torch.randn(
112-
(1, 3, resolution, resolution),
112+
size=(1, 3, resolution, resolution),
113113
device=test_device,
114114
)
115115

@@ -129,27 +129,28 @@ def test_dinov2_facebook_weights(
129129

130130

131131
@no_grad()
132-
def test_dinov2_float16(
132+
def test_dinov2(
133133
resolution: int,
134+
test_dtype_fp32_bf16_fp16: torch.dtype,
134135
test_device: torch.device,
135136
) -> None:
136137
if test_device.type == "cpu":
137138
warn("not running on CPU, skipping")
138139
pytest.skip()
139140

140-
model = DINOv2_small(device=test_device, dtype=torch.float16)
141+
model = DINOv2_small(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
141142

142143
manual_seed(2)
143144
input_data = torch.randn(
144-
(1, 3, resolution, resolution),
145+
size=(1, 3, resolution, resolution),
145146
device=test_device,
146-
dtype=torch.float16,
147+
dtype=test_dtype_fp32_bf16_fp16,
147148
)
148149

149150
output = model(input_data)
150151
sequence_length = (resolution // model.patch_size) ** 2 + 1
151152
assert output.shape == (1, sequence_length, model.embedding_dim)
152-
assert output.dtype == torch.float16
153+
assert output.dtype == test_dtype_fp32_bf16_fp16
153154

154155

155156
@no_grad()
@@ -162,7 +163,7 @@ def test_dinov2_batch_size(
162163
batch_size = 4
163164
manual_seed(2)
164165
input_data = torch.randn(
165-
(batch_size, 3, resolution, resolution),
166+
size=(batch_size, 3, resolution, resolution),
166167
device=test_device,
167168
)
168169

tests/foundationals/latent_diffusion/test_auto_encoder.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,46 @@
66
from PIL import Image
77
from tests.utils import ensure_similar_images
88

9-
from refiners.fluxion.utils import load_from_safetensors, no_grad
10-
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
9+
from refiners.fluxion.utils import no_grad
10+
from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder
1111

1212

1313
@pytest.fixture(scope="module")
1414
def ref_path() -> Path:
1515
return Path(__file__).parent / "test_auto_encoder_ref"
1616

1717

18-
@pytest.fixture(scope="module")
19-
def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
20-
lda_weights = test_weights_path / "lda.safetensors"
21-
if not lda_weights.is_file():
22-
warn(f"could not find weights at {lda_weights}, skipping")
23-
pytest.skip(allow_module_level=True)
24-
encoder = LatentDiffusionAutoencoder(device=test_device)
25-
tensors = load_from_safetensors(lda_weights)
26-
encoder.load_state_dict(tensors)
27-
return encoder
18+
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
19+
def lda(
20+
request: pytest.FixtureRequest,
21+
test_weights_path: Path,
22+
test_dtype_fp32_bf16_fp16: torch.dtype,
23+
test_device: torch.device,
24+
) -> LatentDiffusionAutoencoder:
25+
model_version = request.param
26+
match (model_version, test_dtype_fp32_bf16_fp16):
27+
case ("SD1.5", _):
28+
weight_path = test_weights_path / "lda.safetensors"
29+
if not weight_path.is_file():
30+
warn(f"could not find weights at {weight_path}, skipping")
31+
pytest.skip(allow_module_level=True)
32+
model = SD1Autoencoder().load_from_safetensors(weight_path)
33+
case ("SDXL", torch.float16):
34+
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
35+
if not weight_path.is_file():
36+
warn(f"could not find weights at {weight_path}, skipping")
37+
pytest.skip(allow_module_level=True)
38+
model = SDXLAutoencoder().load_from_safetensors(weight_path)
39+
case ("SDXL", _):
40+
weight_path = test_weights_path / "sdxl-lda.safetensors"
41+
if not weight_path.is_file():
42+
warn(f"could not find weights at {weight_path}, skipping")
43+
pytest.skip(allow_module_level=True)
44+
model = SDXLAutoencoder().load_from_safetensors(weight_path)
45+
case _:
46+
raise ValueError(f"Unknown model version: {model_version}")
47+
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
48+
return model
2849

2950

3051
@pytest.fixture(scope="module")
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
from PIL import Image
3+
4+
from refiners.fluxion.utils import manual_seed, no_grad
5+
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL
6+
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
7+
8+
9+
@no_grad()
10+
def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
11+
manual_seed(2)
12+
latents_0 = LatentDiffusionModel.sample_noise(
13+
size=(1, 4, 64, 64),
14+
device=test_device,
15+
dtype=test_dtype_fp32_bf16_fp16,
16+
)
17+
manual_seed(2)
18+
latents_1 = LatentDiffusionModel.sample_noise(
19+
size=(1, 4, 64, 64),
20+
offset_noise=0.0, # should be no-op
21+
device=test_device,
22+
dtype=test_dtype_fp32_bf16_fp16,
23+
)
24+
25+
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
26+
27+
28+
@no_grad()
29+
def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
30+
sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
31+
32+
# prepare inputs
33+
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
34+
text_embedding = sd.compute_clip_text_embedding("")
35+
36+
# run the pipeline of models, for a single step
37+
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
38+
39+
assert output.shape == (1, 4, 64, 64)
40+
41+
42+
@no_grad()
43+
def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
44+
sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
45+
46+
# prepare inputs
47+
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
48+
target_image = Image.new("RGB", (512, 512))
49+
mask = Image.new("L", (512, 512))
50+
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
51+
text_embedding = sd.compute_clip_text_embedding("")
52+
53+
# run the pipeline of models, for a single step
54+
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
55+
56+
assert output.shape == (1, 4, 64, 64)
57+
58+
59+
@no_grad()
60+
def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
61+
sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
62+
63+
# prepare inputs
64+
latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
65+
text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("")
66+
time_ids = sd.default_time_ids
67+
68+
# run the pipeline of models, for a single step
69+
output = sd(
70+
latent_noise,
71+
step=0,
72+
clip_text_embedding=text_embedding,
73+
pooled_text_embedding=pooled_text_embedding,
74+
time_ids=time_ids,
75+
)
76+
77+
assert output.shape == (1, 4, 128, 128)

tests/foundationals/latent_diffusion/test_sd15_unet.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77

88

99
@pytest.fixture(scope="module")
10-
def refiners_sd15_unet(test_device: torch.device) -> SD1UNet:
11-
unet = SD1UNet(in_channels=4, device=test_device)
12-
return unet
10+
def refiners_sd15_unet(
11+
test_device: torch.device,
12+
test_dtype_fp32_bf16_fp16: torch.dtype,
13+
) -> SD1UNet:
14+
return SD1UNet(
15+
in_channels=4,
16+
device=test_device,
17+
dtype=test_dtype_fp32_bf16_fp16,
18+
)
1319

1420

1521
def test_unet_context_flush(refiners_sd15_unet: SD1UNet):

0 commit comments

Comments
 (0)