@@ -249,6 +249,9 @@ def max_neg_value(t):
249
249
def append_dims (t , ndims ):
250
250
return t .reshape (* t .shape , * ((1 ,) * ndims ))
251
251
252
+ def is_empty (t ):
253
+ return t .numel () == 0
254
+
252
255
def log (t , eps = 1e-20 ):
253
256
return torch .log (t .clamp (min = eps ))
254
257
@@ -1856,8 +1859,6 @@ def generate_modality_only(
1856
1859
1857
1860
assert exists (modality_shape )
1858
1861
1859
- modality_length = math .prod (modality_shape )
1860
-
1861
1862
noise = torch .randn ((batch_size , * modality_shape , mod .dim_latent ), device = device )
1862
1863
1863
1864
if mod .channel_first_latent :
@@ -1986,6 +1987,49 @@ def forward(
1986
1987
1987
1988
need_axial_pos_emb = any (self .add_pos_emb )
1988
1989
1990
+ # standardize modalities to be tuple - type 0 modality is implicit if not given
1991
+ # also store modality lengths for determining noising times
1992
+
1993
+ modality_lens = []
1994
+
1995
+ for batch_modalities in modalities :
1996
+ batch_modality_lens = []
1997
+
1998
+ for ind , modality in enumerate (batch_modalities ):
1999
+ if is_tensor (modality ) and modality .dtype == torch .float :
2000
+ modality = (0 , modality )
2001
+
2002
+ if not isinstance (modality , tuple ):
2003
+ continue
2004
+
2005
+ modality_type , modality_tensor = modality
2006
+ mod = self .get_modality_info (modality_type )
2007
+
2008
+ batch_modalities [ind ] = modality
2009
+ seq_dim = 0 if mod .channel_first_latent else - 1
2010
+ batch_modality_lens .append (modality_tensor .shape [seq_dim ])
2011
+
2012
+ modality_lens .append (tensor_ (batch_modality_lens ))
2013
+
2014
+ modality_lens = pad_sequence (modality_lens )
2015
+
2016
+ # determine the times
2017
+
2018
+ if not exists (times ):
2019
+ if is_empty (modality_lens ):
2020
+ times = modality_lens .float ()
2021
+ else :
2022
+ modality_length_to_times_fn = default (modality_length_to_times_fn , default_modality_length_to_time_fn )
2023
+
2024
+ if exists (modality_length_to_times_fn ):
2025
+ times = modality_length_to_times_fn (modality_lens )
2026
+
2027
+ # if needs velocity matching, make sure times are in the range of 0 - (1. - <velocity consistency delta time>)
2028
+
2029
+ if need_velocity_matching :
2030
+ orig_times = times .clone ()
2031
+ times = times * (1. - velocity_consistency_delta_time )
2032
+
1989
2033
# process list of text and modalities interspersed with one another
1990
2034
1991
2035
modality_positions = []
@@ -2005,9 +2049,6 @@ def forward(
2005
2049
# if non-text modality detected and not given as a tuple
2006
2050
# cast to (int, Tensor) where int is defaulted to type 0 (convenience for one modality)
2007
2051
2008
- if is_tensor (modality ) and modality .dtype == torch .float :
2009
- modality = (0 , modality )
2010
-
2011
2052
is_text = not isinstance (modality , tuple )
2012
2053
is_modality = not is_text
2013
2054
@@ -2142,7 +2183,7 @@ def forward(
2142
2183
modality_positions = modality_positions_to_tensor (modality_positions , device = device )
2143
2184
2144
2185
if modality_positions .shape [- 1 ] == 2 : # Int['b m 2'] -> Int['b m 3'] if type is not given (one modality)
2145
- modality_positions = F .pad (modality_positions , (1 , 0 ), value = 0 )
2186
+ modality_positions = F .pad (modality_positions , (1 , 0 ))
2146
2187
2147
2188
# for now use dummy padding modality position info if empty (all zeros)
2148
2189
@@ -2178,19 +2219,8 @@ def forward(
2178
2219
2179
2220
# noise the modality tokens
2180
2221
2181
- if not exists (times ):
2182
- modality_length_to_times_fn = default (modality_length_to_times_fn , default_modality_length_to_time_fn )
2183
-
2184
- if exists (modality_length_to_times_fn ):
2185
- times = modality_length_to_times_fn (modality_positions [..., - 1 ])
2186
-
2187
2222
times_per_token = einsum (is_modalities .float (), times , 'b t m n, b m -> b t n' )
2188
2223
2189
- # if needs velocity matching, make sure times are in the range of 0 - (1. - <velocity consistency delta time>)
2190
-
2191
- if need_velocity_matching :
2192
- times_per_token = times_per_token * (1. - velocity_consistency_delta_time )
2193
-
2194
2224
# noise only if returning loss
2195
2225
2196
2226
if return_loss :
@@ -2370,7 +2400,7 @@ def forward(
2370
2400
2371
2401
ema_pred_flows = velocity_consistency_ema_model (
2372
2402
velocity_modalities ,
2373
- times = times + velocity_consistency_delta_time ,
2403
+ times = orig_times + velocity_consistency_delta_time ,
2374
2404
return_only_pred_flows = True
2375
2405
)
2376
2406
0 commit comments