@@ -2172,6 +2172,10 @@ def inner(pred_flow):
2172
2172
2173
2173
return inner
2174
2174
2175
+ # prepare storing of sizes of all modalities that require axial positions, for delayed application for efficiency
2176
+
2177
+ pos_emb_max_axial_dims : dict [int , list [Tensor ]] = defaultdict (list )
2178
+
2175
2179
# go through all modality samples and do necessary transform
2176
2180
2177
2181
for batch_index , batch_modalities in enumerate (modalities ):
@@ -2326,9 +2330,9 @@ def inner(pred_flow):
2326
2330
if need_axial_pos_emb :
2327
2331
2328
2332
if exists (mod .pos_emb_mlp ):
2329
- pos_emb = mod .pos_emb_mlp (tensor (modality_shape_tuple ), flatten = True )
2333
+ pos_emb_max_axial_dims [modality_type ].append (tensor (modality_shape_tuple ))
2334
+ pos_emb = (modality_type , modality_shape_tuple , (precede_modality_tokens , succeed_modality_tokens ))
2330
2335
2331
- pos_emb = F .pad (pos_emb , (0 , 0 , precede_modality_tokens , succeed_modality_tokens ), value = 0. )
2332
2336
else :
2333
2337
pos_emb = torch .zeros (text_tensor .shape [0 ], self .dim , device = device )
2334
2338
@@ -2337,7 +2341,7 @@ def inner(pred_flow):
2337
2341
text .append (cat (batch_text ))
2338
2342
2339
2343
if need_axial_pos_emb :
2340
- modality_pos_emb .append (cat ( batch_modality_pos_emb , dim = - 2 ) )
2344
+ modality_pos_emb .append (batch_modality_pos_emb )
2341
2345
2342
2346
modality_tokens .append (cat (batch_modality_tokens ))
2343
2347
modality_positions .append (batch_modality_positions )
@@ -2351,8 +2355,38 @@ def inner(pred_flow):
2351
2355
2352
2356
modality_tokens = pad_sequence (modality_tokens , padding_value = 0. )
2353
2357
2358
+ # handle modality positional embedding
2359
+
2354
2360
if need_axial_pos_emb :
2355
- modality_pos_emb = pad_sequence (modality_pos_emb , padding_value = 0. )
2361
+ pos_emb_max_axial_dims = {mod_type : stack (sizes , dim = - 1 ).amax (dim = - 1 ) for mod_type , sizes in pos_emb_max_axial_dims .items ()}
2362
+ factorized_pos_emb = {mod_type : self .get_modality_info (mod_type ).pos_emb_mlp (max_size , return_factorized = True ) for mod_type , max_size in pos_emb_max_axial_dims .items ()}
2363
+
2364
+ # lazy evaluate the modality positional embedding from the factorized positional embedding from maximum axial dims
2365
+
2366
+ evaluated_pos_emb = []
2367
+
2368
+ for batch_modality_pos_emb in modality_pos_emb :
2369
+ evaluated_batch_pos_emb = []
2370
+
2371
+ for maybe_pos_emb_config in batch_modality_pos_emb :
2372
+
2373
+ if is_tensor (maybe_pos_emb_config ):
2374
+ evaluated_batch_pos_emb .append (maybe_pos_emb_config )
2375
+ continue
2376
+
2377
+ mod_type , mod_size , padding = maybe_pos_emb_config
2378
+
2379
+ mod_info = self .get_modality_info (mod_type )
2380
+ mod_factorized_pos_emb = factorized_pos_emb [mod_type ]
2381
+
2382
+ mod_pos_emb = mod_info .pos_emb_mlp .combine_factorized (mod_factorized_pos_emb , mod_size , flatten = True )
2383
+ mod_pos_emb = F .pad (mod_pos_emb , (0 , 0 , * padding ), value = 0. ) # handle padding for preceding and succeeding meta tokens
2384
+
2385
+ evaluated_batch_pos_emb .append (mod_pos_emb )
2386
+
2387
+ evaluated_pos_emb .append (cat (evaluated_batch_pos_emb , dim = - 2 ))
2388
+
2389
+ modality_pos_emb = pad_sequence (evaluated_pos_emb , padding_value = 0. )
2356
2390
2357
2391
# handle training mode and removal of last token
2358
2392
@@ -2384,7 +2418,6 @@ def inner(pred_flow):
2384
2418
if modality_positions .numel () == 0 :
2385
2419
modality_positions = F .pad (modality_positions , (0 , 0 , 0 , 1 ))
2386
2420
2387
-
2388
2421
# sort the modalities tensor and sanitize, readying for noising of modalities
2389
2422
2390
2423
modality_positions , sorted_indices = order_modality_positions_by_seq_offset (modality_positions )
0 commit comments