Skip to content

Commit 95beb5c

Browse files
author
Laurent
committed
rename test_unet.py to test_sd15_unet.py + use test_device fixture
1 parent 3196a26 commit 95beb5c

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import torch
3+
4+
from refiners.fluxion import manual_seed
5+
from refiners.fluxion.utils import no_grad
6+
from refiners.foundationals.latent_diffusion import SD1UNet
7+
8+
9+
@pytest.fixture(scope="module")
10+
def refiners_sd15_unet(test_device: torch.device) -> SD1UNet:
11+
unet = SD1UNet(in_channels=4, device=test_device)
12+
return unet
13+
14+
15+
def test_unet_context_flush(refiners_sd15_unet: SD1UNet):
16+
manual_seed(0)
17+
text_embedding = torch.randn(1, 77, 768, device=refiners_sd15_unet.device, dtype=refiners_sd15_unet.dtype)
18+
timestep = torch.randint(0, 999, size=(1, 1), device=refiners_sd15_unet.device)
19+
x = torch.randn(1, 4, 32, 32, device=refiners_sd15_unet.device, dtype=refiners_sd15_unet.dtype)
20+
21+
refiners_sd15_unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
22+
23+
with no_grad():
24+
refiners_sd15_unet.set_timestep(timestep=timestep)
25+
y_1 = refiners_sd15_unet(x.clone())
26+
27+
with no_grad():
28+
refiners_sd15_unet.set_timestep(timestep=timestep)
29+
y_2 = refiners_sd15_unet(x.clone())
30+
31+
assert torch.equal(y_1, y_2)

tests/foundationals/latent_diffusion/test_unet.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)