@@ -184,8 +184,6 @@ def __init__(
184184 query_bias = True ,
185185 window_size = None ,
186186 num_memory_kv : int = 0 ,
187- laser = False ,
188- laser_softclamp_value = 15. ,
189187 enable_attn_softclamp = False ,
190188 attn_softclamp_value = 50. ,
191189 softmax_full_precision = False ,
@@ -211,8 +209,6 @@ def __init__(
211209 dropout = dropout ,
212210 window_size = window_size ,
213211 enable_attn_softclamp = enable_attn_softclamp ,
214- laser = laser ,
215- laser_softclamp_value = laser_softclamp_value ,
216212 attn_softclamp_value = attn_softclamp_value ,
217213 softmax_full_precision = softmax_full_precision
218214 )
@@ -322,8 +318,6 @@ class Attend(Module):
322318 def __init__ (
323319 self ,
324320 dropout = 0. ,
325- laser = False ,
326- laser_softclamp_value = 15. ,
327321 window_size = None ,
328322 scale : float | None = None ,
329323 enable_attn_softclamp = False ,
@@ -352,11 +346,6 @@ def __init__(
352346
353347 self .attn_dropout = nn .Dropout (dropout )
354348
355- # laser attention
356-
357- self .laser = laser
358- self .laser_softclamp_value = laser_softclamp_value
359-
360349 # softclamp attention logits
361350 # being adopted by a number of recent llms (gemma, grok)
362351
@@ -477,20 +466,10 @@ def local_attn(
477466
478467 attn = sim .softmax (dim = - 1 )
479468
480- # maybe laser
481-
482- if self .laser :
483- v = softclamp (v , self .laser_softclamp_value )
484-
485469 # aggregate
486470
487471 out = einsum (attn , v , "... i j, ... j d -> ... i d" )
488472
489- # maybe laser
490-
491- if self .laser :
492- out = log (out )
493-
494473 # un-window the output
495474
496475 out = rearrange (out , "b h n w d -> b h (n w) d" )
@@ -586,19 +565,8 @@ def forward(
586565
587566 attn = self .attn_dropout (attn )
588567
589- # maybe laser
590-
591- if self .laser :
592- v_max = v .amax (dim = - 2 , keepdim = True )
593- v = (v - v_max ).exp ()
594-
595568 # aggregate values
596569
597570 out = einsum (attn , v , "b h i j, b h j d -> b h i d" )
598571
599- # maybe laser
600-
601- if self .laser :
602- out = log (out ) + v_max
603-
604572 return out
0 commit comments