Skip to content

Commit 308c470

Browse files
committed
fix local windowed attention to be overlapping
1 parent 95d3ab6 commit 308c470

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

alphafold3_pytorch/attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,11 @@ def local_attn(
374374
# just do radius of 1 for now
375375
# perhaps not even necessary, and could try shifted windows (a la Swin)
376376

377-
k, v = tuple(pad_at_dim(t, (1, 0), dim = -2) for t in (k, v))
378-
mask = F.pad(mask, (1, 0), value = False)
377+
k, v = tuple(pad_at_dim(t, (1, 0), dim = -3) for t in (k, v))
378+
mask = pad_at_dim(mask, (1, 0), dim = -2, value = False)
379379

380-
k, v = tuple(torch.cat((t[..., :-1, :], t[..., 1:, :]), dim = -2) for t in (k, v))
381-
mask = torch.cat((mask[..., :-1], mask[..., 1:]), dim = -1)
380+
k, v = tuple(torch.cat((t[..., :-1, :, :], t[..., 1:, :, :]), dim = -2) for t in (k, v))
381+
mask = torch.cat((mask[..., :-1, :], mask[..., 1:, :]), dim = -1)
382382

383383
# handle attention bias (inefficiently)
384384

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.6.6"
3+
version = "0.6.7"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" },

0 commit comments

Comments
 (0)