Skip to content

Commit f63d7f2

Browse files
authored
fix: prevent int32 overflow in k-grouped GEMM size calculations (#226)
1 parent ec5e9ed commit f63d7f2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

csrc/apis/gemm.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,13 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
280280

281281
// Shape checks
282282
const auto& [num_groups, m, n] = get_shape<3>(d);
283-
const auto& sum_mk = a.first.numel();
284-
const auto& sum_nk = b.first.numel();
285-
int sum_k = 0;
283+
const auto& sum_mk = static_cast<uint64_t>(a.first.numel());
284+
const auto& sum_nk = static_cast<uint64_t>(b.first.numel());
285+
uint64_t sum_k = 0;
286286
for (const auto& k: ks)
287-
sum_k += k;
288-
DG_HOST_ASSERT(sum_mk == m * sum_k);
289-
DG_HOST_ASSERT(sum_nk == n * sum_k);
287+
sum_k += static_cast<uint64_t>(k);
288+
DG_HOST_ASSERT(sum_mk == static_cast<uint64_t>(m) * sum_k);
289+
DG_HOST_ASSERT(sum_nk == static_cast<uint64_t>(n) * sum_k);
290290

291291
// Contiguity checks
292292
DG_HOST_ASSERT(a.first.is_contiguous());

0 commit comments

Comments
 (0)