@@ -40,6 +40,9 @@ def pack_one(t, pattern):
4040def unpack_one (t , ps , pattern ):
4141 return unpack (t , ps , pattern )[0 ]
4242
43+ def log (t , eps = 1e-20 ):
44+ return t .clamp (min = eps ).log ()
45+
4346def softclamp (t , value ):
4447 return (t / value ).tanh () * value
4548
@@ -181,6 +184,7 @@ def __init__(
181184 query_bias = True ,
182185 window_size = None ,
183186 num_memory_kv : int = 0 ,
187+ laser = False ,
184188 enable_attn_softclamp = False ,
185189 attn_softclamp_value = 50. ,
186190 softmax_full_precision = False
@@ -202,6 +206,7 @@ def __init__(
202206 dim_inner = dim_head * heads
203207
204208 self .attend = Attend (
209+ laser = laser ,
205210 dropout = dropout ,
206211 window_size = window_size ,
207212 enable_attn_softclamp = enable_attn_softclamp ,
@@ -299,6 +304,7 @@ class Attend(Module):
299304 def __init__ (
300305 self ,
301306 dropout = 0. ,
307+ laser = False ,
302308 window_size = None ,
303309 scale : float | None = None ,
304310 enable_attn_softclamp = False ,
@@ -327,6 +333,10 @@ def __init__(
327333
328334 self .attn_dropout = nn .Dropout (dropout )
329335
336+ # laser attention
337+
338+ self .laser = laser
339+
330340 # softclamp attention logits
331341 # being adopted by a number of recent llms (gemma, grok)
332342
@@ -447,10 +457,21 @@ def local_attn(
447457
448458 attn = sim .softmax (dim = - 1 )
449459
460+ # maybe laser
461+
462+ if self .laser :
463+ v_max = v .amax (dim = - 2 , keepdim = True )
464+ v = (v - v_max ).exp ()
465+
450466 # aggregate
451467
452468 out = einsum (attn , v , "... i j, ... j d -> ... i d" )
453469
470+ # maybe laser
471+
472+ if self .laser :
473+ out = log (out ) + v_max
474+
454475 # un-window the output
455476
456477 out = rearrange (out , "b h n w d -> b h (n w) d" )
@@ -546,8 +567,19 @@ def forward(
546567
547568 attn = self .attn_dropout (attn )
548569
570+ # maybe laser
571+
572+ if self .laser :
573+ v_max = v .amax (dim = - 2 , keepdim = True )
574+ v = (v - v_max ).exp ()
575+
549576 # aggregate values
550577
551578 out = einsum (attn , v , "b h i j, b h j d -> b h i d" )
552579
580+ # maybe laser
581+
582+ if self .laser :
583+ out = log (out ) + v_max
584+
553585 return out
0 commit comments