Skip to content

Commit 3d90489

Browse files
committed
autoset modality_num_dim if modality_default_shape specified
1 parent 4325dd8 commit 3d90489

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.10.1"
3+
version = "0.10.2"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_transfusion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ def test_velocity_consistency():
234234
num_text_tokens = 12,
235235
dim_latent = 384,
236236
channel_first_latent = True,
237-
modality_default_shape = ((4, 4)),
238-
modality_num_dim = 2,
237+
modality_default_shape = (4, 4),
239238
modality_encoder = mock_encoder,
240239
modality_decoder = mock_decoder,
241240
transformer = dict(

transfusion_pytorch/transfusion.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,11 +1276,30 @@ def __init__(
12761276

12771277
self.to_modality_shape_fn = cast_tuple(to_modality_shape_fn, self.num_modalities)
12781278

1279+
# default token lengths for respective modality
1280+
# fallback if the language model does not come up with valid dimensions
1281+
1282+
if not exists(modality_default_shape) or is_bearable(modality_default_shape, tuple[int, ...]):
1283+
modality_default_shape = (modality_default_shape,) * self.num_modalities
1284+
1285+
self.modality_default_shape = modality_default_shape
1286+
1287+
assert len(self.modality_default_shape) == self.num_modalities
1288+
1289+
self.fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
1290+
1291+
# default `modality_num_dim` to `len(modality_default_shape)` if latter is specified but former not
1292+
1293+
modality_num_dim = default(modality_num_dim, tuple(len(shape) for shape in self.modality_default_shape))
1294+
12791295
# specifying the number of dimensions for the modality, which will be hard validated
12801296

12811297
self.modality_num_dim = cast_tuple(modality_num_dim, self.num_modalities)
1298+
12821299
assert len(self.modality_num_dim) == self.num_modalities
12831300

1301+
assert all([not exists(ndim) or not exists(shape) or len(shape) == ndim for ndim, shape in zip(self.modality_num_dim, self.modality_default_shape)])
1302+
12841303
# whether to add an extra axial positional embedding per modality
12851304

12861305
self.add_pos_emb = cast_tuple(add_pos_emb, self.num_modalities)
@@ -1318,18 +1337,6 @@ def __init__(
13181337

13191338
self.maybe_add_temp_batch_dim = add_temp_batch_dim if modality_encoder_decoder_requires_batch_dim else identity
13201339

1321-
# default token lengths for respective modality
1322-
# fallback if the language model does not come up with valid dimensions
1323-
1324-
if not exists(modality_default_shape) or is_bearable(modality_default_shape, tuple[int, ...]):
1325-
modality_default_shape = (modality_default_shape,) * self.num_modalities
1326-
1327-
self.modality_default_shape = modality_default_shape
1328-
1329-
assert len(self.modality_default_shape) == self.num_modalities
1330-
1331-
self.fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
1332-
13331340
# store number of text tokens
13341341

13351342
self.num_text_tokens = num_text_tokens

0 commit comments

Comments
 (0)