@@ -400,7 +400,10 @@ def __init__(
400400 self .conditioners = nn .ModuleList ([])
401401
402402 self .hidden_dims = hidden_dims
403- self .hiddens_channel_first = hiddens_channel_first # whether hiddens to be conditioned is channel first or last
403+ self .num_condition_fns = len (hidden_dims )
404+ self .hiddens_channel_first = cast_tuple (hiddens_channel_first , self .num_condition_fns ) # whether hiddens to be conditioned is channel first or last
405+
406+ assert len (self .hiddens_channel_first ) == self .num_condition_fns
404407
405408 self .cond_drop_prob = cond_drop_prob
406409
@@ -447,13 +450,11 @@ def forward(
447450 elif exists (text_embeds ):
448451 batch = text_embeds .shape [0 ]
449452
450- device = self .device
451-
452453 if not exists (text_embeds ):
453454 text_embeds = self .embed_texts (texts )
454455
455456 if cond_drop_prob > 0. :
456- prob_keep_mask = prob_mask_like ((batch , 1 ), 1. - cond_drop_prob , device = device )
457+ prob_keep_mask = prob_mask_like ((batch , 1 ), 1. - cond_drop_prob , device = self . device )
457458 null_text_embeds = rearrange (self .null_text_embed , 'd -> 1 d' )
458459
459460 text_embeds = torch .where (
@@ -462,8 +463,18 @@ def forward(
462463 null_text_embeds
463464 )
464465
465- text_embeds = repeat (text_embeds , 'b ... -> (b r) ...' , r = repeat_batch )
466+ # prepare the conditioning functions
467+
468+ repeat_batch = cast_tuple (repeat_batch , self .num_condition_fns )
469+
470+ cond_fns = []
471+
472+ for cond , cond_hiddens_channel_first , cond_repeat_batch in zip (self .conditioners , self .hiddens_channel_first , repeat_batch ):
473+ cond_text_embeds = repeat (text_embeds , 'b ... -> (b r) ...' , r = cond_repeat_batch )
474+ cond_fn = partial (cond , cond_text_embeds )
475+
476+ wrapper_fn = rearrange_channel_first if cond_hiddens_channel_first else rearrange_channel_last
466477
467- wrapper_fn = rearrange_channel_first if self . hiddens_channel_first else rearrange_channel_last
478+ cond_fns . append ( wrapper_fn ( cond_fn ))
468479
469- return tuple (wrapper_fn ( partial ( cond , text_embeds )) for cond in self . conditioners )
480+ return tuple (cond_fns )
0 commit comments