Skip to content

Commit 0d82328

Browse files
committed
add working e2e script for latent + text + custom learned down/upsampling for oxford flowers
1 parent 8cbe40d commit 0d82328

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed

data/flowers/labels.txt

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
pink primrose
2+
hard-leaved pocket orchid
3+
canterbury bells
4+
sweet pea
5+
english marigold
6+
tiger lily
7+
moon orchid
8+
bird of paradise
9+
monkshood
10+
globe thistle
11+
snapdragon
12+
colt's foot
13+
king protea
14+
spear thistle
15+
yellow iris
16+
globe-flower
17+
purple coneflower
18+
peruvian lily
19+
balloon flower
20+
giant white arum lily
21+
fire lily
22+
pincushion flower
23+
fritillary
24+
red ginger
25+
grape hyacinth
26+
corn poppy
27+
prince of wales feathers
28+
stemless gentian
29+
artichoke
30+
sweet william
31+
carnation
32+
garden phlox
33+
love in the mist
34+
mexican aster
35+
alpine sea holly
36+
ruby-lipped cattleya
37+
cape flower
38+
great masterwort
39+
siam tulip
40+
lenten rose
41+
barbeton daisy
42+
daffodil
43+
sword lily
44+
poinsettia
45+
bolero deep blue
46+
wallflower
47+
marigold
48+
buttercup
49+
oxeye daisy
50+
common dandelion
51+
petunia
52+
wild pansy
53+
primula
54+
sunflower
55+
pelargonium
56+
bishop of llandaff
57+
gaura
58+
geranium
59+
orange dahlia
60+
pink-yellow dahlia?
61+
cautleya spicata
62+
japanese anemone
63+
black-eyed susan
64+
silverbush
65+
californian poppy
66+
osteospermum
67+
spring crocus
68+
bearded iris
69+
windflower
70+
tree poppy
71+
gazania
72+
azalea
73+
water lily
74+
rose
75+
thorn apple
76+
morning glory
77+
passion flower
78+
lotus
79+
toad lily
80+
anthurium
81+
frangipani
82+
clematis
83+
hibiscus
84+
columbine
85+
desert-rose
86+
tree mallow
87+
magnolia
88+
cyclamen
89+
watercress
90+
canna lily
91+
hippeastrum
92+
bee balm
93+
ball moss
94+
foxglove
95+
bougainvillea
96+
camellia
97+
mallow
98+
mexican petunia
99+
bromelia
100+
blanket flower
101+
trumpet creeper
102+
blackberry lily

train_latent_with_text.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from shutil import rmtree
2+
from pathlib import Path
3+
4+
import torch
5+
from torch import nn, tensor, Tensor
6+
from torch.nn import Module
7+
from torch.utils.data import Dataset, DataLoader
8+
from torch.optim import Adam
9+
10+
from einops import rearrange
11+
12+
import torchvision
13+
import torchvision.transforms as T
14+
from torchvision.utils import save_image
15+
16+
from transfusion_pytorch import Transfusion, print_modality_sample
17+
18+
# hf related
19+
20+
from datasets import load_dataset
21+
from diffusers.models import AutoencoderKL
22+
23+
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder = "vae")
24+
25+
class Encoder(Module):
26+
def __init__(self, vae):
27+
super().__init__()
28+
self.vae = vae
29+
30+
def forward(self, image):
31+
with torch.no_grad():
32+
latent = self.vae.encode(image * 2 - 1)
33+
34+
return 0.18215 * latent.latent_dist.sample()
35+
36+
class Decoder(Module):
37+
def __init__(self, vae):
38+
super().__init__()
39+
self.vae = vae
40+
41+
def forward(self, latents):
42+
latents = (1 / 0.18215) * latents
43+
44+
with torch.no_grad():
45+
image = self.vae.decode(latents).sample
46+
47+
return (image / 2 + 0.5).clamp(0, 1)
48+
49+
# results folder
50+
51+
rmtree('./results', ignore_errors = True)
52+
results_folder = Path('./results')
53+
results_folder.mkdir(exist_ok = True, parents = True)
54+
55+
# constants
56+
57+
SAMPLE_EVERY = 100
58+
59+
with open("./data/flowers/labels.txt", "r") as file:
60+
content = file.read()
61+
62+
LABELS_TEXT = content.split('\n')
63+
64+
# functions
65+
66+
def divisible_by(num, den):
67+
return (num % den) == 0
68+
69+
def decode_token(token):
70+
return str(chr(max(32, token)))
71+
72+
def decode_tokens(tokens: Tensor) -> str:
73+
return "".join(list(map(decode_token, tokens.tolist())))
74+
75+
def encode_tokens(str: str) -> Tensor:
76+
return tensor([*bytes(str, 'UTF-8')])
77+
78+
# encoder / decoder
79+
80+
model = Transfusion(
81+
num_text_tokens = 256,
82+
dim_latent = 4,
83+
channel_first_latent = True,
84+
modality_default_shape = (4, 4),
85+
modality_encoder = Encoder(vae),
86+
modality_decoder = Decoder(vae),
87+
pre_post_transformer_enc_dec = (
88+
nn.Conv2d(4, 128, 3, 2, 1),
89+
nn.ConvTranspose2d(128, 4, 3, 2, 1, output_padding = 1),
90+
),
91+
add_pos_emb = True,
92+
modality_num_dim = 2,
93+
velocity_consistency_loss_weight = 0.1,
94+
reconstruction_loss_weight = 0.1,
95+
transformer = dict(
96+
dim = 128,
97+
depth = 8,
98+
dim_head = 64,
99+
heads = 8
100+
)
101+
).cuda()
102+
103+
ema_model = model.create_ema(0.9)
104+
105+
class FlowersDataset(Dataset):
106+
def __init__(self, image_size):
107+
self.ds = load_dataset("nelorth/oxford-flowers")['train']
108+
109+
self.transform = T.Compose([
110+
T.Resize((image_size, image_size)),
111+
T.PILToTensor(),
112+
T.Lambda(lambda t: t / 255.)
113+
])
114+
115+
def __len__(self):
116+
return len(self.ds)
117+
118+
def __getitem__(self, idx):
119+
sample = self.ds[idx]
120+
pil = sample['image']
121+
122+
labels_int = sample['label']
123+
labels_text = LABELS_TEXT[labels_int]
124+
125+
tensor = self.transform(pil)
126+
return encode_tokens(labels_text), tensor
127+
128+
def cycle(iter_dl):
129+
while True:
130+
for batch in iter_dl:
131+
yield batch
132+
133+
dataset = FlowersDataset(128)
134+
135+
dataloader = model.create_dataloader(dataset, batch_size = 4, shuffle = True)
136+
137+
iter_dl = cycle(dataloader)
138+
139+
optimizer = Adam(model.parameters(), lr = 8e-4)
140+
141+
# train loop
142+
143+
for step in range(1, 100_000 + 1):
144+
145+
for _ in range(4):
146+
loss = model.forward(next(iter_dl))
147+
(loss / 4).backward()
148+
149+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
150+
151+
optimizer.step()
152+
optimizer.zero_grad()
153+
154+
ema_model.update()
155+
156+
print(f'{step}: {loss.item():.3f}')
157+
158+
if divisible_by(step, SAMPLE_EVERY):
159+
sample = ema_model.sample()
160+
161+
print_modality_sample(sample)
162+
163+
if len(sample) < 3:
164+
continue
165+
166+
text_tensor, maybe_image, *_ = sample
167+
168+
if not isinstance(maybe_image, tuple):
169+
continue
170+
171+
_, image = maybe_image
172+
text_tensor = text_tensor[text_tensor < 256] # todo: offer a utility function for removing meta tags and special tokens
173+
174+
text = decode_tokens(text_tensor)
175+
filename = str(results_folder / f'{text}.{step}.png')
176+
177+
save_image(
178+
image.detach().cpu(),
179+
filename
180+
)

0 commit comments

Comments
 (0)