Skip to content

Commit dec4d06

Browse files
cthifacebook-github-bot
authored andcommitted
Demonstration of per-op targets
Summary: For some new kernels where we are trying to add a lot of instances, it could cause code size bloat and we would encounter relocation issues on targets including fbgemm. Although we should invest in ways to cut down our kernels code size (which I have some other ideas for later), inevitably over time as number of kernels increase, the chances of hitting this would increase a lot and is rather painful to deal with in fbcode. Currently we have a monolithic approach where all ops are pulled in when adding `:quantize_ops_gpu` as a dep. Instead we can go with a granular approach where kernels could be pulled in selectively. This approach seems to work well, in this diff we give an example of it with adding only the bf16 grouped grad/wgrad kernels in `:quantize_bench` but not in `:quantize_ops_gpu`. The only minor down-side is the user of the ops would have to know to add the dep into their buck traget, but I think the upside is quite high as we increase the number of kernels in fbgemm and deal with more users of fbgemm, especially those with stricter requirements on code size (e.g. Sigrid predictor). Reviewed By: jiawenliu64 Differential Revision: D83056686
1 parent 1c7b6d4 commit dec4d06

File tree

5 files changed

+64
-41
lines changed

5 files changed

+64
-41
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,15 @@
4646
torch.ops.load_library(
4747
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:gather_scatter_ops"
4848
)
49+
50+
gemm_ops = [
51+
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions:cutlass_bf16bf16bf16_grouped_grad",
52+
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions:cutlass_bf16bf16bf16_grouped_wgrad",
53+
]
54+
for op in gemm_ops:
55+
try:
56+
torch.ops.load_library(
57+
op,
58+
)
59+
except OSError:
60+
pass

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <torch/library.h>
1112

1213
#include "bf16bf16bf16_grouped_grad/bf16bf16bf16_grouped_grad_manifest.cuh"
1314
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
@@ -334,4 +335,28 @@ at::Tensor bf16bf16bf16_grouped_grad(at::Tensor, at::Tensor, at::Tensor) {
334335

335336
#endif
336337

338+
at::Tensor bf16bf16bf16_grouped_grad_meta(
339+
at::Tensor X,
340+
at::Tensor W,
341+
at::Tensor /* M_sizes */) {
342+
const at::SymInt total_M = X.sym_size(0);
343+
const at::SymInt N = W.sym_size(1);
344+
at::Tensor Y =
345+
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
346+
return Y;
347+
}
348+
349+
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
350+
m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad);
351+
}
352+
353+
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
354+
m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad_meta);
355+
}
356+
357+
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
358+
m.def(
359+
"bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes) -> Tensor");
360+
}
361+
337362
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <torch/library.h>
1112

1213
#include "bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh"
1314
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
@@ -1093,4 +1094,30 @@ at::Tensor bf16bf16bf16_grouped_wgrad(
10931094

10941095
#endif
10951096

1097+
at::Tensor bf16bf16bf16_grouped_wgrad_meta(
1098+
at::Tensor X,
1099+
at::Tensor W,
1100+
at::Tensor M_sizes,
1101+
std::optional<at::Tensor> /* output = std::nullopt */,
1102+
bool /* output_accum = false */) {
1103+
const at::SymInt G = M_sizes.size(0);
1104+
const at::SymInt N = X.sym_size(1);
1105+
const at::SymInt K = W.sym_size(1);
1106+
at::Tensor Y = at::empty_symint({G, N, K}, X.options().dtype(at::kBFloat16));
1107+
return Y;
1108+
}
1109+
1110+
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
1111+
m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad);
1112+
}
1113+
1114+
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
1115+
m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad_meta);
1116+
}
1117+
1118+
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
1119+
m.def(
1120+
"bf16bf16bf16_grouped_wgrad(Tensor X, Tensor W, Tensor M_sizes, Tensor(a!)? output=None, bool output_accum=False) -> Tensor");
1121+
}
1122+
10961123
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,6 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
7878
at::Tensor zero_start_index_M);
7979
at::Tensor
8080
bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes);
81-
at::Tensor
82-
bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes);
83-
at::Tensor bf16bf16bf16_grouped_wgrad(
84-
at::Tensor X,
85-
at::Tensor W,
86-
at::Tensor M_sizes,
87-
std::optional<at::Tensor> output = std::nullopt,
88-
bool output_accum = false);
8981
at::Tensor f8f8bf16_rowwise(
9082
at::Tensor XQ,
9183
at::Tensor WQ,
@@ -325,8 +317,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
325317
m.impl("bf16i4bf16_shuffled", bf16i4bf16_shuffled);
326318
m.impl("f8i4bf16_shuffled_grouped", f8i4bf16_shuffled_grouped);
327319
m.impl("bf16i4bf16_shuffled_grouped", bf16i4bf16_shuffled_grouped);
328-
m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad);
329-
m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad);
330320
m.impl("preshuffle_i4", preshuffle_i4);
331321
m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched);
332322
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
@@ -382,7 +372,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
382372
m.impl("bf16i4bf16_shuffled", bf16i4bf16_shuffled);
383373
m.impl("f8i4bf16_shuffled_grouped", f8i4bf16_shuffled_grouped);
384374
m.impl("bf16i4bf16_shuffled_grouped", bf16i4bf16_shuffled_grouped);
385-
m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad);
386375
m.impl("preshuffle_i4", preshuffle_i4);
387376
m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched);
388377
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
@@ -800,30 +789,6 @@ at::Tensor bf16bf16bf16_grouped_stacked_meta(
800789
return Y;
801790
}
802791

803-
at::Tensor bf16bf16bf16_grouped_grad_meta(
804-
at::Tensor X,
805-
at::Tensor W,
806-
at::Tensor /* M_sizes */) {
807-
const at::SymInt total_M = X.sym_size(0);
808-
const at::SymInt N = W.sym_size(1);
809-
at::Tensor Y =
810-
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
811-
return Y;
812-
}
813-
814-
at::Tensor bf16bf16bf16_grouped_wgrad_meta(
815-
at::Tensor X,
816-
at::Tensor W,
817-
at::Tensor M_sizes,
818-
std::optional<at::Tensor> /* output = std::nullopt */,
819-
bool /* output_accum = false */) {
820-
const at::SymInt G = M_sizes.size(0);
821-
const at::SymInt N = X.sym_size(1);
822-
const at::SymInt K = W.sym_size(1);
823-
at::Tensor Y = at::empty_symint({G, N, K}, X.options().dtype(at::kBFloat16));
824-
return Y;
825-
}
826-
827792
at::Tensor f8f8bf16_rowwise_grouped_stacked_meta(
828793
at::Tensor XQ,
829794
at::Tensor WQ,
@@ -864,8 +829,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
864829
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta);
865830
m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched_meta);
866831
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
867-
m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad_meta);
868-
m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad_meta);
869832
m.impl("f8f8bf16_lite", f8f8bf16_lite_meta);
870833
m.impl("scaled_fp4_quant", scaled_fp4_quant_meta);
871834
m.impl("preshuffle_i4", preshuffle_i4_meta);

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6464
"bf16bf16bf16_grouped_dynamic(Tensor X, Tensor W, Tensor zero_start_index_M) -> Tensor");
6565
m.def(
6666
"bf16bf16bf16_grouped_stacked(Tensor X, Tensor W, Tensor M_sizes) -> Tensor");
67-
m.def(
68-
"bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes) -> Tensor");
69-
m.def(
70-
"bf16bf16bf16_grouped_wgrad(Tensor X, Tensor W, Tensor M_sizes, Tensor(a!)? output=None, bool output_accum=False) -> Tensor");
7167
m.def(
7268
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor");
7369
m.def(

0 commit comments

Comments
 (0)