Skip to content

Commit 4595133

Browse files
author
Laurent
committed
move some tests into the adapters test folder
1 parent a51d695 commit 4595133

File tree

5 files changed

+35
-20
lines changed

5 files changed

+35
-20
lines changed
File renamed without changes.
Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Iterator
2-
31
import pytest
42
import torch
53

@@ -10,48 +8,63 @@
108
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
119

1210

13-
@pytest.fixture(scope="module", params=[True, False])
14-
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]:
15-
xl: bool = request.param
16-
unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4)
17-
yield unet
11+
@pytest.fixture(scope="module")
12+
def unet(
13+
refiners_unet: SD1UNet | SDXLUNet,
14+
) -> SD1UNet | SDXLUNet:
15+
return refiners_unet
1816

1917

20-
def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
18+
def test_inject_eject_freeu(
19+
unet: SD1UNet | SDXLUNet,
20+
) -> None:
21+
initial_repr = repr(unet)
2122
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])
2223

23-
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
24+
assert unet.parent is None
25+
assert unet.find(FreeUResidualConcatenator) is None
26+
assert repr(unet) == initial_repr
27+
28+
freeu.inject()
29+
assert unet.parent is not None
30+
assert unet.find(FreeUResidualConcatenator) is not None
31+
assert repr(unet) != initial_repr
2432

25-
with pytest.raises(AssertionError) as exc:
26-
freeu.eject()
27-
assert "could not find" in str(exc.value)
33+
freeu.eject()
34+
assert unet.parent is None
35+
assert unet.find(FreeUResidualConcatenator) is None
36+
assert repr(unet) == initial_repr
2837

2938
freeu.inject()
30-
assert len(list(unet.walk(FreeUResidualConcatenator))) == 2
39+
assert unet.parent is not None
40+
assert unet.find(FreeUResidualConcatenator) is not None
41+
assert repr(unet) != initial_repr
3142

3243
freeu.eject()
33-
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
44+
assert unet.parent is None
45+
assert unet.find(FreeUResidualConcatenator) is None
46+
assert repr(unet) == initial_repr
3447

3548

3649
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
3750
num_blocks = len(unet.layer("UpBlocks", Chain))
38-
3951
with pytest.raises(AssertionError):
4052
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
4153

4254

4355
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
4456
with pytest.raises(AssertionError):
4557
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
58+
with pytest.raises(AssertionError):
59+
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2, 1.2], skip_scales=[0.9, 0.9])
4660

4761

48-
def test_freeu_identity_scales() -> None:
62+
def test_freeu_identity_scales(unet: SD1UNet | SDXLUNet) -> None:
4963
manual_seed(0)
50-
text_embedding = torch.randn(1, 77, 768)
51-
timestep = torch.randint(0, 999, size=(1, 1))
52-
x = torch.randn(1, 4, 32, 32)
64+
text_embedding = torch.randn(1, 77, 768, dtype=unet.dtype, device=unet.device)
65+
timestep = torch.randint(0, 999, size=(1, 1), device=unet.device)
66+
x = torch.randn(1, 4, 32, 32, dtype=unet.dtype, device=unet.device)
5367

54-
unet = SD1UNet(in_channels=4)
5568
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
5669

5770
with no_grad():
@@ -65,5 +78,7 @@ def test_freeu_identity_scales() -> None:
6578
unet.set_timestep(timestep=timestep)
6679
y_2 = unet(x.clone())
6780

81+
freeu.eject()
82+
6883
# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
6984
assert torch.allclose(y_1, y_2, atol=1e-5)
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)