Skip to content

Commit 1a14b55

Browse files
committed
move axial positional embedding to a factorized version in a reusable lib
1 parent 8b2d0be commit 1a14b55

File tree

2 files changed

+9
-22
lines changed

2 files changed

+9
-22
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.8.0"
3+
version = "0.9.1"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,6 +24,7 @@ classifiers=[
2424
]
2525

2626
dependencies = [
27+
'axial-positional-embedding>=0.3.4',
2728
'beartype',
2829
'einx>=0.3.0',
2930
'einops>=0.8.0',

transfusion_pytorch/transfusion.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
from ema_pytorch import EMA
4242

43+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
44+
4345
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
4446

4547
from hyper_connections import HyperConnections
@@ -530,15 +532,10 @@ def __init__(
530532
expand_factor = 2.
531533
):
532534
super().__init__()
533-
self.num_dimensions = num_dimensions
534-
dim_hidden = int(dim * expand_factor)
535-
536-
self.mlp = nn.Sequential(
537-
nn.Linear(num_dimensions, dim),
538-
nn.SiLU(),
539-
nn.Linear(dim, dim_hidden),
540-
nn.SiLU(),
541-
nn.Linear(dim_hidden, dim)
535+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(
536+
dim = dim,
537+
num_axial_dims = num_dimensions,
538+
mlp_expansion = expand_factor
542539
)
543540

544541
# tensor typing
@@ -556,18 +553,7 @@ def forward(
556553
flatten_dims = False
557554
) -> Float['... {self._d}']:
558555

559-
if isinstance(modality_shape, torch.Size):
560-
modality_shape = tensor(modality_shape)
561-
562-
modality_shape = modality_shape.to(self.device)
563-
564-
assert len(modality_shape) == self.num_dimensions
565-
dimensions = modality_shape.tolist()
566-
567-
grid = torch.meshgrid([torch.arange(dim_len, device = self.device) for dim_len in dimensions], indexing = 'ij')
568-
axial_positions = stack(grid, dim = -1)
569-
570-
pos_emb = self.mlp(axial_positions.float())
556+
pos_emb = self.axial_pos_emb(modality_shape)
571557

572558
if flatten_dims:
573559
pos_emb = rearrange(pos_emb, '... d -> (...) d')

0 commit comments

Comments
 (0)