@@ -89,7 +89,7 @@ class LossBreakdown(NamedTuple):
89
89
total : Scalar
90
90
text : Scalar
91
91
flow : list [Scalar ]
92
- velocity : list [Scalar ] | None
92
+ velocity : list [Scalar ] | None = None
93
93
recon : list [Scalar ] | None = None
94
94
95
95
class ModalityInfo (NamedTuple ):
@@ -2057,8 +2057,38 @@ def forward(
2057
2057
text = []
2058
2058
2059
2059
flows = defaultdict (list ) # store flows for loss
2060
+
2061
+ # for parsing out the predicted flow from flattened sequence of tokens coming out of transformer
2062
+
2060
2063
get_pred_flows : GetPredFlows = defaultdict (list ) # functions for parsing modalities from Float['b n d'] for model back to latents or pixel space
2061
2064
2065
+ def model_to_pred_flow (batch_index , start_index , modality_length , unpack_fn ):
2066
+
2067
+ def inner (embed : Float ['b n d' ], need_splice = True ) -> Float ['...' ]:
2068
+ embed = embed [batch_index ]
2069
+
2070
+ if need_splice :
2071
+ embed = embed [start_index :(start_index + modality_length )]
2072
+
2073
+ embed = unpack_fn (embed )
2074
+ return embed
2075
+
2076
+ return inner
2077
+
2078
+ # for going from predicted flow -> reconstruction
2079
+
2080
+ get_recon_losses : Callable [[Tensor ], Tensor ] = defaultdict (list )
2081
+
2082
+ def get_recon_loss (noise , times , modality ):
2083
+
2084
+ def inner (pred_flow ):
2085
+ recon_modality = noise + pred_flow * (1. - times )
2086
+ return F .mse_loss (modality , recon_modality )
2087
+
2088
+ return inner
2089
+
2090
+ # go through all modality samples and do necessary transform
2091
+
2062
2092
for batch_index , batch_modalities in enumerate (modalities ):
2063
2093
2064
2094
modality_index = 0
@@ -2147,6 +2177,10 @@ def forward(
2147
2177
2148
2178
modality_tensor = noised_modality
2149
2179
2180
+ # store function for deriving reconstruction loss from decoder
2181
+
2182
+ get_recon_losses [modality_type ].append (get_recon_loss (noise , modality_time , modality_tensor ))
2183
+
2150
2184
# go through maybe encoder
2151
2185
2152
2186
modality_tensor = add_temp_batch_dim (mod .latent_to_model )(modality_tensor )
@@ -2189,19 +2223,6 @@ def forward(
2189
2223
2190
2224
modality_tensor , unpack_modality_shape = pack_one_with_inverse (modality_tensor , '* d' )
2191
2225
2192
- def model_to_pred_flow (batch_index , start_index , modality_length , unpack_fn ):
2193
-
2194
- def inner (embed : Float ['b n d' ], need_splice = True ) -> Float ['...' ]:
2195
- embed = embed [batch_index ]
2196
-
2197
- if need_splice :
2198
- embed = embed [start_index :(start_index + modality_length )]
2199
-
2200
- embed = unpack_fn (embed )
2201
- return embed
2202
-
2203
- return inner
2204
-
2205
2226
inverse_fn = model_to_pred_flow (batch_index , offset + precede_modality_tokens , modality_length , unpack_modality_shape )
2206
2227
2207
2228
get_pred_flows [modality_type ].append (inverse_fn )
@@ -2362,19 +2383,30 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
2362
2383
# flow loss
2363
2384
2364
2385
pred_flows = []
2386
+ recon_losses = []
2365
2387
2366
2388
for modality_id in range (self .num_modalities ):
2367
2389
mod = self .get_modality_info (modality_id )
2390
+
2368
2391
modality_get_pred_flows = get_pred_flows [modality_id ]
2392
+ modality_get_recon_losses = get_recon_losses [modality_id ]
2369
2393
2370
2394
modality_pred_flows = []
2395
+ modality_recon_losses = []
2396
+
2397
+ for get_pred_flow , get_recon_loss in zip (modality_get_pred_flows , modality_get_recon_losses ):
2371
2398
2372
- for get_pred_flow in modality_get_pred_flows :
2373
2399
pred_flow = get_pred_flow (embed )
2374
2400
pred_flow = add_temp_batch_dim (mod .model_to_latent )(pred_flow )
2375
2401
modality_pred_flows .append (pred_flow )
2376
2402
2403
+ if not return_loss or not self .has_recon_loss :
2404
+ continue
2405
+
2406
+ modality_recon_losses .append (get_recon_loss (pred_flow ))
2407
+
2377
2408
pred_flows .append (modality_pred_flows )
2409
+ recon_losses .append (modality_recon_losses )
2378
2410
2379
2411
# early return for velocity consistency ema model
2380
2412
@@ -2448,7 +2480,7 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
2448
2480
2449
2481
velocity_match_losses = []
2450
2482
2451
- for mod , ema_pred_flow , pred_flow , is_one_modality in zip (self . get_all_modality_info (), ema_pred_flows , pred_flows , is_modalities . unbind ( dim = 1 ) ):
2483
+ for ema_pred_flow , pred_flow in zip (ema_pred_flows , pred_flows ):
2452
2484
2453
2485
pack_pattern = 'd *' if mod .channel_first_latent else '* d'
2454
2486
pred_flow , _ = pack (pred_flow , pack_pattern )
@@ -2466,9 +2498,23 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
2466
2498
(stack (velocity_match_losses ) * modality_loss_weights ).sum () * self .velocity_consistency_loss_weight
2467
2499
)
2468
2500
2501
+ # maybe reconstruction loss
2502
+
2503
+ if self .has_recon_loss :
2504
+
2505
+ averaged_recon_losses = []
2506
+
2507
+ for modality_recon_loss in recon_losses :
2508
+ averaged_recon_losses .append (sum (modality_recon_loss ) / len (modality_recon_loss ))
2509
+
2510
+ total_loss = (
2511
+ total_loss +
2512
+ (stack (averaged_recon_losses ) * modality_loss_weights ).sum () * self .reconstruction_loss_weight
2513
+ )
2514
+
2469
2515
# return total loss if no breakdown needed
2470
2516
2471
2517
if not return_breakdown :
2472
2518
return total_loss
2473
2519
2474
- return total_loss , LossBreakdown (total_loss , text_loss , flow_losses , velocity_match_losses )
2520
+ return total_loss , LossBreakdown (total_loss , text_loss , flow_losses , velocity_match_losses , recon_losses )
0 commit comments