@@ -66,7 +66,8 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
66
66
[32 , 64 , 128 , 192 , 224 , 256 , 512 ])
67
67
@pytest .mark .parametrize ("K" , [128 , 256 , 1024 ])
68
68
@pytest .mark .parametrize ("N" , [128 , 256 , 512 , 1024 ])
69
- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
69
+ @pytest .mark .parametrize ("dtype" ,
70
+ [torch .float32 , torch .float16 , torch .bfloat16 ])
70
71
def test_batched_mm (num_experts : int , max_tokens_per_expert : int , K : int ,
71
72
N : int , dtype : torch .dtype ):
72
73
@@ -104,4 +105,10 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
104
105
ref_output = ref_impl (tensors .A , tensors .B , ref_output ,
105
106
tensors .num_expert_tokens )
106
107
107
- torch .testing .assert_close (test_output , ref_output , atol = 1e-3 , rtol = 1e-3 )
108
+ rtol , atol = {
109
+ torch .float16 : (6e-2 , 6e-2 ),
110
+ torch .bfloat16 : (6e-2 , 6e-2 ),
111
+ torch .float32 : (1e-2 , 1e-2 ),
112
+ }[test_output .dtype ]
113
+
114
+ torch .testing .assert_close (test_output , ref_output , atol = atol , rtol = rtol )
0 commit comments