Skip to content

Commit 16dd8b3

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Add output as an option in CUTLASS grouped GEMM (#4931)
Summary: X-link: facebookresearch/FBGEMM#1954 Enable output as an option in CUTLASS grouped GEMM, as pretraining requires assigning empty preallocated output tensor for usecases in fprop and dgrad. Differential Revision: D83126291
1 parent 826064d commit 16dd8b3

File tree

5 files changed

+255
-42
lines changed

5 files changed

+255
-42
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,11 @@ at::Tensor bf16bf16bf16_grouped_cat(at::TensorList X, at::TensorList W) {
345345
return _bf16bf16bf16_grouped<at::Tensor>(X, W);
346346
}
347347

348-
at::Tensor
349-
bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
348+
at::Tensor bf16bf16bf16_grouped_stacked(
349+
at::Tensor X,
350+
at::Tensor W,
351+
at::Tensor M_sizes,
352+
std::optional<at::Tensor> Y) {
350353
int64_t total_M = X.size(0);
351354
int64_t N = W.size(1);
352355
int64_t K = W.size(2);
@@ -356,14 +359,21 @@ bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
356359
"M_sizes must be on same device as inputs.");
357360
TORCH_CHECK(
358361
W.dim() == 3 && W.size(0) == G, "Weights should be shape [G, N, K].")
359-
at::Tensor Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
362+
363+
at::Tensor output_tensor;
364+
if (Y.has_value()) {
365+
output_tensor = Y.value();
366+
} else {
367+
output_tensor = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
368+
}
369+
360370
// Early exit for empty inputs.
361371
if (total_M == 0) {
362-
return Y.view({total_M, N});
372+
return output_tensor.view({total_M, N});
363373
}
364374
// Return continuous view of output.
365375
at::Tensor out = dispatch_bf16_grouped_kernel<at::Tensor>(
366-
G, total_M, N, K, X, W, Y, std::nullopt, M_sizes);
376+
G, total_M, N, K, X, W, output_tensor, std::nullopt, M_sizes);
367377
return out.view({total_M, N});
368378
}
369379

@@ -411,7 +421,11 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
411421
"CUDA version is older than 12.0"); // requires CUDA>=12
412422
}
413423

414-
at::Tensor bf16bf16bf16_grouped_stacked(at::Tensor, at::Tensor, at::Tensor) {
424+
at::Tensor bf16bf16bf16_grouped_stacked(
425+
at::Tensor,
426+
at::Tensor,
427+
at::Tensor,
428+
std::optional<at::Tensor>) {
415429
throw std::runtime_error(
416430
"CUDA version is older than 12.0"); // requires CUDA>=12
417431
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad.cu

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,11 @@ at::Tensor dispatch_bf16_grouped_kernel(
300300
return kernel(X, W, output, M_sizes);
301301
}
302302

303-
at::Tensor
304-
bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
303+
at::Tensor bf16bf16bf16_grouped_grad(
304+
at::Tensor X,
305+
at::Tensor W,
306+
at::Tensor M_sizes,
307+
std::optional<at::Tensor> Y) {
305308
int64_t total_M = X.size(0);
306309
int64_t N = W.size(1);
307310
int64_t K = W.size(2);
@@ -315,20 +318,29 @@ bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
315318
TORCH_CHECK(X.stride(-1) == 1, "Activation memory layout must be row-major.");
316319
TORCH_CHECK(W.stride(-2) == 1, "Weight memory layout must be column-major.");
317320

318-
at::Tensor Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
321+
at::Tensor output_tensor;
322+
if (Y.has_value()) {
323+
output_tensor = Y.value();
324+
} else {
325+
output_tensor = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
326+
}
319327
// Early exit for empty inputs.
320328
if (total_M == 0) {
321-
return Y.view({total_M, N});
329+
return output_tensor.view({total_M, N});
322330
}
323331
// Return continuous view of output.
324-
at::Tensor out =
325-
dispatch_bf16_grouped_kernel(G, total_M, N, K, X, W, Y, M_sizes);
332+
at::Tensor out = dispatch_bf16_grouped_kernel(
333+
G, total_M, N, K, X, W, output_tensor, M_sizes);
326334
return out.view({total_M, N});
327335
}
328336

329337
#else
330338

331-
at::Tensor bf16bf16bf16_grouped_grad(at::Tensor, at::Tensor, at::Tensor) {
339+
at::Tensor bf16bf16bf16_grouped_grad(
340+
at::Tensor,
341+
at::Tensor,
342+
at::Tensor,
343+
std::optional<at::Tensor>) {
332344
throw std::runtime_error(
333345
"CUDA version is older than 12.0"); // requires CUDA>=12
334346
}
@@ -338,12 +350,18 @@ at::Tensor bf16bf16bf16_grouped_grad(at::Tensor, at::Tensor, at::Tensor) {
338350
at::Tensor bf16bf16bf16_grouped_grad_meta(
339351
at::Tensor X,
340352
at::Tensor W,
341-
at::Tensor /* M_sizes */) {
353+
at::Tensor /* M_sizes */,
354+
std::optional<at::Tensor> Y) {
342355
const at::SymInt total_M = X.sym_size(0);
343356
const at::SymInt N = W.sym_size(1);
344-
at::Tensor Y =
345-
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
346-
return Y;
357+
358+
if (Y.has_value()) {
359+
return Y.value();
360+
} else {
361+
at::Tensor output =
362+
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
363+
return output;
364+
}
347365
}
348366

349367
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
@@ -356,7 +374,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
356374

357375
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
358376
m.def(
359-
"bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes) -> Tensor");
377+
"bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes, Tensor? Y=None) -> Tensor");
360378
}
361379

362380
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
7676
at::Tensor X,
7777
at::Tensor W,
7878
at::Tensor zero_start_index_M);
79-
at::Tensor
80-
bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes);
79+
at::Tensor bf16bf16bf16_grouped_stacked(
80+
at::Tensor X,
81+
at::Tensor W,
82+
at::Tensor M_sizes,
83+
std::optional<at::Tensor> Y = std::nullopt);
8184
at::Tensor f8f8bf16_rowwise(
8285
at::Tensor XQ,
8386
at::Tensor WQ,
@@ -781,12 +784,18 @@ at::Tensor bf16bf16bf16_grouped_dynamic_meta(
781784
at::Tensor bf16bf16bf16_grouped_stacked_meta(
782785
at::Tensor X,
783786
at::Tensor W,
784-
at::Tensor /* M_sizes */) {
787+
at::Tensor /* M_sizes */,
788+
std::optional<at::Tensor> Y) {
785789
const at::SymInt total_M = X.sym_size(0);
786790
const at::SymInt N = W.sym_size(1);
787-
at::Tensor Y =
788-
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
789-
return Y;
791+
792+
if (Y.has_value()) {
793+
return Y.value();
794+
} else {
795+
at::Tensor output =
796+
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
797+
return output;
798+
}
790799
}
791800

792801
at::Tensor f8f8bf16_rowwise_grouped_stacked_meta(

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6363
m.def(
6464
"bf16bf16bf16_grouped_dynamic(Tensor X, Tensor W, Tensor zero_start_index_M) -> Tensor");
6565
m.def(
66-
"bf16bf16bf16_grouped_stacked(Tensor X, Tensor W, Tensor M_sizes) -> Tensor");
66+
"bf16bf16bf16_grouped_stacked(Tensor X, Tensor W, Tensor M_sizes, Tensor? Y=None) -> Tensor");
6767
m.def(
6868
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor");
6969
m.def(

0 commit comments

Comments
 (0)