Skip to content

Commit b85b20d

Browse files
committed
Add swiglu activation
1 parent c8396a3 commit b85b20d

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

mamba_ssm/ops/triton/k_activations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,19 @@ def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
151151
return dxy.reshape(*batch_shape, dxy.shape[-1])
152152
else:
153153
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
154+
155+
156+
class SwiGLU(torch.autograd.Function):
157+
158+
@staticmethod
159+
def forward(ctx, xy):
160+
ctx.save_for_backward(xy)
161+
return _swiglu_fwd(xy)
162+
163+
@staticmethod
164+
def backward(ctx, dout):
165+
xy, = ctx.saved_tensors
166+
return _swiglu_bwd(xy, dout)
167+
168+
169+
swiglu = SwiGLU.apply

0 commit comments

Comments
 (0)