Skip to content

Commit d94a5a3

Browse files
committed
address #237
1 parent dcc3da8 commit d94a5a3

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,11 @@ def __init__(
314314

315315
# attention
316316

317-
full_attn = cast_tuple(full_attn, length = len(dim_mults))
317+
num_stages = len(dim_mults)
318+
full_attn = cast_tuple(full_attn, num_stages)
319+
attn_heads = cast_tuple(attn_heads, num_stages)
320+
attn_dim_head = cast_tuple(attn_dim_head, num_stages)
321+
318322
assert len(full_attn) == len(dim_mults)
319323

320324
FullAttention = partial(Attention, flash = flash_attn)
@@ -325,15 +329,15 @@ def __init__(
325329
self.ups = nn.ModuleList([])
326330
num_resolutions = len(in_out)
327331

328-
for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(in_out, full_attn)):
332+
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
329333
is_last = ind >= (num_resolutions - 1)
330334

331335
attn_klass = FullAttention if layer_full_attn else LinearAttention
332336

333337
self.downs.append(nn.ModuleList([
334338
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
335339
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
336-
attn_klass(dim_in, dim_head = attn_dim_head, heads = attn_heads),
340+
attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
337341
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
338342
]))
339343

@@ -342,15 +346,15 @@ def __init__(
342346
self.mid_attn = FullAttention(mid_dim)
343347
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
344348

345-
for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(reversed(in_out), reversed(full_attn))):
349+
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
346350
is_last = ind == (len(in_out) - 1)
347351

348352
attn_klass = FullAttention if layer_full_attn else LinearAttention
349353

350354
self.ups.append(nn.ModuleList([
351355
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
352356
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
353-
attn_klass(dim_out, dim_head = attn_dim_head, heads = attn_heads),
357+
attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
354358
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
355359
]))
356360

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.8.6'
1+
__version__ = '1.8.7'

0 commit comments

Comments
 (0)