Skip to content

Commit 3b2076e

Browse files
committed
take a tiny necessary step
1 parent 8b354fa commit 3b2076e

File tree

1 file changed

+48
-18
lines changed

1 file changed

+48
-18
lines changed

transfusion_pytorch/transfusion.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ def max_neg_value(t):
249249
def append_dims(t, ndims):
250250
return t.reshape(*t.shape, *((1,) * ndims))
251251

252+
def is_empty(t):
253+
return t.numel() == 0
254+
252255
def log(t, eps = 1e-20):
253256
return torch.log(t.clamp(min = eps))
254257

@@ -1856,8 +1859,6 @@ def generate_modality_only(
18561859

18571860
assert exists(modality_shape)
18581861

1859-
modality_length = math.prod(modality_shape)
1860-
18611862
noise = torch.randn((batch_size, *modality_shape, mod.dim_latent), device = device)
18621863

18631864
if mod.channel_first_latent:
@@ -1986,6 +1987,49 @@ def forward(
19861987

19871988
need_axial_pos_emb = any(self.add_pos_emb)
19881989

1990+
# standardize modalities to be tuple - type 0 modality is implicit if not given
1991+
# also store modality lengths for determining noising times
1992+
1993+
modality_lens = []
1994+
1995+
for batch_modalities in modalities:
1996+
batch_modality_lens = []
1997+
1998+
for ind, modality in enumerate(batch_modalities):
1999+
if is_tensor(modality) and modality.dtype == torch.float:
2000+
modality = (0, modality)
2001+
2002+
if not isinstance(modality, tuple):
2003+
continue
2004+
2005+
modality_type, modality_tensor = modality
2006+
mod = self.get_modality_info(modality_type)
2007+
2008+
batch_modalities[ind] = modality
2009+
seq_dim = 0 if mod.channel_first_latent else -1
2010+
batch_modality_lens.append(modality_tensor.shape[seq_dim])
2011+
2012+
modality_lens.append(tensor_(batch_modality_lens))
2013+
2014+
modality_lens = pad_sequence(modality_lens)
2015+
2016+
# determine the times
2017+
2018+
if not exists(times):
2019+
if is_empty(modality_lens):
2020+
times = modality_lens.float()
2021+
else:
2022+
modality_length_to_times_fn = default(modality_length_to_times_fn, default_modality_length_to_time_fn)
2023+
2024+
if exists(modality_length_to_times_fn):
2025+
times = modality_length_to_times_fn(modality_lens)
2026+
2027+
# if needs velocity matching, make sure times are in the range of 0 - (1. - <velocity consistency delta time>)
2028+
2029+
if need_velocity_matching:
2030+
orig_times = times.clone()
2031+
times = times * (1. - velocity_consistency_delta_time)
2032+
19892033
# process list of text and modalities interspersed with one another
19902034

19912035
modality_positions = []
@@ -2005,9 +2049,6 @@ def forward(
20052049
# if non-text modality detected and not given as a tuple
20062050
# cast to (int, Tensor) where int is defaulted to type 0 (convenience for one modality)
20072051

2008-
if is_tensor(modality) and modality.dtype == torch.float:
2009-
modality = (0, modality)
2010-
20112052
is_text = not isinstance(modality, tuple)
20122053
is_modality = not is_text
20132054

@@ -2142,7 +2183,7 @@ def forward(
21422183
modality_positions = modality_positions_to_tensor(modality_positions, device = device)
21432184

21442185
if modality_positions.shape[-1] == 2: # Int['b m 2'] -> Int['b m 3'] if type is not given (one modality)
2145-
modality_positions = F.pad(modality_positions, (1, 0), value = 0)
2186+
modality_positions = F.pad(modality_positions, (1, 0))
21462187

21472188
# for now use dummy padding modality position info if empty (all zeros)
21482189

@@ -2178,19 +2219,8 @@ def forward(
21782219

21792220
# noise the modality tokens
21802221

2181-
if not exists(times):
2182-
modality_length_to_times_fn = default(modality_length_to_times_fn, default_modality_length_to_time_fn)
2183-
2184-
if exists(modality_length_to_times_fn):
2185-
times = modality_length_to_times_fn(modality_positions[..., -1])
2186-
21872222
times_per_token = einsum(is_modalities.float(), times, 'b t m n, b m -> b t n')
21882223

2189-
# if needs velocity matching, make sure times are in the range of 0 - (1. - <velocity consistency delta time>)
2190-
2191-
if need_velocity_matching:
2192-
times_per_token = times_per_token * (1. - velocity_consistency_delta_time)
2193-
21942224
# noise only if returning loss
21952225

21962226
if return_loss:
@@ -2370,7 +2400,7 @@ def forward(
23702400

23712401
ema_pred_flows = velocity_consistency_ema_model(
23722402
velocity_modalities,
2373-
times = times + velocity_consistency_delta_time,
2403+
times = orig_times + velocity_consistency_delta_time,
23742404
return_only_pred_flows = True
23752405
)
23762406

0 commit comments

Comments
 (0)