Skip to content

Commit 3f10988

Browse files
Varun Sundar Rabindranathbnellnm
authored andcommitted
relax test_batched_moe tolerances
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent c4086d7 commit 3f10988

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
6666
[32, 64, 128, 192, 224, 256, 512])
6767
@pytest.mark.parametrize("K", [128, 256, 1024])
6868
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
69-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
69+
@pytest.mark.parametrize("dtype",
70+
[torch.float32, torch.float16, torch.bfloat16])
7071
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
7172
N: int, dtype: torch.dtype):
7273

@@ -104,4 +105,10 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
104105
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
105106
tensors.num_expert_tokens)
106107

107-
torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3)
108+
rtol, atol = {
109+
torch.float16: (6e-2, 6e-2),
110+
torch.bfloat16: (6e-2, 6e-2),
111+
torch.float32: (1e-2, 1e-2),
112+
}[test_output.dtype]
113+
114+
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)