@@ -521,45 +521,6 @@ def min_p_filter(logits, min_p = 0.1):
521
521
limit = min_p * max_probs
522
522
return torch .where (probs < limit , float ('-inf' ), logits )
523
523
524
- # MLP parameterized N-dimensional positions
525
-
526
- class MLPAxialPositions (Module ):
527
- def __init__ (
528
- self ,
529
- * ,
530
- num_dimensions , # 2 for images, 3 for video, etc etc
531
- dim ,
532
- expand_factor = 2.
533
- ):
534
- super ().__init__ ()
535
- self .axial_pos_emb = ContinuousAxialPositionalEmbedding (
536
- dim = dim ,
537
- num_axial_dims = num_dimensions ,
538
- mlp_expansion = expand_factor
539
- )
540
-
541
- # tensor typing
542
-
543
- self ._d = dim
544
-
545
- @property
546
- def device (self ):
547
- return next (self .parameters ()).device
548
-
549
- @typecheck
550
- def forward (
551
- self ,
552
- modality_shape : Int ['p' ] | torch .Size ,
553
- flatten_dims = False
554
- ) -> Float ['... {self._d}' ]:
555
-
556
- pos_emb = self .axial_pos_emb (modality_shape )
557
-
558
- if flatten_dims :
559
- pos_emb = rearrange (pos_emb , '... d -> (...) d' )
560
-
561
- return pos_emb
562
-
563
524
# random fourier embedding
564
525
565
526
class RandomFourierEmbed (Module ):
@@ -1222,9 +1183,9 @@ def __init__(
1222
1183
1223
1184
assert exists (modality_ndim ), '`modality_num_dim` must be set if you wish to automatically inject axial positional embeddings'
1224
1185
1225
- pos_generating_mlp = MLPAxialPositions (
1186
+ pos_generating_mlp = ContinuousAxialPositionalEmbedding (
1226
1187
dim = dim ,
1227
- num_dimensions = modality_ndim ,
1188
+ num_axial_dims = modality_ndim ,
1228
1189
)
1229
1190
1230
1191
self .pos_emb_mlp .append (pos_generating_mlp )
@@ -1909,7 +1870,7 @@ def forward_modality(
1909
1870
# maybe add axial pos emb
1910
1871
1911
1872
if mod .add_pos_emb :
1912
- axial_pos_emb = mod .pos_emb_mlp (tensor (axial_dims ), flatten_dims = True )
1873
+ axial_pos_emb = mod .pos_emb_mlp (tensor (axial_dims ), flatten = True )
1913
1874
noised_tokens = noised_tokens + axial_pos_emb
1914
1875
1915
1876
# attention
@@ -2365,7 +2326,7 @@ def inner(pred_flow):
2365
2326
if need_axial_pos_emb :
2366
2327
2367
2328
if exists (mod .pos_emb_mlp ):
2368
- pos_emb = mod .pos_emb_mlp (tensor (modality_shape_tuple ), flatten_dims = True )
2329
+ pos_emb = mod .pos_emb_mlp (tensor (modality_shape_tuple ), flatten = True )
2369
2330
2370
2331
pos_emb = F .pad (pos_emb , (0 , 0 , precede_modality_tokens , succeed_modality_tokens ), value = 0. )
2371
2332
else :
0 commit comments