Skip to content

Commit 8b354fa

Browse files
committed
more efficient flow loss
1 parent 0476b4e commit 8b354fa

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.4.15"
3+
version = "0.4.16"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,20 +2334,16 @@ def forward(
23342334
for modality_id, (flow, pred_flow, is_one_modality) in enumerate(zip(flows, pred_flows, is_modalities.unbind(dim = 1))):
23352335
mod = self.get_modality_info(modality_id)
23362336

2337-
flow_loss = F.mse_loss(
2338-
pred_flow,
2339-
flow,
2340-
reduction = 'none'
2341-
)
2342-
2343-
if mod.channel_first_latent:
2344-
flow_loss = rearrange(flow_loss, 'b d ... -> b ... d')
2345-
23462337
is_one_modality = reduce(is_one_modality, 'b m n -> b n', 'any')
2338+
modality_loss_weight = is_one_modality.sum() / total_tokens
23472339

2348-
flow_loss = flow_loss[is_one_modality].mean()
2340+
if mod.channel_first_latent:
2341+
pred_flow, flow = tuple(rearrange(t, 'b d ... -> b ... d') for t in (pred_flow, flow))
23492342

2350-
modality_loss_weight = is_one_modality.sum() / total_tokens
2343+
flow_loss = F.mse_loss(
2344+
pred_flow[is_one_modality],
2345+
flow[is_one_modality]
2346+
)
23512347

23522348
modality_loss_weights.append(modality_loss_weight)
23532349

0 commit comments

Comments
 (0)