Skip to content

Commit 909f970

Browse files
committed
typo fix
1 parent 50bffae commit 909f970

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/ops/test_mamba_cu_seqlens_equivalence.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def main():
100100
out = mamba(packed_hidden_states, cu_seqlens)
101101

102102
# Testing the max/mean diff
103-
print(f'Output max diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().max().item()}')
104-
print(f'Output mean diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().mean().item()}')
103+
print(f'max diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().max().item()}')
104+
print(f'mean diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().mean().item()}')
105105
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
106106

107107
# bwd for mamba w/ cu_seqlens
@@ -117,9 +117,9 @@ def main():
117117
# check bwd pass
118118
assert set(mamba_grad.keys()) == set(mamba_ref_grad.keys())
119119
for name in mamba_ref_grad:
120-
print(f'Output max diff for {name} in varlen_mamba bwd pass: {( - mamba_ref_grad[name]).abs().max().item()}')
121-
print(f'Output mean diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().mean().item()}')
120+
print(f'max diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().max().item()}')
121+
print(f'mean diff for {name} in varlen_mamba bwd pass: {(mamba_grad[name] - mamba_ref_grad[name]).abs().mean().item()}')
122122
assert torch.allclose(mamba_grad[name], mamba_ref_grad[name], rtol=rtol, atol=atol)
123123

124124
if __name__ == "__main__":
125-
main()
125+
main()

0 commit comments

Comments
 (0)