40
40
41
41
from ema_pytorch import EMA
42
42
43
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
44
+
43
45
from rotary_embedding_torch import RotaryEmbedding , apply_rotary_emb
44
46
45
47
from hyper_connections import HyperConnections
@@ -530,15 +532,10 @@ def __init__(
530
532
expand_factor = 2.
531
533
):
532
534
super ().__init__ ()
533
- self .num_dimensions = num_dimensions
534
- dim_hidden = int (dim * expand_factor )
535
-
536
- self .mlp = nn .Sequential (
537
- nn .Linear (num_dimensions , dim ),
538
- nn .SiLU (),
539
- nn .Linear (dim , dim_hidden ),
540
- nn .SiLU (),
541
- nn .Linear (dim_hidden , dim )
535
+ self .axial_pos_emb = ContinuousAxialPositionalEmbedding (
536
+ dim = dim ,
537
+ num_axial_dims = num_dimensions ,
538
+ mlp_expansion = expand_factor
542
539
)
543
540
544
541
# tensor typing
@@ -556,18 +553,7 @@ def forward(
556
553
flatten_dims = False
557
554
) -> Float ['... {self._d}' ]:
558
555
559
- if isinstance (modality_shape , torch .Size ):
560
- modality_shape = tensor (modality_shape )
561
-
562
- modality_shape = modality_shape .to (self .device )
563
-
564
- assert len (modality_shape ) == self .num_dimensions
565
- dimensions = modality_shape .tolist ()
566
-
567
- grid = torch .meshgrid ([torch .arange (dim_len , device = self .device ) for dim_len in dimensions ], indexing = 'ij' )
568
- axial_positions = stack (grid , dim = - 1 )
569
-
570
- pos_emb = self .mlp (axial_positions .float ())
556
+ pos_emb = self .axial_pos_emb (modality_shape )
571
557
572
558
if flatten_dims :
573
559
pos_emb = rearrange (pos_emb , '... d -> (...) d' )
0 commit comments