Skip to content

Commit 11653db

Browse files
committed
complete the optional reconstruction loss for mixed modality training
1 parent b6e9226 commit 11653db

File tree

3 files changed

+68
-19
lines changed

3 files changed

+68
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.5.4"
3+
version = "0.5.5"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_transfusion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
@pytest.mark.parametrize('cache_kv', (False, True))
2222
@pytest.mark.parametrize('use_flex_attn', (False, True))
23+
@pytest.mark.parametrize('reconstruction_loss_weight', (0., 0.1))
2324
def test_transfusion(
2425
cache_kv: bool,
2526
use_flex_attn: bool,
27+
reconstruction_loss_weight: float
2628
):
2729

2830
if use_flex_attn and (not exists(flex_attention) or not cuda_available):
@@ -33,8 +35,9 @@ def test_transfusion(
3335

3436
model = Transfusion(
3537
num_text_tokens = text_tokens,
36-
dim_latent = (384, 192), # specify multiple latent dimensions
38+
dim_latent = (384, 192),
3739
modality_default_shape = ((32,), (64,)),
40+
reconstruction_loss_weight = reconstruction_loss_weight,
3841
transformer = dict(
3942
dim = 64,
4043
depth = 2,

transfusion_pytorch/transfusion.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class LossBreakdown(NamedTuple):
8989
total: Scalar
9090
text: Scalar
9191
flow: list[Scalar]
92-
velocity: list[Scalar] | None
92+
velocity: list[Scalar] | None = None
9393
recon: list[Scalar] | None = None
9494

9595
class ModalityInfo(NamedTuple):
@@ -2057,8 +2057,38 @@ def forward(
20572057
text = []
20582058

20592059
flows = defaultdict(list) # store flows for loss
2060+
2061+
# for parsing out the predicted flow from flattened sequence of tokens coming out of transformer
2062+
20602063
get_pred_flows: GetPredFlows = defaultdict(list) # functions for parsing modalities from Float['b n d'] for model back to latents or pixel space
20612064

2065+
def model_to_pred_flow(batch_index, start_index, modality_length, unpack_fn):
2066+
2067+
def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
2068+
embed = embed[batch_index]
2069+
2070+
if need_splice:
2071+
embed = embed[start_index:(start_index + modality_length)]
2072+
2073+
embed = unpack_fn(embed)
2074+
return embed
2075+
2076+
return inner
2077+
2078+
# for going from predicted flow -> reconstruction
2079+
2080+
get_recon_losses: Callable[[Tensor], Tensor] = defaultdict(list)
2081+
2082+
def get_recon_loss(noise, times, modality):
2083+
2084+
def inner(pred_flow):
2085+
recon_modality = noise + pred_flow * (1. - times)
2086+
return F.mse_loss(modality, recon_modality)
2087+
2088+
return inner
2089+
2090+
# go through all modality samples and do necessary transform
2091+
20622092
for batch_index, batch_modalities in enumerate(modalities):
20632093

20642094
modality_index = 0
@@ -2147,6 +2177,10 @@ def forward(
21472177

21482178
modality_tensor = noised_modality
21492179

2180+
# store function for deriving reconstruction loss from decoder
2181+
2182+
get_recon_losses[modality_type].append(get_recon_loss(noise, modality_time, modality_tensor))
2183+
21502184
# go through maybe encoder
21512185

21522186
modality_tensor = add_temp_batch_dim(mod.latent_to_model)(modality_tensor)
@@ -2189,19 +2223,6 @@ def forward(
21892223

21902224
modality_tensor, unpack_modality_shape = pack_one_with_inverse(modality_tensor, '* d')
21912225

2192-
def model_to_pred_flow(batch_index, start_index, modality_length, unpack_fn):
2193-
2194-
def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
2195-
embed = embed[batch_index]
2196-
2197-
if need_splice:
2198-
embed = embed[start_index:(start_index + modality_length)]
2199-
2200-
embed = unpack_fn(embed)
2201-
return embed
2202-
2203-
return inner
2204-
22052226
inverse_fn = model_to_pred_flow(batch_index, offset + precede_modality_tokens, modality_length, unpack_modality_shape)
22062227

22072228
get_pred_flows[modality_type].append(inverse_fn)
@@ -2362,19 +2383,30 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
23622383
# flow loss
23632384

23642385
pred_flows = []
2386+
recon_losses = []
23652387

23662388
for modality_id in range(self.num_modalities):
23672389
mod = self.get_modality_info(modality_id)
2390+
23682391
modality_get_pred_flows = get_pred_flows[modality_id]
2392+
modality_get_recon_losses = get_recon_losses[modality_id]
23692393

23702394
modality_pred_flows = []
2395+
modality_recon_losses = []
2396+
2397+
for get_pred_flow, get_recon_loss in zip(modality_get_pred_flows, modality_get_recon_losses):
23712398

2372-
for get_pred_flow in modality_get_pred_flows:
23732399
pred_flow = get_pred_flow(embed)
23742400
pred_flow = add_temp_batch_dim(mod.model_to_latent)(pred_flow)
23752401
modality_pred_flows.append(pred_flow)
23762402

2403+
if not return_loss or not self.has_recon_loss:
2404+
continue
2405+
2406+
modality_recon_losses.append(get_recon_loss(pred_flow))
2407+
23772408
pred_flows.append(modality_pred_flows)
2409+
recon_losses.append(modality_recon_losses)
23782410

23792411
# early return for velocity consistency ema model
23802412

@@ -2448,7 +2480,7 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
24482480

24492481
velocity_match_losses = []
24502482

2451-
for mod, ema_pred_flow, pred_flow, is_one_modality in zip(self.get_all_modality_info(), ema_pred_flows, pred_flows, is_modalities.unbind(dim = 1)):
2483+
for ema_pred_flow, pred_flow in zip(ema_pred_flows, pred_flows):
24522484

24532485
pack_pattern = 'd *' if mod.channel_first_latent else '* d'
24542486
pred_flow, _ = pack(pred_flow, pack_pattern)
@@ -2466,9 +2498,23 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
24662498
(stack(velocity_match_losses) * modality_loss_weights).sum() * self.velocity_consistency_loss_weight
24672499
)
24682500

2501+
# maybe reconstruction loss
2502+
2503+
if self.has_recon_loss:
2504+
2505+
averaged_recon_losses = []
2506+
2507+
for modality_recon_loss in recon_losses:
2508+
averaged_recon_losses.append(sum(modality_recon_loss) / len(modality_recon_loss))
2509+
2510+
total_loss = (
2511+
total_loss +
2512+
(stack(averaged_recon_losses) * modality_loss_weights).sum() * self.reconstruction_loss_weight
2513+
)
2514+
24692515
# return total loss if no breakdown needed
24702516

24712517
if not return_breakdown:
24722518
return total_loss
24732519

2474-
return total_loss, LossBreakdown(total_loss, text_loss, flow_losses, velocity_match_losses)
2520+
return total_loss, LossBreakdown(total_loss, text_loss, flow_losses, velocity_match_losses, recon_losses)

0 commit comments

Comments
 (0)