We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 95d3ab6 commit 308c470Copy full SHA for 308c470
alphafold3_pytorch/attention.py
@@ -374,11 +374,11 @@ def local_attn(
374
# just do radius of 1 for now
375
# perhaps not even necessary, and could try shifted windows (a la Swin)
376
377
- k, v = tuple(pad_at_dim(t, (1, 0), dim = -2) for t in (k, v))
378
- mask = F.pad(mask, (1, 0), value = False)
+ k, v = tuple(pad_at_dim(t, (1, 0), dim = -3) for t in (k, v))
+ mask = pad_at_dim(mask, (1, 0), dim = -2, value = False)
379
380
- k, v = tuple(torch.cat((t[..., :-1, :], t[..., 1:, :]), dim = -2) for t in (k, v))
381
- mask = torch.cat((mask[..., :-1], mask[..., 1:]), dim = -1)
+ k, v = tuple(torch.cat((t[..., :-1, :, :], t[..., 1:, :, :]), dim = -2) for t in (k, v))
+ mask = torch.cat((mask[..., :-1, :], mask[..., 1:, :]), dim = -1)
382
383
# handle attention bias (inefficiently)
384
pyproject.toml
@@ -1,6 +1,6 @@
1
[project]
2
name = "alphafold3-pytorch"
3
-version = "0.6.6"
+version = "0.6.7"
4
description = "Alphafold 3 - Pytorch"
5
authors = [
6
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
0 commit comments