@@ -314,7 +314,11 @@ def __init__(
314314
315315 # attention
316316
317- full_attn = cast_tuple (full_attn , length = len (dim_mults ))
317+ num_stages = len (dim_mults )
318+ full_attn = cast_tuple (full_attn , num_stages )
319+ attn_heads = cast_tuple (attn_heads , num_stages )
320+ attn_dim_head = cast_tuple (attn_dim_head , num_stages )
321+
318322 assert len (full_attn ) == len (dim_mults )
319323
320324 FullAttention = partial (Attention , flash = flash_attn )
@@ -325,32 +329,32 @@ def __init__(
325329 self .ups = nn .ModuleList ([])
326330 num_resolutions = len (in_out )
327331
328- for ind , ((dim_in , dim_out ), layer_full_attn ) in enumerate (zip (in_out , full_attn )):
332+ for ind , ((dim_in , dim_out ), layer_full_attn , layer_attn_heads , layer_attn_dim_head ) in enumerate (zip (in_out , full_attn , attn_heads , attn_dim_head )):
329333 is_last = ind >= (num_resolutions - 1 )
330334
331335 attn_klass = FullAttention if layer_full_attn else LinearAttention
332336
333337 self .downs .append (nn .ModuleList ([
334338 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
335339 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
336- attn_klass (dim_in , dim_head = attn_dim_head , heads = attn_heads ),
340+ attn_klass (dim_in , dim_head = layer_attn_dim_head , heads = layer_attn_heads ),
337341 Downsample (dim_in , dim_out ) if not is_last else nn .Conv2d (dim_in , dim_out , 3 , padding = 1 )
338342 ]))
339343
340344 mid_dim = dims [- 1 ]
341345 self .mid_block1 = block_klass (mid_dim , mid_dim , time_emb_dim = time_dim )
342- self .mid_attn = FullAttention (mid_dim )
346+ self .mid_attn = FullAttention (mid_dim , heads = attn_heads [ - 1 ], dim_head = attn_dim_head [ - 1 ] )
343347 self .mid_block2 = block_klass (mid_dim , mid_dim , time_emb_dim = time_dim )
344348
345- for ind , ((dim_in , dim_out ), layer_full_attn ) in enumerate (zip (reversed (in_out ), reversed ( full_attn ))):
349+ for ind , ((dim_in , dim_out ), layer_full_attn , layer_attn_heads , layer_attn_dim_head ) in enumerate (zip (* map ( reversed , (in_out , full_attn , attn_heads , attn_dim_head ) ))):
346350 is_last = ind == (len (in_out ) - 1 )
347351
348352 attn_klass = FullAttention if layer_full_attn else LinearAttention
349353
350354 self .ups .append (nn .ModuleList ([
351355 block_klass (dim_out + dim_in , dim_out , time_emb_dim = time_dim ),
352356 block_klass (dim_out + dim_in , dim_out , time_emb_dim = time_dim ),
353- attn_klass (dim_out , dim_head = attn_dim_head , heads = attn_heads ),
357+ attn_klass (dim_out , dim_head = layer_attn_dim_head , heads = layer_attn_heads ),
354358 Upsample (dim_out , dim_in ) if not is_last else nn .Conv2d (dim_out , dim_in , 3 , padding = 1 )
355359 ]))
356360
0 commit comments