Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 13f5bd1

Browse files
committed
fix
1 parent f95d8a9 commit 13f5bd1

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

tests/adapters/test_ella_adapter.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import torch
32

43
import refiners.fluxion.layers as fl
@@ -12,9 +11,8 @@ def new_adapter(target: SD1UNet) -> SD1ELLAAdapter:
1211

1312

1413
@no_grad()
15-
@pytest.mark.parametrize("k_unet", [SD1UNet])
16-
def test_inject_eject(k_unet: type[SD1UNet], test_device: torch.device):
17-
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
14+
def test_inject_eject(test_device: torch.device):
15+
unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16)
1816
initial_repr = repr(unet)
1917
adapter = new_adapter(unet)
2018
assert repr(unet) == initial_repr
@@ -29,9 +27,8 @@ def test_inject_eject(k_unet: type[SD1UNet], test_device: torch.device):
2927

3028

3129
@no_grad()
32-
@pytest.mark.parametrize("k_unet", [SD1UNet])
33-
def test_scale(k_unet: type[SD1UNet], test_device: torch.device):
34-
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
30+
def test_scale(test_device: torch.device):
31+
unet = SD1UNet(in_channels=4, device=test_device, dtype=torch.float16)
3532
adapter = new_adapter(unet).inject()
3633

3734
def predicate(m: fl.Module, p: fl.Chain) -> bool:

tests/e2e/test_diffusion.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -513,12 +513,26 @@ def lda_ft_mse_weights(test_weights_path: Path) -> Path:
513513

514514

515515
@pytest.fixture(scope="module")
516-
def ella_adapter_weights(test_weights_path: Path) -> Path:
516+
def ella_weights(test_weights_path: Path) -> tuple[Path, Path]:
517517
ella_adapter_weights = test_weights_path / "ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors"
518518
if not ella_adapter_weights.is_file():
519519
warn(f"could not find weights at {ella_adapter_weights}, skipping")
520520
pytest.skip(allow_module_level=True)
521-
return ella_adapter_weights
521+
t5xl_weights = test_weights_path / "QQGYLab/T5XLFP16"
522+
t5xl_files = [
523+
"config.json",
524+
"model.safetensors",
525+
"special_tokens_map.json",
526+
"spiece.model",
527+
"tokenizer_config.json",
528+
"tokenizer.json",
529+
]
530+
for file in t5xl_files:
531+
if not (t5xl_weights / file).is_file():
532+
warn(f"could not find weights at {t5xl_weights / file}, skipping")
533+
pytest.skip(allow_module_level=True)
534+
535+
return (ella_adapter_weights, t5xl_weights)
522536

523537

524538
@pytest.fixture(scope="module")
@@ -605,10 +619,6 @@ def sd15_std_sde(
605619
def sd15_std_float16(
606620
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
607621
) -> StableDiffusion_1:
608-
if test_device.type == "cpu":
609-
warn("not running on CPU, skipping")
610-
pytest.skip()
611-
612622
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
613623

614624
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
@@ -1817,15 +1827,13 @@ def test_diffusion_textual_inversion_random_init(
18171827
@no_grad()
18181828
def test_diffusion_ella_adapter(
18191829
sd15_std_float16: StableDiffusion_1,
1820-
ella_adapter_weights: Path,
1821-
test_weights_path: Path,
1830+
ella_weights: tuple[Path, Path],
18221831
expected_image_ella_adapter: Image.Image,
18231832
test_device: torch.device,
18241833
):
18251834
sd15 = sd15_std_float16
1826-
t5_encoder = T5TextEmbedder(pretrained_path=test_weights_path / "QQGYLab/T5XLFP16", max_length=128).to(
1827-
test_device, torch.float16
1828-
)
1835+
ella_adapter_weights, t5xl_weights = ella_weights
1836+
t5_encoder = T5TextEmbedder(pretrained_path=t5xl_weights, max_length=128).to(test_device, torch.float16)
18291837

18301838
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
18311839
negative_prompt = ""

0 commit comments

Comments
 (0)