Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 60f6f62

Browse files
author
Laurent
committed
move some tests into the adapters test folder
1 parent a51d695 commit 60f6f62

File tree

6 files changed

+142
-111
lines changed

6 files changed

+142
-111
lines changed
File renamed without changes.
Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Iterator
2-
31
import pytest
42
import torch
53

@@ -10,48 +8,63 @@
108
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
119

1210

13-
@pytest.fixture(scope="module", params=[True, False])
14-
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]:
15-
xl: bool = request.param
16-
unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4)
17-
yield unet
11+
@pytest.fixture(scope="module")
12+
def unet(
13+
refiners_unet: SD1UNet | SDXLUNet,
14+
) -> SD1UNet | SDXLUNet:
15+
return refiners_unet
1816

1917

20-
def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
18+
def test_inject_eject_freeu(
19+
unet: SD1UNet | SDXLUNet,
20+
) -> None:
21+
initial_repr = repr(unet)
2122
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])
2223

23-
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
24+
assert unet.parent is None
25+
assert unet.find(FreeUResidualConcatenator) is None
26+
assert repr(unet) == initial_repr
27+
28+
freeu.inject()
29+
assert unet.parent is not None
30+
assert unet.find(FreeUResidualConcatenator) is not None
31+
assert repr(unet) != initial_repr
2432

25-
with pytest.raises(AssertionError) as exc:
26-
freeu.eject()
27-
assert "could not find" in str(exc.value)
33+
freeu.eject()
34+
assert unet.parent is None
35+
assert unet.find(FreeUResidualConcatenator) is None
36+
assert repr(unet) == initial_repr
2837

2938
freeu.inject()
30-
assert len(list(unet.walk(FreeUResidualConcatenator))) == 2
39+
assert unet.parent is not None
40+
assert unet.find(FreeUResidualConcatenator) is not None
41+
assert repr(unet) != initial_repr
3142

3243
freeu.eject()
33-
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
44+
assert unet.parent is None
45+
assert unet.find(FreeUResidualConcatenator) is None
46+
assert repr(unet) == initial_repr
3447

3548

3649
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
3750
num_blocks = len(unet.layer("UpBlocks", Chain))
38-
3951
with pytest.raises(AssertionError):
4052
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
4153

4254

4355
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
4456
with pytest.raises(AssertionError):
4557
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
58+
with pytest.raises(AssertionError):
59+
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2, 1.2], skip_scales=[0.9, 0.9])
4660

4761

48-
def test_freeu_identity_scales() -> None:
62+
def test_freeu_identity_scales(unet: SD1UNet | SDXLUNet) -> None:
4963
manual_seed(0)
50-
text_embedding = torch.randn(1, 77, 768)
51-
timestep = torch.randint(0, 999, size=(1, 1))
52-
x = torch.randn(1, 4, 32, 32)
64+
text_embedding = torch.randn(1, 77, 768, dtype=unet.dtype, device=unet.device)
65+
timestep = torch.randint(0, 999, size=(1, 1), device=unet.device)
66+
x = torch.randn(1, 4, 32, 32, dtype=unet.dtype, device=unet.device)
5367

54-
unet = SD1UNet(in_channels=4)
5568
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
5669

5770
with no_grad():
@@ -65,5 +78,7 @@ def test_freeu_identity_scales() -> None:
6578
unet.set_timestep(timestep=timestep)
6679
y_2 = unet(x.clone())
6780

81+
freeu.eject()
82+
6883
# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
6984
assert torch.allclose(y_1, y_2, atol=1e-5)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from hashlib import sha256
2+
from pathlib import Path
3+
from warnings import warn
4+
5+
import pytest
6+
import torch
7+
from huggingface_hub import hf_hub_download # type: ignore
8+
9+
from refiners.fluxion.utils import load_tensors
10+
from refiners.foundationals.latent_diffusion import StableDiffusion_1
11+
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
12+
13+
14+
@pytest.fixture
15+
def manager(refiners_sd15: StableDiffusion_1) -> SDLoraManager:
16+
return SDLoraManager(refiners_sd15)
17+
18+
19+
@pytest.fixture
20+
def pokemon_lora_weights(
21+
test_weights_path: Path,
22+
use_local_weights: bool,
23+
) -> dict[str, torch.Tensor]:
24+
if use_local_weights:
25+
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
26+
if not weights_path.is_file():
27+
warn(f"could not find weights at {weights_path}, skipping")
28+
pytest.skip(allow_module_level=True)
29+
else:
30+
weights_path = Path(
31+
hf_hub_download(
32+
repo_id="pcuenq/pokemon-lora",
33+
filename="pytorch_lora_weights.bin",
34+
revision="bc3cb5256ebc303457acab170ca6219a66dd31f5",
35+
)
36+
)
37+
38+
expected_sha256 = "f712fcfb6618da14d25a4f3e0c9460a878fc2417e2df95cdd683a73f71b50384"
39+
retrieved_sha256 = sha256(weights_path.read_bytes()).hexdigest().lower()
40+
assert retrieved_sha256 == expected_sha256, f"expected {expected_sha256}, got {retrieved_sha256}"
41+
42+
return load_tensors(weights_path)
43+
44+
45+
def test_add_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
46+
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
47+
assert "pokemon-lora" in manager.names
48+
49+
with pytest.raises(AssertionError) as exc:
50+
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
51+
assert "already exists" in str(exc.value)
52+
53+
54+
def test_add_multiple_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
55+
manager.add_loras("pokemon-lora", pokemon_lora_weights)
56+
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
57+
assert "pokemon-lora" in manager.names
58+
assert "pokemon-lora2" in manager.names
59+
60+
61+
def test_remove_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
62+
manager.add_loras("pokemon-lora", pokemon_lora_weights)
63+
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
64+
manager.remove_loras("pokemon-lora")
65+
assert "pokemon-lora" not in manager.names
66+
assert "pokemon-lora2" in manager.names
67+
68+
manager.remove_loras("pokemon-lora2")
69+
assert "pokemon-lora2" not in manager.names
70+
assert len(manager.names) == 0
71+
72+
73+
def test_remove_all(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
74+
manager.add_loras("pokemon-lora", pokemon_lora_weights)
75+
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
76+
manager.remove_all()
77+
assert len(manager.names) == 0
78+
79+
80+
def test_get_lora(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
81+
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
82+
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
83+
84+
85+
def test_get_scale(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
86+
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
87+
assert manager.get_scale("pokemon-lora") == 0.4
88+
89+
90+
def test_names(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
91+
assert manager.names == []
92+
93+
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
94+
assert manager.names == ["pokemon-lora"]
95+
96+
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights)
97+
assert set(manager.names) == set(["pokemon-lora", "pokemon-lora2"])
98+
99+
100+
def test_scales(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
101+
assert manager.scales == {}
102+
103+
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
104+
assert manager.scales == {"pokemon-lora": 0.4}
105+
106+
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights, scale=0.5)
107+
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}
File renamed without changes.
File renamed without changes.

tests/foundationals/latent_diffusion/test_lora_manager.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

0 commit comments

Comments
 (0)