Skip to content

Commit 49ddf83

Browse files
committed
Fix residual in MLP when not fused_add_norm
1 parent a71bb5a commit 49ddf83

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mamba_ssm/modules/block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(
6969
if self.mlp is not None:
7070
if not self.fused_add_norm:
7171
residual = hidden_states + residual
72-
residual = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
72+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
7373
if self.residual_in_fp32:
7474
residual = residual.to(torch.float32)
7575
else:

0 commit comments

Comments
 (0)