Skip to content

Commit 4325dd8

Browse files
committed
batched decoding
1 parent 22a9c50 commit 4325dd8

File tree

2 files changed

+8
-18
lines changed

2 files changed

+8
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.10.0"
3+
version = "0.10.1"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,26 +1805,16 @@ def ode_step_fn(step_times, denoised):
18051805
if return_unprocessed_modalities:
18061806
return modality_sample
18071807

1808-
# post process modalities
1808+
# post process modality sample, decoding modality types if they have a decoder
18091809

1810-
processed_modality_sample = []
1810+
for mod in self.get_all_modality_info():
1811+
decoder_fn = default(mod.decoder, nn.Identity())
18111812

1812-
for sample in modality_sample:
1813-
if not isinstance(sample, tuple):
1814-
processed_modality_sample.append(sample)
1815-
continue
1816-
1817-
modality_id, modality = sample
1818-
1819-
mod = self.get_modality_info(modality_id)
1820-
1821-
if exists(mod.decoder):
1822-
mod.decoder.eval()
1823-
modality = self.maybe_add_temp_batch_dim(mod.decoder)(modality)
1824-
1825-
processed_modality_sample.append((modality_id, modality))
1813+
with torch.no_grad():
1814+
decoder_fn.eval()
1815+
modality_sample = apply_fn_modality_type(decoder_fn, modality_sample, modality_type = mod.modality_type)
18261816

1827-
return processed_modality_sample
1817+
return modality_sample
18281818

18291819
@typecheck
18301820
def forward_text(

0 commit comments

Comments
 (0)