Skip to content

Commit 4485c28

Browse files
committed
complete the working latent example with oxford flowers and stable diffusion 1.4 vae
1 parent 0c11473 commit 4485c28

File tree

1 file changed

+49
-25
lines changed

1 file changed

+49
-25
lines changed

train_latent_only.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,33 @@
1818
# hf related
1919

2020
from datasets import load_dataset
21+
from diffusers.models import AutoencoderKL
22+
23+
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder = "vae")
24+
25+
class Encoder(Module):
26+
def __init__(self, vae):
27+
super().__init__()
28+
self.vae = vae
29+
30+
def forward(self, image):
31+
with torch.no_grad():
32+
latent = self.vae.encode(image * 2 - 1)
33+
34+
return 0.18215 * latent.latent_dist.sample()
35+
36+
class Decoder(Module):
37+
def __init__(self, vae):
38+
super().__init__()
39+
self.vae = vae
40+
41+
def forward(self, latents):
42+
latents = (1 / 0.18215) * latents
43+
44+
with torch.no_grad():
45+
image = self.vae.decode(latents).sample
46+
47+
return (image / 2 + 0.5).clamp(0, 1)
2148

2249
# results folder
2350

@@ -36,23 +63,13 @@ def divisible_by(num, den):
3663

3764
# encoder / decoder
3865

39-
class Encoder(Module):
40-
def forward(self, x):
41-
x = rearrange(x, '... c (h p1) (w p2) -> ... h w (p1 p2 c)', p1 = 4, p2 = 4)
42-
return x * 2 - 1
43-
44-
class Decoder(Module):
45-
def forward(self, x):
46-
x = rearrange(x, '... h w (p1 p2 c) -> ... c (h p1) (w p2)', p1 = 4, p2 = 4, c = 3)
47-
return ((x + 1) * 0.5).clamp(min = 0., max = 1.)
48-
4966
model = Transfusion(
5067
num_text_tokens = 10,
51-
dim_latent = 4 * 4 * 3,
52-
channel_first_latent = False,
53-
modality_default_shape = (16, 16),
54-
modality_encoder = Encoder(),
55-
modality_decoder = Decoder(),
68+
dim_latent = 4,
69+
channel_first_latent = True,
70+
modality_default_shape = (32, 32),
71+
modality_encoder = Encoder(vae),
72+
modality_decoder = Decoder(vae),
5673
add_pos_emb = True,
5774
modality_num_dim = 2,
5875
velocity_consistency_loss_weight = 0.1,
@@ -65,28 +82,34 @@ def forward(self, x):
6582
)
6683
).cuda()
6784

68-
ema_model = model.create_ema()
85+
ema_model = model.create_ema(0.9)
6986

7087
class FlowersDataset(Dataset):
71-
def __init__(self):
88+
def __init__(self, image_size):
7289
self.ds = load_dataset("nelorth/oxford-flowers")['train']
7390

91+
self.transform = T.Compose([
92+
T.Resize((image_size, image_size)),
93+
T.PILToTensor()
94+
])
95+
7496
def __len__(self):
7597
return len(self.ds)
7698

7799
def __getitem__(self, idx):
78100
pil = self.ds[idx]['image']
79-
image_tensor = T.PILToTensor()(pil)
80-
return T.Resize((64, 64))(image_tensor / 255.)
101+
tensor = self.transform(pil)
102+
return tensor / 255.
81103

82104
def cycle(iter_dl):
83105
while True:
84106
for batch in iter_dl:
85107
yield batch
86108

87-
dataset = FlowersDataset()
109+
dataset = FlowersDataset(256)
110+
111+
dataloader = DataLoader(dataset, batch_size = 4, shuffle = True)
88112

89-
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)
90113
iter_dl = cycle(dataloader)
91114

92115
optimizer = Adam(model.parameters(), lr = 8e-4)
@@ -95,8 +118,9 @@ def cycle(iter_dl):
95118

96119
for step in range(1, 100_000 + 1):
97120

98-
loss = model.forward_modality(next(iter_dl))
99-
loss.backward()
121+
for _ in range(4):
122+
loss = model.forward_modality(next(iter_dl))
123+
(loss / 4).backward()
100124

101125
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
102126

@@ -108,9 +132,9 @@ def cycle(iter_dl):
108132
print(f'{step}: {loss.item():.3f}')
109133

110134
if divisible_by(step, SAMPLE_EVERY):
111-
image = ema_model.generate_modality_only(batch_size = 64)
135+
image = ema_model.generate_modality_only(batch_size = 4)
112136

113137
save_image(
114-
rearrange(image, '(gh gw) c h w -> c (gh h) (gw w)', gh = 8).detach().cpu(),
138+
rearrange(image, '(gh gw) c h w -> c (gh h) (gw w)', gh = 2).detach().cpu(),
115139
str(results_folder / f'{step}.png')
116140
)

0 commit comments

Comments
 (0)