File tree 2 files changed +8
-18
lines changed
2 files changed +8
-18
lines changed Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " transfusion-pytorch"
3
- version = " 0.10.0 "
3
+ version = " 0.10.1 "
4
4
description = " Transfusion in Pytorch"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -1805,26 +1805,16 @@ def ode_step_fn(step_times, denoised):
1805
1805
if return_unprocessed_modalities :
1806
1806
return modality_sample
1807
1807
1808
- # post process modalities
1808
+ # post process modality sample, decoding modality types if they have a decoder
1809
1809
1810
- processed_modality_sample = []
1810
+ for mod in self .get_all_modality_info ():
1811
+ decoder_fn = default (mod .decoder , nn .Identity ())
1811
1812
1812
- for sample in modality_sample :
1813
- if not isinstance (sample , tuple ):
1814
- processed_modality_sample .append (sample )
1815
- continue
1816
-
1817
- modality_id , modality = sample
1818
-
1819
- mod = self .get_modality_info (modality_id )
1820
-
1821
- if exists (mod .decoder ):
1822
- mod .decoder .eval ()
1823
- modality = self .maybe_add_temp_batch_dim (mod .decoder )(modality )
1824
-
1825
- processed_modality_sample .append ((modality_id , modality ))
1813
+ with torch .no_grad ():
1814
+ decoder_fn .eval ()
1815
+ modality_sample = apply_fn_modality_type (decoder_fn , modality_sample , modality_type = mod .modality_type )
1826
1816
1827
- return processed_modality_sample
1817
+ return modality_sample
1828
1818
1829
1819
@typecheck
1830
1820
def forward_text (
You can’t perform that action at this time.
0 commit comments