Skip to content

Commit ffb1a08

Browse files
author
pytorchbot
committed
2025-09-27 nightly release (dbc93d4)
1 parent 85eed57 commit ffb1a08

File tree

15 files changed

+176
-70
lines changed

15 files changed

+176
-70
lines changed

fbgemm_gpu/cmake/Fbgemm.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ set(fbgemm_sources_avx2
2626
"${FBGEMM}/src/QuantUtilsAvx2.cc")
2727

2828
set(fbgemm_sources_avx512
29-
"${FBGEMM}/src/EmbeddingSpMDMAvx512.cc")
29+
"${FBGEMM}/src/EmbeddingSpMDMAvx512.cc"
30+
"${FBGEMM}/src/QuantUtilsAvx512.cc")
3031

3132
if(CXX_AVX2_FOUND)
3233
set_source_files_properties(${fbgemm_sources_avx2}
@@ -46,7 +47,7 @@ if(CXX_AVX2_FOUND)
4647
${fbgemm_sources}
4748
${fbgemm_sources_avx2})
4849
endif()
49-
if((NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) AND CXX_AVX512_FOUND)
50+
if(CXX_AVX512_FOUND)
5051
set(fbgemm_sources
5152
${fbgemm_sources}
5253
${fbgemm_sources_avx2}

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "fbgemm_gpu/embedding_common.h"
1919
#include "fbgemm/FbgemmEmbedding.h"
2020
#include "fbgemm_gpu/utils/tensor_utils.h"
21-
#include "fbgemm_gpu/config/feature_gates.h"
2221

2322
#if defined(__x86_64__) || defined(__i386__) || (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
2423
#include <immintrin.h>
@@ -191,9 +190,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
191190
{% else %}
192191
TORCH_CHECK(D > 0);
193192
{% endif %}
194-
const static bool disablePinnedMemory = fbgemm_gpu::config::is_feature_enabled_from_env(fbgemm_gpu::config::FeatureGateName::TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY);
195193
bool pinned_memory = false;
196-
if (!disablePinnedMemory && at::Context::hasCUDA() && at::getNumGPUs() > 0) {
194+
if (at::Context::hasCUDA() && at::getNumGPUs() > 0) {
197195
pinned_memory = true;
198196
}
199197

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,15 @@ class Sm100FmhaBwd {
267267
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
268268
auto [H, B] = product_each(HB);
269269
D = cutlass::round_up(D, 8); // Alignment
270-
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
270+
size_t Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
271271
size_t workspace_bytes = 0;
272+
size_t accum_size = sizeof(ElementAccumulator);
272273
// OdO vector
273-
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
274+
workspace_bytes += static_cast<size_t>(B)*static_cast<size_t>(H)*Q * accum_size;
274275
// scaled LSE vector
275-
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
276+
workspace_bytes += static_cast<size_t>(B)*static_cast<size_t>(H)*Q * accum_size;
276277
// FP32 versions of outputs that are churned (start off with Q only)
277-
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
278+
workspace_bytes += static_cast<size_t>(B)*static_cast<size_t>(H)*Q*static_cast<size_t>(D) * accum_size;
278279
return workspace_bytes;
279280
}
280281

fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ namespace fbgemm_gpu::config {
6262
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
6363
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
6464
X(BOUNDS_CHECK_INDICES_V2) \
65-
X(TBE_REPORT_INPUT_PARAMS) \
66-
X(TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY)
65+
X(TBE_REPORT_INPUT_PARAMS)
6766
// X(EXAMPLE_FEATURE_FLAG)
6867

6968
/// @ingroup fbgemm-gpu-config
@@ -92,13 +91,6 @@ bool check_feature_gate_key(const std::string& key);
9291
/// is enabled.
9392
bool is_feature_enabled(const FeatureGateName& feature);
9493

95-
/// @ingroup fbgemm-gpu-config
96-
///
97-
/// @brief For the given `FeatureGateName`, check if the corresponding inference
98-
/// feature is enabled in the env vars only. Only applicable for inference
99-
/// features suitable for env var rollouts
100-
bool is_feature_enabled_from_env(const FeatureGateName& feature);
101-
10294
#ifdef FBGEMM_FBCODE
10395
bool is_feature_enabled(const FbFeatureGateName& feature);
10496
#endif

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ at::Tensor FP8rowwise_to_float_cpu(
411411
const bool forward = true,
412412
const int64_t output_dtype = 0);
413413
at::Tensor fused8bitrowwise_to_half_cpu(const at::Tensor& input);
414+
at::Tensor fused8bitrowwise_to_bfloat16_cpu(const at::Tensor& input);
414415
at::Tensor fused8bitrowwise_to_float_or_half_cpu(
415416
const at::Tensor& input,
416417
const int64_t output_dtype,
@@ -469,6 +470,9 @@ at::Tensor _fusednbitrowwise_to_float_or_half_gpu(
469470
at::Tensor& _fused8bitrowwise_to_float_cpu_out(
470471
at::Tensor& output,
471472
const at::Tensor& input);
473+
at::Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
474+
at::Tensor& output,
475+
const at::Tensor& input);
472476
at::Tensor& _float_to_fused8bitrowwise_cpu_out(
473477
at::Tensor& output,
474478
const at::Tensor& input);

fbgemm_gpu/src/config/feature_gates.cpp

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,50 +45,39 @@ bool ev_check_key(const std::string& key) {
4545
}
4646
}
4747

48-
static bool check_feature_gate_key_impl(
49-
const std::string& key,
50-
bool check_env_vars_only) {
48+
DLL_PUBLIC bool check_feature_gate_key(const std::string& key) {
5149
// Cache feature flags to avoid repeated JK and env var checks
5250
static std::map<std::string, bool> feature_flags_cache;
53-
if (const auto search = feature_flags_cache.find(key);
54-
search != feature_flags_cache.end()) {
55-
return search->second;
56-
}
5751
#ifdef FBGEMM_FBCODE
58-
const auto value =
59-
check_env_vars_only ? ev_check_key(key) : jk_check_key(key);
60-
#else
61-
const auto value = ev_check_key(key);
52+
static const auto no_jk = ev_check_key("NO_JK");
6253
#endif
6354

64-
feature_flags_cache.insert({key, value});
65-
return value;
66-
}
55+
if (const auto search = feature_flags_cache.find(key);
56+
search != feature_flags_cache.end()) {
57+
return search->second;
6758

68-
DLL_PUBLIC bool check_feature_gate_key(const std::string& key) {
59+
} else {
60+
const auto value =
6961
#ifdef FBGEMM_FBCODE
70-
static const auto no_jk = ev_check_key("NO_JK");
62+
(no_jk) ? ev_check_key(key) : jk_check_key(key);
7163
#else
72-
static const auto no_jk = false;
64+
ev_check_key(key);
7365
#endif
7466

75-
return check_feature_gate_key_impl(key, no_jk);
67+
feature_flags_cache.insert({key, value});
68+
return value;
69+
}
7670
}
7771

7872
DLL_PUBLIC bool is_feature_enabled(const FeatureGateName& feature) {
7973
return check_feature_gate_key(to_string(feature));
8074
}
8175

82-
DLL_PUBLIC bool is_feature_enabled_from_env(const FeatureGateName& feature) {
83-
return check_feature_gate_key_impl(
84-
to_string(feature), /* check_env_vars_only */ true);
85-
}
86-
8776
#ifdef FBGEMM_FBCODE
8877
DLL_PUBLIC bool is_feature_enabled(const FbFeatureGateName& feature) {
8978
return check_feature_gate_key(to_string(feature));
9079
}
91-
#endif // FBGEMM_FBCODE
80+
#endif
9281

9382
} // namespace fbgemm_gpu::config
9483

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Tensor& _float_to_fused8bitrowwise_cpu_out_t(
5555
return output;
5656
}
5757

58-
template <typename output_t>
58+
template <typename output_t, bool is_uint16_t_of_type_bf16 = false>
5959
Tensor& _fused8bitrowwise_to_float_cpu_out_t(
6060
Tensor& output,
6161
const Tensor& input) {
@@ -78,7 +78,9 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
7878
auto output_data = static_cast<output_t*>(
7979
output.data_ptr()); // output.data_ptr<output_t>(); -> Yields
8080
// unresolved data_ptr symbol.
81-
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<output_t>(
81+
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<
82+
output_t,
83+
is_uint16_t_of_type_bf16>(
8284
input.data_ptr<uint8_t>(), nrows, ncols, output_data);
8385

8486
return output;
@@ -217,11 +219,19 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
217219
Tensor& _fused8bitrowwise_to_float_cpu_out(
218220
Tensor& output,
219221
const Tensor& input) {
220-
return _fused8bitrowwise_to_float_cpu_out_t<float>(output, input);
222+
return _fused8bitrowwise_to_float_cpu_out_t<float, false>(output, input);
221223
}
222224

223225
Tensor& fused8bitrowwise_to_half_cpu_out(Tensor& output, const Tensor& input) {
224-
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16>(output, input);
226+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16, false>(
227+
output, input);
228+
}
229+
230+
Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
231+
Tensor& output,
232+
const Tensor& input) {
233+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16, true>(
234+
output, input);
225235
}
226236

227237
/// @ingroup quantize-data-cpu
@@ -285,6 +295,13 @@ Tensor fused8bitrowwise_to_half_cpu(const Tensor& input) {
285295
return fused8bitrowwise_to_half_cpu_out(output, input);
286296
}
287297

298+
/// @ingroup quantize-data-cpu
299+
///
300+
Tensor fused8bitrowwise_to_bfloat16_cpu(const Tensor& input) {
301+
auto output = at::empty({0}, input.options().dtype(at::kBFloat16));
302+
return _fused8bitrowwise_to_bfloat16_cpu_out(output, input);
303+
}
304+
288305
/// @ingroup quantize-data-cpu
289306
///
290307
Tensor fused8bitrowwise_to_float_or_half_cpu(
@@ -305,6 +322,10 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
305322
output = at::empty({0}, input.options().dtype(at::kHalf));
306323
output = fused8bitrowwise_to_half_cpu_out(output, input);
307324
break;
325+
case SparseType::BF16:
326+
output = at::empty({0}, input.options().dtype(at::kBFloat16));
327+
output = _fused8bitrowwise_to_bfloat16_cpu_out(output, input);
328+
break;
308329
default:
309330
TORCH_CHECK(false);
310331
}
@@ -582,6 +603,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
582603
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor",
583604
{PT2_COMPLIANT_TAG});
584605
m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor");
606+
m.def("Fused8BitRowwiseQuantizedToBfloat16(Tensor input) -> Tensor");
585607
m.def(
586608
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor");
587609
m.def(
@@ -648,6 +670,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
648670
DISPATCH_TO_CPU(
649671
"Fused8BitRowwiseQuantizedToHalf",
650672
fbgemm_gpu::fused8bitrowwise_to_half_cpu);
673+
DISPATCH_TO_CPU(
674+
"Fused8BitRowwiseQuantizedToBfloat16",
675+
fbgemm_gpu::fused8bitrowwise_to_bfloat16_cpu);
651676
DISPATCH_TO_CPU(
652677
"Fused8BitRowwiseQuantizedToFloatOrHalf",
653678
fbgemm_gpu::fused8bitrowwise_to_float_or_half_cpu);

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
141141

142142
assume(ncols % (2 * num_elem_per_byte) == 0)
143143
if not test_cuda:
144-
# cpu path does not support bf16
144+
# cpu path only supports bf16 dequantization
145145
if output_dtype == SparseType.BF16:
146-
return
146+
input_data = input_data.float()
147147
if test_generic_op:
148148
quantized_data = (
149149
torch.ops.fbgemm.FloatOrHalfToFused8BitRowwiseQuantized(input_data)
@@ -171,6 +171,15 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
171171
dequantized_data = torch.ops.fbgemm.Fused8BitRowwiseQuantizedToHalf(
172172
quantized_data
173173
)
174+
elif output_dtype == SparseType.BF16:
175+
quantized_data = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(
176+
input_data,
177+
)
178+
dequantized_data = (
179+
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToBfloat16(
180+
quantized_data,
181+
)
182+
)
174183
else:
175184
raise NotImplementedError("Unsupported dtype")
176185

@@ -185,6 +194,10 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
185194
torch.testing.assert_close(dequantized_data.float(), reference.float())
186195
elif output_dtype == SparseType.FP16:
187196
torch.testing.assert_close(dequantized_data.half(), reference.half())
197+
elif output_dtype == SparseType.BF16:
198+
torch.testing.assert_close(
199+
dequantized_data.bfloat16(), reference.bfloat16()
200+
)
188201
if test_cuda and gpu_available:
189202
if nrows == 0 or ncols == 0:
190203
return

fbgemm_gpu/test/tbe/inference/nbit_forward_test.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# pyre-strict
99
# pyre-ignore-all-errors[56]
1010

11-
import os
1211
import random
1312
import unittest
1413
from typing import Any, Callable, Optional, Union
@@ -124,12 +123,6 @@
124123

125124
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
126125
class NBitFowardTest(NBitFowardTestCommon):
127-
def _is_cpu_output_on_pinned_memory(self) -> bool:
128-
return (
129-
os.getenv("FBGEMM_TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY") != "1"
130-
and torch.cuda.is_available()
131-
)
132-
133126
def execute_nbit_forward_fused_pooled_emb_quant_(
134127
self,
135128
T: int,
@@ -905,9 +898,6 @@ def test_nbit_forward_cpu_seq_int8(
905898
lengths = torch.cat(lengths_list, 0)
906899
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
907900
quant_cc_output = quant_cc(indices.int(), offsets.int())
908-
self.assertEqual(
909-
quant_cc_output.is_pinned(), self._is_cpu_output_on_pinned_memory()
910-
)
911901
tables_rows = [
912902
T for T, _, _ in quant_cc.split_embedding_weights_with_scale_bias(0)
913903
]

include/fbgemm/QuantUtils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "./FbgemmBuild.h" // @manual
1212
#include "./QuantUtilsAvx2.h" // @manual
13+
#include "./QuantUtilsAvx512.h" // @manual
1314
#include "./QuantUtilsNeon.h" // @manual
1415
#include "./Types.h" // @manual
1516
#include "./Utils.h" // @manual
@@ -330,7 +331,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
330331
* This version intentionally supports only 8-bit because
331332
* the corresponding quantize version only supports 8-bit.
332333
*/
333-
template <typename OutputType>
334+
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
334335
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
335336
const uint8_t* input,
336337
size_t input_rows,
@@ -377,7 +378,7 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
377378
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
378379
* This should not be called directly except in testing.
379380
*/
380-
template <typename OutputType>
381+
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
381382
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
382383
const uint8_t* input,
383384
size_t input_rows,

0 commit comments

Comments
 (0)