@@ -90,6 +90,7 @@ class LossBreakdown(NamedTuple):
90
90
text : Scalar
91
91
flow : list [Scalar ]
92
92
velocity : list [Scalar ] | None
93
+ recon : list [Scalar ] | None
93
94
94
95
class ModalityInfo (NamedTuple ):
95
96
encoder : Module | None
@@ -1085,6 +1086,7 @@ def __init__(
1085
1086
flow_loss_weight = 1. ,
1086
1087
text_loss_weight = 1. ,
1087
1088
velocity_consistency_loss_weight = 0.1 ,
1089
+ reconstruction_loss_weight = 0. ,
1088
1090
modality_encoder_decoder_requires_batch_dim = True , # whether the modality encoder / decoder requires batch dimension, will auto assume it is needed
1089
1091
odeint_kwargs : dict = dict (
1090
1092
atol = 1e-5 ,
@@ -1277,6 +1279,11 @@ def __init__(
1277
1279
1278
1280
self .velocity_consistency_loss_weight = velocity_consistency_loss_weight
1279
1281
1282
+ # additional reconstruction loss, through the decoder
1283
+
1284
+ self .has_recon_loss = reconstruction_loss_weight > 0.
1285
+ self .reconstruction_loss_weight = reconstruction_loss_weight
1286
+
1280
1287
# flow sampling related
1281
1288
1282
1289
self .odeint_fn = partial (odeint , ** odeint_kwargs )
@@ -1711,10 +1718,10 @@ def forward_modality(
1711
1718
return_loss = True ,
1712
1719
return_loss_breakdown = False
1713
1720
) -> Scalar | Float ['b ...' ]:
1714
-
1715
1721
requires_velocity_consistency = exists (velocity_consistency_ema_model )
1716
1722
1717
1723
modalities = modalities .to (self .device )
1724
+ orig_modalities = modalities
1718
1725
1719
1726
if self .num_modalities > 1 :
1720
1727
assert exists (modality_type ), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality'
@@ -1754,6 +1761,7 @@ def forward_modality(
1754
1761
noised_tokens = padded_times * tokens + (1. - padded_times ) * noise
1755
1762
1756
1763
flow = tokens - noise
1764
+
1757
1765
else :
1758
1766
noised_tokens = tokens
1759
1767
@@ -1816,17 +1824,37 @@ def forward_modality(
1816
1824
1817
1825
velocity_loss = F .mse_loss (flow , flow_with_delta_time )
1818
1826
1827
+ # maybe recon loss
1828
+
1829
+ recon_loss = self .zero
1830
+
1831
+ if self .has_recon_loss :
1832
+ assert encode_modality
1833
+
1834
+ recon = noise + pred_flow * (1. - padded_times )
1835
+
1836
+ if exists (mod .decoder ):
1837
+ with torch .no_grad ():
1838
+ mod .decoder .eval ()
1839
+ recon = self .maybe_add_temp_batch_dim (mod .decoder )(recon )
1840
+
1841
+ recon_loss = F .mse_loss (
1842
+ recon ,
1843
+ orig_modalities
1844
+ )
1845
+
1819
1846
# total loss
1820
1847
1821
1848
total_loss = (
1822
1849
flow_loss +
1823
- velocity_loss * self .velocity_consistency_loss_weight
1850
+ velocity_loss * self .velocity_consistency_loss_weight +
1851
+ recon_loss * self .reconstruction_loss_weight
1824
1852
)
1825
1853
1826
1854
if not return_loss_breakdown :
1827
1855
return total_loss
1828
1856
1829
- return total_loss , (flow_loss , velocity_loss )
1857
+ return total_loss , (flow_loss , velocity_loss , recon_loss )
1830
1858
1831
1859
@torch .no_grad ()
1832
1860
@eval_decorator
0 commit comments