Skip to content

Commit d7e726b

Browse files
committed
complete refactor to allow for learned encoder / decoders (unet in paper), validated with an mnist training script
1 parent 3b2076e commit d7e726b

File tree

6 files changed

+170
-169
lines changed

6 files changed

+170
-169
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name: Tests the examples in README
2-
on: push
2+
on: [push, pull_request]
33

44
env:
55
TYPECHECK: True

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.4.16"
3+
version = "0.5.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_transfusion.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_transfusion(
3636
dim_latent = (384, 192), # specify multiple latent dimensions
3737
modality_default_shape = ((32,), (64,)),
3838
transformer = dict(
39-
dim = 512,
39+
dim = 64,
4040
depth = 2,
4141
use_flex_attn = use_flex_attn
4242
)
@@ -80,7 +80,7 @@ def test_auto_modality_transform(
8080
channel_first_latent = True,
8181
modality_default_shape = (32,),
8282
transformer = dict(
83-
dim = 512,
83+
dim = 64,
8484
depth = 2,
8585
use_flex_attn = use_flex_attn
8686
)
@@ -117,7 +117,7 @@ def test_text(
117117
channel_first_latent = True,
118118
modality_default_shape = (32,),
119119
transformer = dict(
120-
dim = 512,
120+
dim = 64,
121121
depth = 2,
122122
use_flex_attn = use_flex_attn
123123
)
@@ -141,7 +141,7 @@ def test_modality_only(
141141
channel_first_latent = channel_first,
142142
modality_default_shape = (32,),
143143
transformer = dict(
144-
dim = 512,
144+
dim = 64,
145145
depth = 2,
146146
use_flex_attn = False
147147
)
@@ -173,8 +173,8 @@ def test_text_image_end_to_end(
173173
modality_encoder = mock_vae_encoder,
174174
modality_decoder = mock_vae_decoder,
175175
transformer = dict(
176-
dim = 512,
177-
depth = 8
176+
dim = 64,
177+
depth = 2
178178
)
179179
)
180180

@@ -196,24 +196,26 @@ def test_text_image_end_to_end(
196196

197197
# allow researchers to experiment with different time distributions across multiple modalities in a sample
198198

199-
def modality_length_to_times(modality_length):
200-
has_modality = modality_length > 0
201-
return torch.where(has_modality, torch.ones_like(modality_length), 0.)
199+
def num_modalities_to_times(num_modalities):
200+
batch = num_modalities.shape[0]
201+
device = num_modalities.device
202+
total_modalities = num_modalities.amax().item()
203+
return torch.ones((batch, total_modalities), device = device)
202204

203-
time_fn = modality_length_to_times if custom_time_fn else None
205+
time_fn = num_modalities_to_times if custom_time_fn else None
204206

205207
# forward
206208

207209
loss = model(
208210
text_and_images,
209-
modality_length_to_times_fn = time_fn
211+
num_modalities_to_times_fn = time_fn
210212
)
211213

212214
loss.backward()
213215

214216
# after much training
215217

216-
one_multimodal_sample = model.sample()
218+
one_multimodal_sample = model.sample(max_length = 128)
217219

218220
def test_velocity_consistency():
219221
mock_encoder = nn.Conv2d(3, 384, 3, padding = 1)
@@ -228,7 +230,7 @@ def test_velocity_consistency():
228230
modality_encoder = mock_encoder,
229231
modality_decoder = mock_decoder,
230232
transformer = dict(
231-
dim = 512,
233+
dim = 64,
232234
depth = 1
233235
)
234236
)
@@ -251,14 +253,9 @@ def test_velocity_consistency():
251253
]
252254
]
253255

254-
def modality_length_to_times(modality_length):
255-
has_modality = modality_length > 0
256-
return torch.where(has_modality, torch.ones_like(modality_length), 0.)
257-
258256
loss, breakdown = model(
259257
text_and_images,
260258
velocity_consistency_ema_model = ema_model,
261-
modality_length_to_times_fn = modality_length_to_times,
262259
return_breakdown = True
263260
)
264261

@@ -275,7 +272,7 @@ def test_axial_pos_emb():
275272
add_pos_emb = True,
276273
modality_num_dim = (2, 1),
277274
transformer = dict(
278-
dim = 512,
275+
dim = 64,
279276
depth = 8
280277
)
281278
)
@@ -295,7 +292,7 @@ def test_axial_pos_emb():
295292

296293
# after much training
297294

298-
one_multimodal_sample = model.sample()
295+
one_multimodal_sample = model.sample(max_length = 128)
299296

300297
# unet related
301298

train_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
IMAGE_AFTER_TEXT = False
2525
NUM_TRAIN_STEPS = 10_000
2626
SAMPLE_EVERY = 250
27-
CHANNEL_FIRST = False
27+
CHANNEL_FIRST = True
2828

2929
# functions
3030

train_mnist_vae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __getitem__(self, idx):
5959

6060
# contrived encoder / decoder with layernorm at bottleneck
6161

62-
autoencoder_train_steps = 15000
62+
autoencoder_train_steps = 15_000
6363
dim_latent = 16
6464

6565
class Normalize(Module):
@@ -133,7 +133,7 @@ def forward(self, x):
133133

134134
# training transfusion
135135

136-
dataloader = model.create_dataloader(dataset, batch_size = 16, collate_fn = collate_fn, shuffle = True)
136+
dataloader = model.create_dataloader(dataset, batch_size = 16, shuffle = True)
137137
iter_dl = cycle(dataloader)
138138

139139
optimizer = Adam(model.parameters_without_encoder_decoder(), lr = 3e-4)

0 commit comments

Comments
 (0)