Skip to content

Commit c6f9ccf

Browse files
committed
fix cache kv edge case
1 parent d7e726b commit c6f9ccf

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.5.0"
3+
version = "0.5.1"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,7 +1546,9 @@ def ode_step_fn(step_times, denoised):
15461546

15471547
parse_embed = get_pred_flows[curr_modality_id][-1]
15481548

1549-
flow = add_temp_batch_dim(mod.model_to_latent)(parse_embed(embeds))
1549+
parsed_embed = parse_embed(embeds, need_splice = not exists(cache))
1550+
1551+
flow = add_temp_batch_dim(mod.model_to_latent)(parsed_embed)
15501552

15511553
return flow
15521554

@@ -2161,10 +2163,14 @@ def forward(
21612163

21622164
def model_to_pred_flow(batch_index, start_index, modality_length, unpack_fn):
21632165

2164-
def inner(embed: Float['b n d']) -> Float['...']:
2165-
modality_embed = embed[batch_index, start_index:(start_index + modality_length)]
2166-
modality_embed = unpack_fn(modality_embed)
2167-
return modality_embed
2166+
def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
2167+
embed = embed[batch_index]
2168+
2169+
if need_splice:
2170+
embed = embed[start_index:(start_index + modality_length)]
2171+
2172+
embed = unpack_fn(embed)
2173+
return embed
21682174

21692175
return inner
21702176

@@ -2334,9 +2340,9 @@ def inner(embed: Float['b n d']) -> Float['...']:
23342340
modality_get_pred_flows = get_pred_flows[modality_id]
23352341

23362342
modality_pred_flows = []
2343+
23372344
for get_pred_flow in modality_get_pred_flows:
23382345
pred_flow = get_pred_flow(embed)
2339-
23402346
pred_flow = add_temp_batch_dim(mod.model_to_latent)(pred_flow)
23412347
modality_pred_flows.append(pred_flow)
23422348

0 commit comments

Comments
 (0)