Skip to content

Commit 8784fab

Browse files
cthifacebook-github-bot
authored andcommitted
Use ATen API to get the device major arch (#4901)
Summary: Pull Request resolved: #4901 X-link: facebookresearch/FBGEMM#1928 There is some weird stuff going on, and CUDA13 build is broken in torch when we bump the pin. I have a gut feeling there is some weirdness with the cuda headers. To fix this instead we can just try to use the more modern ATen API, as that seems ok. Reviewed By: q10 Differential Revision: D82855103 fbshipit-source-id: 0cb5931de3f4997702da3c9b101e8cd3798ee51f
1 parent 6968a68 commit 8784fab

File tree

2 files changed

+16
-35
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/quantize/common

2 files changed

+16
-35
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/utils.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <climits>
1212
#include <cstdint>
1313

14+
#include <ATen/cuda/CUDAContext.h>
15+
1416
namespace fbgemm_gpu {
1517

1618
constexpr int64_t nextPowerOf2(int64_t num) {
@@ -19,6 +21,19 @@ constexpr int64_t nextPowerOf2(int64_t num) {
1921
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
2022
}
2123

22-
int getDeviceArch();
24+
inline int getDeviceArch() {
25+
static int arch = []() {
26+
const int majorVersion =
27+
at::cuda::getDeviceProperties(at::cuda::current_device())->major;
28+
if (majorVersion >= 10) {
29+
int runtimeVersion = 0;
30+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
31+
TORCH_CHECK(
32+
runtimeVersion >= 12080, "SM100a+ kernels require cuda >= 12.8");
33+
}
34+
return majorVersion;
35+
}();
36+
return arch;
37+
}
2338

2439
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/common/utils.cpp

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)