Skip to content

Commit 7d92267

Browse files
authored
fix(GatedDeltaNet): Init param A from log of a uniform distrib (#906)
1 parent 35354fa commit 7d92267

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ch04/08_deltanet/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ class GatedDeltaNet(nn.Module):
166166
# A_log + W_alpha(x) + dt_bias
167167
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
168168
self.dt_bias = nn.Parameter(torch.ones(num_heads))
169-
self.A_log = nn.Parameter(torch.zeros(num_heads))
169+
A_init = torch.empty(num_heads).uniform_(0, 16)
170+
self.A_log = nn.Parameter(torch.log(A_init))
170171
# We could implement this as
171172
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
172173
# but the bias is separate for interpretability and

0 commit comments

Comments
 (0)