4040
4141from ema_pytorch import EMA
4242
43+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
44+
4345from rotary_embedding_torch import RotaryEmbedding , apply_rotary_emb
4446
4547from hyper_connections import HyperConnections
@@ -530,15 +532,10 @@ def __init__(
530532 expand_factor = 2.
531533 ):
532534 super ().__init__ ()
533- self .num_dimensions = num_dimensions
534- dim_hidden = int (dim * expand_factor )
535-
536- self .mlp = nn .Sequential (
537- nn .Linear (num_dimensions , dim ),
538- nn .SiLU (),
539- nn .Linear (dim , dim_hidden ),
540- nn .SiLU (),
541- nn .Linear (dim_hidden , dim )
535+ self .axial_pos_emb = ContinuousAxialPositionalEmbedding (
536+ dim = dim ,
537+ num_axial_dims = num_dimensions ,
538+ mlp_expansion = expand_factor
542539 )
543540
544541 # tensor typing
@@ -556,18 +553,7 @@ def forward(
556553 flatten_dims = False
557554 ) -> Float ['... {self._d}' ]:
558555
559- if isinstance (modality_shape , torch .Size ):
560- modality_shape = tensor (modality_shape )
561-
562- modality_shape = modality_shape .to (self .device )
563-
564- assert len (modality_shape ) == self .num_dimensions
565- dimensions = modality_shape .tolist ()
566-
567- grid = torch .meshgrid ([torch .arange (dim_len , device = self .device ) for dim_len in dimensions ], indexing = 'ij' )
568- axial_positions = stack (grid , dim = - 1 )
569-
570- pos_emb = self .mlp (axial_positions .float ())
556+ pos_emb = self .axial_pos_emb (modality_shape )
571557
572558 if flatten_dims :
573559 pos_emb = rearrange (pos_emb , '... d -> (...) d' )
0 commit comments