1
- from typing import Iterator
2
-
3
1
import pytest
4
2
import torch
5
3
10
8
from refiners .foundationals .latent_diffusion .freeu import FreeUResidualConcatenator , SDFreeUAdapter
11
9
12
10
13
- @pytest .fixture (scope = "module" , params = [ True , False ] )
14
- def unet (request : pytest . FixtureRequest ) -> Iterator [ SD1UNet | SDXLUNet ]:
15
- xl : bool = request . param
16
- unet = SDXLUNet ( in_channels = 4 ) if xl else SD1UNet ( in_channels = 4 )
17
- yield unet
11
+ @pytest .fixture (scope = "module" )
12
+ def unet (
13
+ refiners_unet : SD1UNet | SDXLUNet ,
14
+ ) -> SD1UNet | SDXLUNet :
15
+ return refiners_unet
18
16
19
17
20
- def test_freeu_adapter (unet : SD1UNet | SDXLUNet ) -> None :
18
+ def test_inject_eject_freeu (
19
+ unet : SD1UNet | SDXLUNet ,
20
+ ) -> None :
21
+ initial_repr = repr (unet )
21
22
freeu = SDFreeUAdapter (unet , backbone_scales = [1.2 , 1.2 ], skip_scales = [0.9 , 0.9 ])
22
23
23
- assert len (list (unet .walk (FreeUResidualConcatenator ))) == 0
24
+ assert unet .parent is None
25
+ assert unet .find (FreeUResidualConcatenator ) is None
26
+ assert repr (unet ) == initial_repr
27
+
28
+ freeu .inject ()
29
+ assert unet .parent is not None
30
+ assert unet .find (FreeUResidualConcatenator ) is not None
31
+ assert repr (unet ) != initial_repr
24
32
25
- with pytest .raises (AssertionError ) as exc :
26
- freeu .eject ()
27
- assert "could not find" in str (exc .value )
33
+ freeu .eject ()
34
+ assert unet .parent is None
35
+ assert unet .find (FreeUResidualConcatenator ) is None
36
+ assert repr (unet ) == initial_repr
28
37
29
38
freeu .inject ()
30
- assert len (list (unet .walk (FreeUResidualConcatenator ))) == 2
39
+ assert unet .parent is not None
40
+ assert unet .find (FreeUResidualConcatenator ) is not None
41
+ assert repr (unet ) != initial_repr
31
42
32
43
freeu .eject ()
33
- assert len (list (unet .walk (FreeUResidualConcatenator ))) == 0
44
+ assert unet .parent is None
45
+ assert unet .find (FreeUResidualConcatenator ) is None
46
+ assert repr (unet ) == initial_repr
34
47
35
48
36
49
def test_freeu_adapter_too_many_scales (unet : SD1UNet | SDXLUNet ) -> None :
37
50
num_blocks = len (unet .layer ("UpBlocks" , Chain ))
38
-
39
51
with pytest .raises (AssertionError ):
40
52
SDFreeUAdapter (unet , backbone_scales = [1.2 ] * (num_blocks + 1 ), skip_scales = [0.9 ] * (num_blocks + 1 ))
41
53
42
54
43
55
def test_freeu_adapter_inconsistent_scales (unet : SD1UNet | SDXLUNet ) -> None :
44
56
with pytest .raises (AssertionError ):
45
57
SDFreeUAdapter (unet , backbone_scales = [1.2 , 1.2 ], skip_scales = [0.9 , 0.9 , 0.9 ])
58
+ with pytest .raises (AssertionError ):
59
+ SDFreeUAdapter (unet , backbone_scales = [1.2 , 1.2 , 1.2 ], skip_scales = [0.9 , 0.9 ])
46
60
47
61
48
- def test_freeu_identity_scales () -> None :
62
+ def test_freeu_identity_scales (unet : SD1UNet | SDXLUNet ) -> None :
49
63
manual_seed (0 )
50
- text_embedding = torch .randn (1 , 77 , 768 )
51
- timestep = torch .randint (0 , 999 , size = (1 , 1 ))
52
- x = torch .randn (1 , 4 , 32 , 32 )
64
+ text_embedding = torch .randn (1 , 77 , 768 , dtype = unet . dtype , device = unet . device )
65
+ timestep = torch .randint (0 , 999 , size = (1 , 1 ), device = unet . device )
66
+ x = torch .randn (1 , 4 , 32 , 32 , dtype = unet . dtype , device = unet . device )
53
67
54
- unet = SD1UNet (in_channels = 4 )
55
68
unet .set_clip_text_embedding (clip_text_embedding = text_embedding ) # not flushed between forward-s
56
69
57
70
with no_grad ():
@@ -65,5 +78,7 @@ def test_freeu_identity_scales() -> None:
65
78
unet .set_timestep (timestep = timestep )
66
79
y_2 = unet (x .clone ())
67
80
81
+ freeu .eject ()
82
+
68
83
# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
69
84
assert torch .allclose (y_1 , y_2 , atol = 1e-5 )
0 commit comments