@@ -40,6 +40,9 @@ def pack_one(t, pattern):
4040def unpack_one (t , ps , pattern ):
4141 return unpack (t , ps , pattern )[0 ]
4242
43+ def log (t , eps = 1e-20 ):
44+ return t .clamp (min = eps ).log ()
45+
4346def softclamp (t , value ):
4447 return (t / value ).tanh () * value
4548
@@ -181,6 +184,7 @@ def __init__(
181184 query_bias = True ,
182185 window_size = None ,
183186 num_memory_kv : int = 0 ,
187+ laser = False ,
184188 enable_attn_softclamp = False ,
185189 attn_softclamp_value = 50. ,
186190 softmax_full_precision = False
@@ -222,6 +226,10 @@ def __init__(
222226 self .memory_kv = nn .Parameter (torch .zeros (2 , heads , num_memory_kv , dim_head ))
223227 nn .init .normal_ (self .memory_kv , std = 0.02 )
224228
229+ # laser attention
230+
231+ self .laser = laser
232+
225233 # gating of value
226234 # allows attention to attend to nothing
227235
@@ -262,6 +270,12 @@ def forward(
262270
263271 q , k , v = tuple (self .split_heads (t ) for t in (q , k , v ))
264272
273+ # maybe laser
274+
275+ if self .laser :
276+ v_max = v .amax (dim = - 2 , keepdim = True )
277+ v = (v - v_max ).exp ()
278+
265279 # attention
266280
267281 out = self .attend (
@@ -272,6 +286,11 @@ def forward(
272286 memory_kv = self .memory_kv
273287 )
274288
289+ # maybe laser
290+
291+ if self .laser :
292+ out = log (out ) + v_max
293+
275294 # merge heads
276295
277296 out = self .merge_heads (out )
0 commit comments