Skip to content

Commit 1688ff8

Browse files
committed
address zero-dimensional modality, for #39
1 parent 3d90489 commit 1688ff8

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.10.2"
3+
version = "0.10.3"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_transfusion.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_auto_modality_transform(
8787
num_text_tokens = text_tokens,
8888
dim_latent = 384,
8989
channel_first_latent = True,
90-
modality_default_shape = (32,),
90+
modality_default_shape = (2, 2),
9191
transformer = dict(
9292
dim = 64,
9393
depth = 2,
@@ -385,3 +385,32 @@ def test_apply_fn_modality_type():
385385

386386
assert (modalities[0][0][-1] == 1).all()
387387
assert (modalities[2][0][-1] == 2).all()
388+
389+
390+
def test_zero_dimensional():
391+
392+
model = Transfusion(
393+
num_text_tokens = 256,
394+
dim_latent = 384,
395+
modality_default_shape = (),
396+
transformer = dict(
397+
dim = 512,
398+
depth = 8,
399+
num_residual_streams = 1
400+
)
401+
)
402+
403+
# any torch.long is text, torch.float is modalities
404+
405+
text_and_embeds = [
406+
[randint(0, 256, (16,)), randn(384), randint(0, 256, (8,)), randn(384)],
407+
[randint(0, 256, (16,)), randn(384), randint(0, 256, (5,)), randn(384), randint(0, 256, (9,))]
408+
]
409+
410+
loss = model(text_and_embeds)
411+
412+
loss.backward()
413+
414+
# after much training
415+
416+
one_multimodal_sample = model.sample()

transfusion_pytorch/transfusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,6 +2224,7 @@ def forward(
22242224
batch_num_modalities = 0
22252225

22262226
for ind, modality in enumerate(batch_modalities):
2227+
22272228
if is_tensor(modality) and modality.dtype == torch.float:
22282229
modality = (0, modality)
22292230

@@ -2345,7 +2346,9 @@ def inner(pred_flow):
23452346
assert 0 <= modality_type < self.num_modalities, f'received a modality index that is out of range. only {self.num_modalities} modalities specified'
23462347

23472348
channel_dim = 0 if mod.channel_first_latent else -1
2349+
23482350
assert mod.dim_latent == modality_tensor.shape[channel_dim], f'mismatch for modality latent dimension - expected {mod.dim_latent} but received {modality_tensor.shape[-1]} - modality shape is {tuple(modality_tensor.shape)}, perhaps you need to set `channel_first_latent` to the correct value'
2351+
assert mod.num_dim == (len(modality_tensor.shape) - 1), f'mismatch for modality number of dimensions - expected {mod.num_dim} but received {len(modality_tensor.shape) - 1} {modality_tensor.shape}'
23492352

23502353
# auto ward against scalars (lone start end tokens)
23512354

@@ -2355,7 +2358,7 @@ def inner(pred_flow):
23552358
# handle text
23562359

23572360
if is_text:
2358-
assert modality_tensor.ndim == 1
2361+
assert modality_tensor.ndim == 1 and modality_tensor.dtype in (torch.int, torch.long)
23592362
text_length = modality_tensor.shape[0]
23602363

23612364
batch_text.append(modality_tensor)
@@ -2420,7 +2423,7 @@ def inner(pred_flow):
24202423
# start by just storing the token length of the modality
24212424

24222425
modality_shape_str = join([*map(str, modality_shape_tuple)], ',')
2423-
modality_meta_info = self.char_tokenizer(modality_shape_str, device = device)
2426+
modality_meta_info = self.char_tokenizer(modality_shape_str, device = device).long()
24242427

24252428
precede_modality_tokens = len(modality_meta_info) + 2
24262429
succeed_modality_tokens = 1
@@ -2450,6 +2453,7 @@ def inner(pred_flow):
24502453
modality_tensor = F.pad(modality_tensor, (0, 0, precede_modality_tokens, succeed_modality_tokens))
24512454

24522455
batch_modality_tokens.append(modality_tensor)
2456+
24532457
batch_text.append(text_tensor)
24542458

24552459
# handle axial positional embedding

0 commit comments

Comments
 (0)