Skip to content

Commit 0a746ed

Browse files
committed
move logic to external lib and simplify again
1 parent 1a14b55 commit 0a746ed

File tree

2 files changed

+6
-45
lines changed

2 files changed

+6
-45
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.9.1"
3+
version = "0.9.2"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,7 +24,7 @@ classifiers=[
2424
]
2525

2626
dependencies = [
27-
'axial-positional-embedding>=0.3.4',
27+
'axial-positional-embedding>=0.3.5',
2828
'beartype',
2929
'einx>=0.3.0',
3030
'einops>=0.8.0',

transfusion_pytorch/transfusion.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -521,45 +521,6 @@ def min_p_filter(logits, min_p = 0.1):
521521
limit = min_p * max_probs
522522
return torch.where(probs < limit, float('-inf'), logits)
523523

524-
# MLP parameterized N-dimensional positions
525-
526-
class MLPAxialPositions(Module):
527-
def __init__(
528-
self,
529-
*,
530-
num_dimensions, # 2 for images, 3 for video, etc etc
531-
dim,
532-
expand_factor = 2.
533-
):
534-
super().__init__()
535-
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(
536-
dim = dim,
537-
num_axial_dims = num_dimensions,
538-
mlp_expansion = expand_factor
539-
)
540-
541-
# tensor typing
542-
543-
self._d = dim
544-
545-
@property
546-
def device(self):
547-
return next(self.parameters()).device
548-
549-
@typecheck
550-
def forward(
551-
self,
552-
modality_shape: Int['p'] | torch.Size,
553-
flatten_dims = False
554-
) -> Float['... {self._d}']:
555-
556-
pos_emb = self.axial_pos_emb(modality_shape)
557-
558-
if flatten_dims:
559-
pos_emb = rearrange(pos_emb, '... d -> (...) d')
560-
561-
return pos_emb
562-
563524
# random fourier embedding
564525

565526
class RandomFourierEmbed(Module):
@@ -1222,9 +1183,9 @@ def __init__(
12221183

12231184
assert exists(modality_ndim), '`modality_num_dim` must be set if you wish to automatically inject axial positional embeddings'
12241185

1225-
pos_generating_mlp = MLPAxialPositions(
1186+
pos_generating_mlp = ContinuousAxialPositionalEmbedding(
12261187
dim = dim,
1227-
num_dimensions = modality_ndim,
1188+
num_axial_dims = modality_ndim,
12281189
)
12291190

12301191
self.pos_emb_mlp.append(pos_generating_mlp)
@@ -1909,7 +1870,7 @@ def forward_modality(
19091870
# maybe add axial pos emb
19101871

19111872
if mod.add_pos_emb:
1912-
axial_pos_emb = mod.pos_emb_mlp(tensor(axial_dims), flatten_dims = True)
1873+
axial_pos_emb = mod.pos_emb_mlp(tensor(axial_dims), flatten = True)
19131874
noised_tokens = noised_tokens + axial_pos_emb
19141875

19151876
# attention
@@ -2365,7 +2326,7 @@ def inner(pred_flow):
23652326
if need_axial_pos_emb:
23662327

23672328
if exists(mod.pos_emb_mlp):
2368-
pos_emb = mod.pos_emb_mlp(tensor(modality_shape_tuple), flatten_dims= True)
2329+
pos_emb = mod.pos_emb_mlp(tensor(modality_shape_tuple), flatten = True)
23692330

23702331
pos_emb = F.pad(pos_emb, (0, 0, precede_modality_tokens, succeed_modality_tokens), value = 0.)
23712332
else:

0 commit comments

Comments
 (0)