@@ -36,7 +36,7 @@ def test_transfusion(
36
36
dim_latent = (384 , 192 ), # specify multiple latent dimensions
37
37
modality_default_shape = ((32 ,), (64 ,)),
38
38
transformer = dict (
39
- dim = 512 ,
39
+ dim = 64 ,
40
40
depth = 2 ,
41
41
use_flex_attn = use_flex_attn
42
42
)
@@ -80,7 +80,7 @@ def test_auto_modality_transform(
80
80
channel_first_latent = True ,
81
81
modality_default_shape = (32 ,),
82
82
transformer = dict (
83
- dim = 512 ,
83
+ dim = 64 ,
84
84
depth = 2 ,
85
85
use_flex_attn = use_flex_attn
86
86
)
@@ -117,7 +117,7 @@ def test_text(
117
117
channel_first_latent = True ,
118
118
modality_default_shape = (32 ,),
119
119
transformer = dict (
120
- dim = 512 ,
120
+ dim = 64 ,
121
121
depth = 2 ,
122
122
use_flex_attn = use_flex_attn
123
123
)
@@ -141,7 +141,7 @@ def test_modality_only(
141
141
channel_first_latent = channel_first ,
142
142
modality_default_shape = (32 ,),
143
143
transformer = dict (
144
- dim = 512 ,
144
+ dim = 64 ,
145
145
depth = 2 ,
146
146
use_flex_attn = False
147
147
)
@@ -173,8 +173,8 @@ def test_text_image_end_to_end(
173
173
modality_encoder = mock_vae_encoder ,
174
174
modality_decoder = mock_vae_decoder ,
175
175
transformer = dict (
176
- dim = 512 ,
177
- depth = 8
176
+ dim = 64 ,
177
+ depth = 2
178
178
)
179
179
)
180
180
@@ -196,24 +196,26 @@ def test_text_image_end_to_end(
196
196
197
197
# allow researchers to experiment with different time distributions across multiple modalities in a sample
198
198
199
- def modality_length_to_times (modality_length ):
200
- has_modality = modality_length > 0
201
- return torch .where (has_modality , torch .ones_like (modality_length ), 0. )
199
+ def num_modalities_to_times (num_modalities ):
200
+ batch = num_modalities .shape [0 ]
201
+ device = num_modalities .device
202
+ total_modalities = num_modalities .amax ().item ()
203
+ return torch .ones ((batch , total_modalities ), device = device )
202
204
203
- time_fn = modality_length_to_times if custom_time_fn else None
205
+ time_fn = num_modalities_to_times if custom_time_fn else None
204
206
205
207
# forward
206
208
207
209
loss = model (
208
210
text_and_images ,
209
- modality_length_to_times_fn = time_fn
211
+ num_modalities_to_times_fn = time_fn
210
212
)
211
213
212
214
loss .backward ()
213
215
214
216
# after much training
215
217
216
- one_multimodal_sample = model .sample ()
218
+ one_multimodal_sample = model .sample (max_length = 128 )
217
219
218
220
def test_velocity_consistency ():
219
221
mock_encoder = nn .Conv2d (3 , 384 , 3 , padding = 1 )
@@ -228,7 +230,7 @@ def test_velocity_consistency():
228
230
modality_encoder = mock_encoder ,
229
231
modality_decoder = mock_decoder ,
230
232
transformer = dict (
231
- dim = 512 ,
233
+ dim = 64 ,
232
234
depth = 1
233
235
)
234
236
)
@@ -251,14 +253,9 @@ def test_velocity_consistency():
251
253
]
252
254
]
253
255
254
- def modality_length_to_times (modality_length ):
255
- has_modality = modality_length > 0
256
- return torch .where (has_modality , torch .ones_like (modality_length ), 0. )
257
-
258
256
loss , breakdown = model (
259
257
text_and_images ,
260
258
velocity_consistency_ema_model = ema_model ,
261
- modality_length_to_times_fn = modality_length_to_times ,
262
259
return_breakdown = True
263
260
)
264
261
@@ -275,7 +272,7 @@ def test_axial_pos_emb():
275
272
add_pos_emb = True ,
276
273
modality_num_dim = (2 , 1 ),
277
274
transformer = dict (
278
- dim = 512 ,
275
+ dim = 64 ,
279
276
depth = 8
280
277
)
281
278
)
@@ -295,7 +292,7 @@ def test_axial_pos_emb():
295
292
296
293
# after much training
297
294
298
- one_multimodal_sample = model .sample ()
295
+ one_multimodal_sample = model .sample (max_length = 128 )
299
296
300
297
# unet related
301
298
0 commit comments