Skip to content

Commit 11d96ab

Browse files
committed
add optional reconstruction loss off decoder for .forward_modality for starters
1 parent 78a2f3d commit 11d96ab

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.5.1"
3+
version = "0.5.2"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train_image_only.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def forward(self, x):
4646
add_pos_emb = True,
4747
modality_num_dim = 2,
4848
velocity_consistency_loss_weight = 0.1,
49+
reconstruction_loss_weight = 0.1,
4950
transformer = dict(
5051
dim = 64,
5152
depth = 4,

transfusion_pytorch/transfusion.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class LossBreakdown(NamedTuple):
9090
text: Scalar
9191
flow: list[Scalar]
9292
velocity: list[Scalar] | None
93+
recon: list[Scalar] | None
9394

9495
class ModalityInfo(NamedTuple):
9596
encoder: Module | None
@@ -1085,6 +1086,7 @@ def __init__(
10851086
flow_loss_weight = 1.,
10861087
text_loss_weight = 1.,
10871088
velocity_consistency_loss_weight = 0.1,
1089+
reconstruction_loss_weight = 0.,
10881090
modality_encoder_decoder_requires_batch_dim = True, # whether the modality encoder / decoder requires batch dimension, will auto assume it is needed
10891091
odeint_kwargs: dict = dict(
10901092
atol = 1e-5,
@@ -1277,6 +1279,11 @@ def __init__(
12771279

12781280
self.velocity_consistency_loss_weight = velocity_consistency_loss_weight
12791281

1282+
# additional reconstruction loss, through the decoder
1283+
1284+
self.has_recon_loss = reconstruction_loss_weight > 0.
1285+
self.reconstruction_loss_weight = reconstruction_loss_weight
1286+
12801287
# flow sampling related
12811288

12821289
self.odeint_fn = partial(odeint, **odeint_kwargs)
@@ -1711,10 +1718,10 @@ def forward_modality(
17111718
return_loss = True,
17121719
return_loss_breakdown = False
17131720
) -> Scalar | Float['b ...']:
1714-
17151721
requires_velocity_consistency = exists(velocity_consistency_ema_model)
17161722

17171723
modalities = modalities.to(self.device)
1724+
orig_modalities = modalities
17181725

17191726
if self.num_modalities > 1:
17201727
assert exists(modality_type), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality'
@@ -1754,6 +1761,7 @@ def forward_modality(
17541761
noised_tokens = padded_times * tokens + (1. - padded_times) * noise
17551762

17561763
flow = tokens - noise
1764+
17571765
else:
17581766
noised_tokens = tokens
17591767

@@ -1816,17 +1824,37 @@ def forward_modality(
18161824

18171825
velocity_loss = F.mse_loss(flow, flow_with_delta_time)
18181826

1827+
# maybe recon loss
1828+
1829+
recon_loss = self.zero
1830+
1831+
if self.has_recon_loss:
1832+
assert encode_modality
1833+
1834+
recon = noise + pred_flow * (1. - padded_times)
1835+
1836+
if exists(mod.decoder):
1837+
with torch.no_grad():
1838+
mod.decoder.eval()
1839+
recon = self.maybe_add_temp_batch_dim(mod.decoder)(recon)
1840+
1841+
recon_loss = F.mse_loss(
1842+
recon,
1843+
orig_modalities
1844+
)
1845+
18191846
# total loss
18201847

18211848
total_loss = (
18221849
flow_loss +
1823-
velocity_loss * self.velocity_consistency_loss_weight
1850+
velocity_loss * self.velocity_consistency_loss_weight +
1851+
recon_loss * self.reconstruction_loss_weight
18241852
)
18251853

18261854
if not return_loss_breakdown:
18271855
return total_loss
18281856

1829-
return total_loss, (flow_loss, velocity_loss)
1857+
return total_loss, (flow_loss, velocity_loss, recon_loss)
18301858

18311859
@torch.no_grad()
18321860
@eval_decorator

0 commit comments

Comments
 (0)