Skip to content

Commit 2c3f70d

Browse files
committed
move axial positional embedding to a factorized version in a reusable lib
1 parent 8b2d0be commit 2c3f70d

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.8.0"
3+
version = "0.9.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,6 +24,7 @@ classifiers=[
2424
]
2525

2626
dependencies = [
27+
'axial-positional-embedding>=0.3.3',
2728
'beartype',
2829
'einx>=0.3.0',
2930
'einops>=0.8.0',

transfusion_pytorch/transfusion.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
from ema_pytorch import EMA
4242

43+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
44+
4345
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
4446

4547
from hyper_connections import HyperConnections
@@ -533,12 +535,10 @@ def __init__(
533535
self.num_dimensions = num_dimensions
534536
dim_hidden = int(dim * expand_factor)
535537

536-
self.mlp = nn.Sequential(
537-
nn.Linear(num_dimensions, dim),
538-
nn.SiLU(),
539-
nn.Linear(dim, dim_hidden),
540-
nn.SiLU(),
541-
nn.Linear(dim_hidden, dim)
538+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(
539+
dim = dim,
540+
num_axial_dims = num_dimensions,
541+
mlp_expansion = expand_factor
542542
)
543543

544544
# tensor typing
@@ -562,12 +562,9 @@ def forward(
562562
modality_shape = modality_shape.to(self.device)
563563

564564
assert len(modality_shape) == self.num_dimensions
565-
dimensions = modality_shape.tolist()
566-
567-
grid = torch.meshgrid([torch.arange(dim_len, device = self.device) for dim_len in dimensions], indexing = 'ij')
568-
axial_positions = stack(grid, dim = -1)
565+
dimensions = tuple(modality_shape.tolist())
569566

570-
pos_emb = self.mlp(axial_positions.float())
567+
pos_emb = self.axial_pos_emb(dimensions)
571568

572569
if flatten_dims:
573570
pos_emb = rearrange(pos_emb, '... d -> (...) d')

0 commit comments

Comments
 (0)