@@ -113,6 +113,7 @@ class ModalityInfo(NamedTuple):
113
113
eom_id : int
114
114
to_shape_fn : Callable | None
115
115
channel_first_latent : bool
116
+ modality_type : int
116
117
117
118
# helper functions
118
119
@@ -1489,7 +1490,8 @@ def get_modality_info(
1489
1490
som_id = som_id ,
1490
1491
eom_id = eom_id ,
1491
1492
to_shape_fn = to_shape_fn ,
1492
- channel_first_latent = channel_first_latent
1493
+ channel_first_latent = channel_first_latent ,
1494
+ modality_type = modality_type
1493
1495
)
1494
1496
1495
1497
def get_all_modality_info (self ) -> list [ModalityInfo ]:
@@ -2264,10 +2266,24 @@ def forward(
2264
2266
2265
2267
text = []
2266
2268
2267
- flows = defaultdict (list ) # store flows for loss
2269
+ # auto move all tensors to device of model
2270
+
2271
+ modalities = tree_map_tensor (modalities , lambda t : t .to (device ))
2272
+
2273
+ # for all modalities, batch process same shaped modalities of the same type
2274
+
2275
+ if not is_decoding :
2276
+ for mod in self .get_all_modality_info ():
2277
+ encode_fn = default (mod .encoder , nn .Identity ())
2278
+
2279
+ with torch .no_grad ():
2280
+ encode_fn .eval ()
2281
+ modalities = apply_fn_modality_type (encode_fn , modalities , modality_type = mod .modality_type )
2268
2282
2269
2283
# for parsing out the predicted flow from flattened sequence of tokens coming out of transformer
2270
2284
2285
+ flows = defaultdict (list ) # store flows for loss
2286
+
2271
2287
get_pred_flows : GetPredFlows = defaultdict (list ) # functions for parsing modalities from Float['b n d'] for model back to latents or pixel space
2272
2288
2273
2289
def model_to_pred_flow (batch_index , start_index , modality_length , unpack_fn ):
@@ -2322,22 +2338,13 @@ def inner(pred_flow):
2322
2338
if is_text :
2323
2339
modality_tensor = modality
2324
2340
else :
2325
- modality_type , modality_tensor = modality
2341
+ modality_type , modality_tensor , * _ = modality
2326
2342
2327
2343
# auto move modality tensor to correct device
2328
2344
2329
- modality_tensor = modality_tensor .to (device )
2330
-
2331
2345
mod = self .get_modality_info (modality_type )
2332
2346
2333
2347
if is_modality :
2334
- if not is_decoding :
2335
-
2336
- if exists (mod .encoder ):
2337
- with torch .no_grad ():
2338
- mod .encoder .eval ()
2339
- modality_tensor = self .maybe_add_temp_batch_dim (mod .encoder )(modality_tensor ).detach ()
2340
-
2341
2348
assert 0 <= modality_type < self .num_modalities , f'received a modality index that is out of range. only { self .num_modalities } modalities specified'
2342
2349
2343
2350
channel_dim = 0 if mod .channel_first_latent else - 1
0 commit comments