@@ -1546,7 +1546,9 @@ def ode_step_fn(step_times, denoised):
1546
1546
1547
1547
parse_embed = get_pred_flows [curr_modality_id ][- 1 ]
1548
1548
1549
- flow = add_temp_batch_dim (mod .model_to_latent )(parse_embed (embeds ))
1549
+ parsed_embed = parse_embed (embeds , need_splice = not exists (cache ))
1550
+
1551
+ flow = add_temp_batch_dim (mod .model_to_latent )(parsed_embed )
1550
1552
1551
1553
return flow
1552
1554
@@ -2161,10 +2163,14 @@ def forward(
2161
2163
2162
2164
def model_to_pred_flow (batch_index , start_index , modality_length , unpack_fn ):
2163
2165
2164
- def inner (embed : Float ['b n d' ]) -> Float ['...' ]:
2165
- modality_embed = embed [batch_index , start_index :(start_index + modality_length )]
2166
- modality_embed = unpack_fn (modality_embed )
2167
- return modality_embed
2166
+ def inner (embed : Float ['b n d' ], need_splice = True ) -> Float ['...' ]:
2167
+ embed = embed [batch_index ]
2168
+
2169
+ if need_splice :
2170
+ embed = embed [start_index :(start_index + modality_length )]
2171
+
2172
+ embed = unpack_fn (embed )
2173
+ return embed
2168
2174
2169
2175
return inner
2170
2176
@@ -2334,9 +2340,9 @@ def inner(embed: Float['b n d']) -> Float['...']:
2334
2340
modality_get_pred_flows = get_pred_flows [modality_id ]
2335
2341
2336
2342
modality_pred_flows = []
2343
+
2337
2344
for get_pred_flow in modality_get_pred_flows :
2338
2345
pred_flow = get_pred_flow (embed )
2339
-
2340
2346
pred_flow = add_temp_batch_dim (mod .model_to_latent )(pred_flow )
2341
2347
modality_pred_flows .append (pred_flow )
2342
2348
0 commit comments