Skip to content

Commit 43f1b75

Browse files
committed
go for an improvised solution for what may be an issue with autoregressive and laser
1 parent bb54873 commit 43f1b75

File tree

6 files changed

+10
-8
lines changed

6 files changed

+10
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.6.3"
3+
version = "0.6.4"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train_latent_with_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def encode_tokens(str: str) -> Tensor:
9595
dim = 128,
9696
depth = 8,
9797
dim_head = 64,
98-
heads = 8
98+
heads = 8,
9999
)
100100
).cuda()
101101

train_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def forward(self, x):
6565
dim = 64,
6666
depth = 4,
6767
dim_head = 32,
68-
heads = 8
68+
heads = 8,
6969
)
7070
).cuda()
7171

train_mnist_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def forward(self, x):
127127
dim = 64,
128128
depth = 4,
129129
dim_head = 32,
130-
heads = 8
130+
heads = 8,
131131
)
132132
).cuda()
133133

train_mnist_with_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def forward(self, x):
6161
dim = 64,
6262
depth = 4,
6363
dim_head = 32,
64-
heads = 8
64+
heads = 8,
6565
)
6666
).to(device)
6767

transfusion_pytorch/transfusion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ def __init__(
757757
use_flex_attn = False,
758758
gate_values = True,
759759
laser = False,
760+
laser_softclamp_value = 15.,
760761
learned_value_residual_mix = False
761762
):
762763
super().__init__()
@@ -785,6 +786,7 @@ def __init__(
785786
self.softcap_value = softcap_value
786787

787788
self.laser = laser
789+
self.laser_softclamp_value = laser_softclamp_value
788790

789791
self.dropout = nn.Dropout(dropout)
790792

@@ -850,8 +852,8 @@ def forward(
850852
# laser attention
851853

852854
if self.laser:
853-
v_max = v.amax(dim = -2, keepdim = True).detach()
854-
v = (v - v_max).exp()
855+
v = softclamp(v, self.laser_softclamp_value)
856+
v = v.exp()
855857

856858
# whether to use flex attention or not
857859

@@ -890,7 +892,7 @@ def forward(
890892
# laser attention
891893

892894
if self.laser:
893-
out = log(out) + v_max
895+
out = log(out)
894896

895897
# maybe gate values
896898

0 commit comments

Comments
 (0)