@@ -1276,11 +1276,30 @@ def __init__(
1276
1276
1277
1277
self .to_modality_shape_fn = cast_tuple (to_modality_shape_fn , self .num_modalities )
1278
1278
1279
+ # default token lengths for respective modality
1280
+ # fallback if the language model does not come up with valid dimensions
1281
+
1282
+ if not exists (modality_default_shape ) or is_bearable (modality_default_shape , tuple [int , ...]):
1283
+ modality_default_shape = (modality_default_shape ,) * self .num_modalities
1284
+
1285
+ self .modality_default_shape = modality_default_shape
1286
+
1287
+ assert len (self .modality_default_shape ) == self .num_modalities
1288
+
1289
+ self .fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
1290
+
1291
+ # default `modality_num_dim` to `len(modality_default_shape)` if latter is specified but former not
1292
+
1293
+ modality_num_dim = default (modality_num_dim , tuple (len (shape ) for shape in self .modality_default_shape ))
1294
+
1279
1295
# specifying the number of dimensions for the modality, which will be hard validated
1280
1296
1281
1297
self .modality_num_dim = cast_tuple (modality_num_dim , self .num_modalities )
1298
+
1282
1299
assert len (self .modality_num_dim ) == self .num_modalities
1283
1300
1301
+ assert all ([not exists (ndim ) or not exists (shape ) or len (shape ) == ndim for ndim , shape in zip (self .modality_num_dim , self .modality_default_shape )])
1302
+
1284
1303
# whether to add an extra axial positional embedding per modality
1285
1304
1286
1305
self .add_pos_emb = cast_tuple (add_pos_emb , self .num_modalities )
@@ -1318,18 +1337,6 @@ def __init__(
1318
1337
1319
1338
self .maybe_add_temp_batch_dim = add_temp_batch_dim if modality_encoder_decoder_requires_batch_dim else identity
1320
1339
1321
- # default token lengths for respective modality
1322
- # fallback if the language model does not come up with valid dimensions
1323
-
1324
- if not exists (modality_default_shape ) or is_bearable (modality_default_shape , tuple [int , ...]):
1325
- modality_default_shape = (modality_default_shape ,) * self .num_modalities
1326
-
1327
- self .modality_default_shape = modality_default_shape
1328
-
1329
- assert len (self .modality_default_shape ) == self .num_modalities
1330
-
1331
- self .fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
1332
-
1333
1340
# store number of text tokens
1334
1341
1335
1342
self .num_text_tokens = num_text_tokens
0 commit comments