40
40
41
41
from ema_pytorch import EMA
42
42
43
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
44
+
43
45
from rotary_embedding_torch import RotaryEmbedding , apply_rotary_emb
44
46
45
47
from hyper_connections import HyperConnections
@@ -533,12 +535,10 @@ def __init__(
533
535
self .num_dimensions = num_dimensions
534
536
dim_hidden = int (dim * expand_factor )
535
537
536
- self .mlp = nn .Sequential (
537
- nn .Linear (num_dimensions , dim ),
538
- nn .SiLU (),
539
- nn .Linear (dim , dim_hidden ),
540
- nn .SiLU (),
541
- nn .Linear (dim_hidden , dim )
538
+ self .axial_pos_emb = ContinuousAxialPositionalEmbedding (
539
+ dim = dim ,
540
+ num_axial_dims = num_dimensions ,
541
+ mlp_expansion = expand_factor
542
542
)
543
543
544
544
# tensor typing
@@ -562,12 +562,9 @@ def forward(
562
562
modality_shape = modality_shape .to (self .device )
563
563
564
564
assert len (modality_shape ) == self .num_dimensions
565
- dimensions = modality_shape .tolist ()
566
-
567
- grid = torch .meshgrid ([torch .arange (dim_len , device = self .device ) for dim_len in dimensions ], indexing = 'ij' )
568
- axial_positions = stack (grid , dim = - 1 )
565
+ dimensions = tuple (modality_shape .tolist ())
569
566
570
- pos_emb = self .mlp ( axial_positions . float () )
567
+ pos_emb = self .axial_pos_emb ( dimensions )
571
568
572
569
if flatten_dims :
573
570
pos_emb = rearrange (pos_emb , '... d -> (...) d' )
0 commit comments