@@ -201,11 +201,11 @@ def concat_contiguous_text(
201
201
""" within a modality sample, any two tensors of type int / long will be concatted together if next to each other, so all text is followed by a modality, and all modality followed by text """
202
202
203
203
output = []
204
- curr_modality = None
205
204
206
205
for modality in modality_sample :
207
206
if (
208
207
len (output ) > 0 and
208
+ is_tensor (output [- 1 ]) and is_tensor (modality ) and
209
209
output [- 1 ].dtype == modality .dtype and
210
210
modality .dtype in (torch .int , torch .long )
211
211
):
@@ -1365,6 +1365,19 @@ def get_modality_info(
1365
1365
def get_all_modality_info (self ) -> list [ModalityInfo ]:
1366
1366
return [self .get_modality_info (i ) for i in range (self .num_modalities )]
1367
1367
1368
+ def get_modality_shape (
1369
+ self ,
1370
+ modality : Float ['...' ],
1371
+ modality_type : int | None = None
1372
+ ) -> tuple [int , ...]:
1373
+
1374
+ mod = self .get_modality_info (modality_type )
1375
+
1376
+ if mod .channel_first_latent :
1377
+ modality = rearrange (modality , 'c ... -> ... c' )
1378
+
1379
+ return tuple (modality .shape [:- 1 ])
1380
+
1368
1381
def parameters_without_encoder_decoder (self ):
1369
1382
return (
1370
1383
set (self .parameters ()) -
@@ -1402,7 +1415,7 @@ def create_ema(
1402
1415
@typecheck
1403
1416
def sample (
1404
1417
self ,
1405
- prompt : ModalitySample | None = None ,
1418
+ prompt : ModalitySample | Tensor | tuple [ int , Float [ '...' ]] | None = None ,
1406
1419
max_length = 2048 ,
1407
1420
text_temperature = 1.5 ,
1408
1421
text_min_p = 0.1 ,
@@ -1415,22 +1428,52 @@ def sample(
1415
1428
1416
1429
device = self .device
1417
1430
1431
+ # take care of prompt being a raw tensor, either text or raw modality (image, video, actions, latents, etc)
1432
+
1433
+ if is_tensor (prompt ) and prompt .dtype == torch .float : # is modality with type 0 implicit
1434
+ prompt = (0 , prompt )
1435
+
1436
+ if is_tensor (prompt ) and prompt .dtype in (torch .int , torch .long ): # is text only prompt
1437
+ prompt = [prompt ]
1438
+
1439
+ elif isinstance (prompt , tuple ):
1440
+ modality_type , modality = prompt
1441
+
1442
+ mod = self .get_modality_info (modality_type )
1443
+
1444
+ if exists (mod .encoder ):
1445
+ with torch .no_grad ():
1446
+ mod .encoder .eval ()
1447
+ modality = self .maybe_add_temp_batch_dim (mod .encoder )(modality ).detach ()
1448
+
1449
+ modality_shape_tuple = self .get_modality_shape (modality , modality_type )
1450
+ modality_shape_str = join ([* map (str , modality_shape_tuple )], ',' )
1451
+ modality_meta_info = self .char_tokenizer (modality_shape_str , device = device )
1452
+
1453
+ prompt = [
1454
+ tensor ([self .meta_id ]),
1455
+ modality_meta_info ,
1456
+ tensor ([mod .som_id ]),
1457
+ (modality_type , modality ),
1458
+ tensor ([mod .eom_id ]),
1459
+ ]
1460
+
1461
+ # sos
1462
+
1418
1463
init_text_seq = tensor ([self .sos_id ], device = device )
1419
1464
1420
1465
# just take care of prompt being zero dimensions
1421
1466
1422
- prompt = tree_map_tensor (prompt , lambda t : rearrange (t , '-> 1' ) if t .ndim == 0 else t )
1423
-
1424
1467
modality_sample = [init_text_seq , * default (prompt , [])]
1425
1468
1426
1469
# take care of moving to device
1427
1470
1428
1471
modality_sample = tree_map_tensor (modality_sample , lambda t : t .to (device ))
1472
+ modality_sample = tree_map_tensor (modality_sample , lambda t : rearrange (t , '-> 1' ) if t .ndim == 0 else t )
1429
1473
1430
1474
modality_sample = concat_contiguous_text (modality_sample )
1431
1475
1432
1476
* _ , last_modality_sample = modality_sample
1433
- assert last_modality_sample .dtype in (torch .int , torch .long ), 'prompt must be text tokens'
1434
1477
1435
1478
curr_length = 0
1436
1479
curr_modality_id = None
0 commit comments