Skip to content

Commit 78a3336

Browse files
author
Laurent
committed
add DINOv2 similarity to compare_images + relax some test constraints
1 parent f45fa7e commit 78a3336

File tree

6 files changed

+75
-34
lines changed

6 files changed

+75
-34
lines changed

tests/e2e/test_diffusion.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ def test_diffusion_std_random_init_bfloat16(
832832
)
833833
predicted_image = sd15.lda.latents_to_image(x)
834834

835-
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16)
835+
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16, min_psnr=30, min_ssim=0.97)
836836

837837

838838
@no_grad()
@@ -1166,7 +1166,7 @@ def test_diffusion_inpainting_float16(
11661166
predicted_image = sd15.lda.latents_to_image(x)
11671167

11681168
# PSNR and SSIM values are large because float16 is even worse than float32.
1169-
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
1169+
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95, min_dinov2=0.96)
11701170

11711171

11721172
@no_grad()
@@ -1245,7 +1245,7 @@ def test_diffusion_controlnet_tile_upscale(
12451245
predicted_image = sd15.lda.latents_to_image(x)
12461246

12471247
# Note: rather large tolerances are used on purpose here (loose comparison with diffusers' output)
1248-
ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75)
1248+
ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75, min_dinov2=0.94)
12491249

12501250

12511251
@no_grad()
@@ -1852,7 +1852,7 @@ def test_diffusion_ella_adapter(
18521852
condition_scale=12,
18531853
)
18541854
predicted_image = sd15.lda.latents_to_image(x)
1855-
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98)
1855+
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=31, min_ssim=0.98)
18561856

18571857

18581858
@no_grad()
@@ -1937,7 +1937,7 @@ def test_diffusion_ip_adapter_multi(
19371937
)
19381938
predicted_image = sd15.lda.decode_latents(x)
19391939

1940-
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi)
1940+
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98)
19411941

19421942

19431943
@no_grad()
@@ -2130,7 +2130,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
21302130
sdxl.lda.to(dtype=torch.float32)
21312131
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
21322132

2133-
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
2133+
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman, min_psnr=43, min_ssim=0.98)
21342134

21352135

21362136
@no_grad()
@@ -2608,11 +2608,11 @@ def test_style_aligned(
26082608

26092609
# tile all images horizontally
26102610
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
2611-
for i in range(len(predicted_images)):
2612-
merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore
2611+
for i, image in enumerate(predicted_images):
2612+
merged_image.paste(image, (1024 * i, 0))
26132613

26142614
# compare against reference image
2615-
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
2615+
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=12, min_ssim=0.39, min_dinov2=0.95)
26162616

26172617

26182618
@no_grad()
@@ -2624,7 +2624,7 @@ def test_multi_upscaler(
26242624
generator = torch.Generator(device=multi_upscaler.device)
26252625
generator.manual_seed(37)
26262626
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
2627-
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
2627+
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=25, min_ssim=0.85, min_dinov2=0.96)
26282628

26292629

26302630
@no_grad()

tests/e2e/test_doc_examples.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_guide_adapting_sdxl_vanilla(
110110
)
111111

112112
predicted_image = sdxl.lda.decode_latents(x)
113-
ensure_similar_images(predicted_image, expected_image)
113+
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
114114

115115

116116
@no_grad()
@@ -152,7 +152,7 @@ def test_guide_adapting_sdxl_single_lora(
152152
)
153153

154154
predicted_image = sdxl.lda.decode_latents(x)
155-
ensure_similar_images(predicted_image, expected_image)
155+
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
156156

157157

158158
@no_grad()
@@ -196,7 +196,7 @@ def test_guide_adapting_sdxl_multiple_loras(
196196
)
197197

198198
predicted_image = sdxl.lda.decode_latents(x)
199-
ensure_similar_images(predicted_image, expected_image)
199+
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
200200

201201

202202
@no_grad()
@@ -256,7 +256,7 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
256256
)
257257

258258
predicted_image = sdxl.lda.decode_latents(x)
259-
ensure_similar_images(predicted_image, expected_image)
259+
ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98)
260260

261261

262262
# We do not (yet) test the last example using T2i-Adapter with Zoe Depth.

tests/e2e/test_lightning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_lightning_base_4step(
9393
)
9494
predicted_image = sdxl.lda.latents_to_image(x)
9595

96-
ensure_similar_images(predicted_image, expected_image)
96+
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
9797

9898

9999
@no_grad()
@@ -144,7 +144,7 @@ def test_lightning_base_1step(
144144
)
145145
predicted_image = sdxl.lda.latents_to_image(x)
146146

147-
ensure_similar_images(predicted_image, expected_image)
147+
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
148148

149149

150150
@no_grad()
@@ -198,4 +198,4 @@ def test_lightning_lora_4step(
198198
)
199199
predicted_image = sdxl.lda.latents_to_image(x)
200200

201-
ensure_similar_images(predicted_image, expected_image)
201+
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)

tests/foundationals/segment_anything/test_hq_sam.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,15 @@ def test_predictor(
238238
assert torch.allclose(
239239
reference_low_res_mask_hq,
240240
refiners_low_res_mask_hq,
241-
atol=4e-3,
241+
atol=1e-2,
242+
)
243+
assert ( # absolute diff in number of pixels
244+
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 10
242245
)
243-
assert (
244-
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 2
245-
) # The diff on the logits above leads to an absolute diff of 2 pixel on the high res masks
246246
assert torch.allclose(
247247
iou_predictions_np,
248248
torch.max(iou_predictions),
249-
atol=1e-5,
249+
atol=1e-4,
250250
)
251251

252252

tests/foundationals/segment_anything/test_sam.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,9 @@ def test_predictor(
317317
for i in range(3):
318318
mask_prediction = masks[i].cpu()
319319
facebook_mask = torch.as_tensor(facebook_masks[i])
320-
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
321-
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
320+
iou = intersection_over_union(mask_prediction, facebook_mask)
321+
assert isclose(iou, 1.0, rel_tol=5e-04), f"iou: {iou}"
322+
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-04)
322323

323324

324325
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:

tests/utils.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,65 @@
1+
from functools import cache
12
from pathlib import Path
3+
from textwrap import dedent
24

3-
import numpy as np
45
import piq # type: ignore
56
import torch
67
import torch.nn as nn
78
from PIL import Image
89
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
910

11+
from refiners.conversion.models import dinov2
12+
from refiners.fluxion.utils import image_to_tensor
13+
from refiners.foundationals.dinov2 import DINOv2_small
1014

11-
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
12-
x1, x2 = (
13-
torch.tensor(np.array(x).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0 for x in (img_1, img_2)
15+
16+
@cache
17+
def get_small_dinov2_model() -> DINOv2_small:
18+
model = DINOv2_small()
19+
model.load_from_safetensors(
20+
dinov2.small.converted.local_path
21+
if dinov2.small.converted.local_path.exists()
22+
else dinov2.small.converted.hf_cache_path
23+
)
24+
return model
25+
26+
27+
def compare_images(
28+
img_1: Image.Image,
29+
img_2: Image.Image,
30+
) -> tuple[float, float, float]:
31+
x1 = image_to_tensor(img_1)
32+
x2 = image_to_tensor(img_2)
33+
34+
psnr = piq.psnr(x1, x2) # type: ignore
35+
ssim = piq.ssim(x1, x2) # type: ignore
36+
37+
dinov2_model = get_small_dinov2_model()
38+
dinov2 = torch.nn.functional.cosine_similarity(
39+
dinov2_model(x1)[:, 0],
40+
dinov2_model(x2)[:, 0],
1441
)
15-
return (piq.psnr(x1, x2), piq.ssim(x1, x2).item()) # type: ignore
1642

43+
return psnr.item(), ssim.item(), dinov2.item() # type: ignore
1744

18-
def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int = 45, min_ssim: float = 0.99):
19-
psnr, ssim = compare_images(img_1, img_2)
20-
assert (psnr >= min_psnr) and (
21-
ssim >= min_ssim
22-
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"
45+
46+
def ensure_similar_images(
47+
img_1: Image.Image,
48+
img_2: Image.Image,
49+
min_psnr: int = 45,
50+
min_ssim: float = 0.99,
51+
min_dinov2: float = 0.99,
52+
) -> None:
53+
psnr, ssim, dinov2 = compare_images(img_1, img_2)
54+
if (psnr < min_psnr) or (ssim < min_ssim) or (dinov2 < min_dinov2):
55+
raise AssertionError(
56+
dedent(f"""
57+
Images are not similar enough!
58+
- PSNR: {psnr:08.05f} (required at least {min_psnr:08.05f})
59+
- SSIM: {ssim:08.06f} (required at least {min_ssim:08.06f})
60+
- DINO: {dinov2:08.06f} (required at least {min_dinov2:08.06f})
61+
""").strip()
62+
)
2363

2464

2565
class T5TextEmbedder(nn.Module):

0 commit comments

Comments
 (0)