Skip to content

Commit a80532c

Browse files
committed
axial positional embeddings will be calculate only once per modality type with the max axial dimensions
1 parent 0a746ed commit a80532c

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.9.2"
3+
version = "0.9.3"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,10 @@ def inner(pred_flow):
21722172

21732173
return inner
21742174

2175+
# prepare storing of sizes of all modalities that require axial positions, for delayed application for efficiency
2176+
2177+
pos_emb_max_axial_dims: dict[int, list[Tensor]] = defaultdict(list)
2178+
21752179
# go through all modality samples and do necessary transform
21762180

21772181
for batch_index, batch_modalities in enumerate(modalities):
@@ -2326,9 +2330,9 @@ def inner(pred_flow):
23262330
if need_axial_pos_emb:
23272331

23282332
if exists(mod.pos_emb_mlp):
2329-
pos_emb = mod.pos_emb_mlp(tensor(modality_shape_tuple), flatten = True)
2333+
pos_emb_max_axial_dims[modality_type].append(tensor(modality_shape_tuple))
2334+
pos_emb = (modality_type, modality_shape_tuple, (precede_modality_tokens, succeed_modality_tokens))
23302335

2331-
pos_emb = F.pad(pos_emb, (0, 0, precede_modality_tokens, succeed_modality_tokens), value = 0.)
23322336
else:
23332337
pos_emb = torch.zeros(text_tensor.shape[0], self.dim, device = device)
23342338

@@ -2337,7 +2341,7 @@ def inner(pred_flow):
23372341
text.append(cat(batch_text))
23382342

23392343
if need_axial_pos_emb:
2340-
modality_pos_emb.append(cat(batch_modality_pos_emb, dim = -2))
2344+
modality_pos_emb.append(batch_modality_pos_emb)
23412345

23422346
modality_tokens.append(cat(batch_modality_tokens))
23432347
modality_positions.append(batch_modality_positions)
@@ -2351,8 +2355,38 @@ def inner(pred_flow):
23512355

23522356
modality_tokens = pad_sequence(modality_tokens, padding_value = 0.)
23532357

2358+
# handle modality positional embedding
2359+
23542360
if need_axial_pos_emb:
2355-
modality_pos_emb = pad_sequence(modality_pos_emb, padding_value = 0.)
2361+
pos_emb_max_axial_dims = {mod_type: stack(sizes, dim = -1).amax(dim = -1) for mod_type, sizes in pos_emb_max_axial_dims.items()}
2362+
factorized_pos_emb = {mod_type: self.get_modality_info(mod_type).pos_emb_mlp(max_size, return_factorized = True) for mod_type, max_size in pos_emb_max_axial_dims.items()}
2363+
2364+
# lazy evaluate the modality positional embedding from the factorized positional embedding from maximum axial dims
2365+
2366+
evaluated_pos_emb = []
2367+
2368+
for batch_modality_pos_emb in modality_pos_emb:
2369+
evaluated_batch_pos_emb = []
2370+
2371+
for maybe_pos_emb_config in batch_modality_pos_emb:
2372+
2373+
if is_tensor(maybe_pos_emb_config):
2374+
evaluated_batch_pos_emb.append(maybe_pos_emb_config)
2375+
continue
2376+
2377+
mod_type, mod_size, padding = maybe_pos_emb_config
2378+
2379+
mod_info = self.get_modality_info(mod_type)
2380+
mod_factorized_pos_emb = factorized_pos_emb[mod_type]
2381+
2382+
mod_pos_emb = mod_info.pos_emb_mlp.combine_factorized(mod_factorized_pos_emb, mod_size, flatten = True)
2383+
mod_pos_emb = F.pad(mod_pos_emb, (0, 0, *padding), value = 0.) # handle padding for preceding and succeeding meta tokens
2384+
2385+
evaluated_batch_pos_emb.append(mod_pos_emb)
2386+
2387+
evaluated_pos_emb.append(cat(evaluated_batch_pos_emb, dim = -2))
2388+
2389+
modality_pos_emb = pad_sequence(evaluated_pos_emb, padding_value = 0.)
23562390

23572391
# handle training mode and removal of last token
23582392

@@ -2384,7 +2418,6 @@ def inner(pred_flow):
23842418
if modality_positions.numel() == 0:
23852419
modality_positions = F.pad(modality_positions, (0, 0, 0, 1))
23862420

2387-
23882421
# sort the modalities tensor and sanitize, readying for noising of modalities
23892422

23902423
modality_positions, sorted_indices = order_modality_positions_by_seq_offset(modality_positions)

0 commit comments

Comments
 (0)