Skip to content

Commit 865b893

Browse files
committed
able to return the initial untransformed modality when encoding modalities in batches, for the recon loss
1 parent d62824a commit 865b893

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

transfusion_pytorch/transfusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,8 @@ def inverse(inverse_inp):
583583
def apply_fn_modality_type(
584584
fn: Callable,
585585
modalities: ModalitySample | list[ModalitySample],
586-
modality_type = 0
586+
modality_type = 0,
587+
return_untransformed = False
587588
) -> ModalitySample | list[ModalitySample]:
588589

589590
modalities, tree_spec = tree_flatten(modalities, is_leaf = lambda el: isinstance(el, tuple))
@@ -610,7 +611,10 @@ def apply_fn_modality_type(
610611

611612
# add back the type
612613

613-
out = [(modality_type, m) for m in out]
614+
if return_untransformed:
615+
out = [(modality_type, transformed_m, prev_m) for transformed_m, prev_m in zip(out, modalities)]
616+
else:
617+
out = [(modality_type, transformed_m) for transformed_m in out]
614618

615619
# replace transformed modalities and untree flatten
616620

0 commit comments

Comments
 (0)