Skip to content

Commit 10354f9

Browse files
Jason Chenfacebook-github-bot
authored andcommitted
Add Inference Feature to Skip Pinned Memory Creation (#4924)
Summary: Pull Request resolved: #4924 X-link: facebookresearch/FBGEMM#1948 Reviewed By: q10 Differential Revision: D83100663 fbshipit-source-id: d30d2742dd69834c7f0b7684d55033461fc3024a
1 parent a6a6007 commit 10354f9

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "fbgemm_gpu/embedding_common.h"
1919
#include "fbgemm/FbgemmEmbedding.h"
2020
#include "fbgemm_gpu/utils/tensor_utils.h"
21+
#include "fbgemm_gpu/config/feature_gates.h"
2122

2223
#if defined(__x86_64__) || defined(__i386__) || (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
2324
#include <immintrin.h>
@@ -190,8 +191,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
190191
{% else %}
191192
TORCH_CHECK(D > 0);
192193
{% endif %}
194+
const static bool disablePinnedMemory = fbgemm_gpu::config::is_feature_enabled_from_env(fbgemm_gpu::config::FeatureGateName::TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY);
193195
bool pinned_memory = false;
194-
if (at::Context::hasCUDA() && at::getNumGPUs() > 0) {
196+
if (!disablePinnedMemory && at::Context::hasCUDA() && at::getNumGPUs() > 0) {
195197
pinned_memory = true;
196198
}
197199

fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ namespace fbgemm_gpu::config {
6262
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
6363
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
6464
X(BOUNDS_CHECK_INDICES_V2) \
65-
X(TBE_REPORT_INPUT_PARAMS)
65+
X(TBE_REPORT_INPUT_PARAMS) \
66+
X(TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY)
6667
// X(EXAMPLE_FEATURE_FLAG)
6768

6869
/// @ingroup fbgemm-gpu-config
@@ -91,6 +92,13 @@ bool check_feature_gate_key(const std::string& key);
9192
/// is enabled.
9293
bool is_feature_enabled(const FeatureGateName& feature);
9394

95+
/// @ingroup fbgemm-gpu-config
96+
///
97+
/// @brief For the given `FeatureGateName`, check if the corresponding inference
98+
/// feature is enabled in the env vars only. Only applicable for inference
99+
/// features suitable for env var rollouts
100+
bool is_feature_enabled_from_env(const FeatureGateName& feature);
101+
94102
#ifdef FBGEMM_FBCODE
95103
bool is_feature_enabled(const FbFeatureGateName& feature);
96104
#endif

fbgemm_gpu/src/config/feature_gates.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,39 +45,50 @@ bool ev_check_key(const std::string& key) {
4545
}
4646
}
4747

48-
DLL_PUBLIC bool check_feature_gate_key(const std::string& key) {
48+
static bool check_feature_gate_key_impl(
49+
const std::string& key,
50+
bool check_env_vars_only) {
4951
// Cache feature flags to avoid repeated JK and env var checks
5052
static std::map<std::string, bool> feature_flags_cache;
51-
#ifdef FBGEMM_FBCODE
52-
static const auto no_jk = ev_check_key("NO_JK");
53-
#endif
54-
5553
if (const auto search = feature_flags_cache.find(key);
5654
search != feature_flags_cache.end()) {
5755
return search->second;
56+
}
57+
#ifdef FBGEMM_FBCODE
58+
const auto value =
59+
check_env_vars_only ? ev_check_key(key) : jk_check_key(key);
60+
#else
61+
const auto value = ev_check_key(key);
62+
#endif
5863

59-
} else {
60-
const auto value =
64+
feature_flags_cache.insert({key, value});
65+
return value;
66+
}
67+
68+
DLL_PUBLIC bool check_feature_gate_key(const std::string& key) {
6169
#ifdef FBGEMM_FBCODE
62-
(no_jk) ? ev_check_key(key) : jk_check_key(key);
70+
static const auto no_jk = ev_check_key("NO_JK");
6371
#else
64-
ev_check_key(key);
72+
static const auto no_jk = false;
6573
#endif
6674

67-
feature_flags_cache.insert({key, value});
68-
return value;
69-
}
75+
return check_feature_gate_key_impl(key, no_jk);
7076
}
7177

7278
DLL_PUBLIC bool is_feature_enabled(const FeatureGateName& feature) {
7379
return check_feature_gate_key(to_string(feature));
7480
}
7581

82+
DLL_PUBLIC bool is_feature_enabled_from_env(const FeatureGateName& feature) {
83+
return check_feature_gate_key_impl(
84+
to_string(feature), /* check_env_vars_only */ true);
85+
}
86+
7687
#ifdef FBGEMM_FBCODE
7788
DLL_PUBLIC bool is_feature_enabled(const FbFeatureGateName& feature) {
7889
return check_feature_gate_key(to_string(feature));
7990
}
80-
#endif
91+
#endif // FBGEMM_FBCODE
8192

8293
} // namespace fbgemm_gpu::config
8394

fbgemm_gpu/test/tbe/inference/nbit_forward_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99
# pyre-ignore-all-errors[56]
1010

11+
import os
1112
import random
1213
import unittest
1314
from typing import Any, Callable, Optional, Union
@@ -123,6 +124,12 @@
123124

124125
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
125126
class NBitFowardTest(NBitFowardTestCommon):
127+
def _is_cpu_output_on_pinned_memory(self) -> bool:
128+
return (
129+
os.getenv("FBGEMM_TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY") != "1"
130+
and torch.cuda.is_available()
131+
)
132+
126133
def execute_nbit_forward_fused_pooled_emb_quant_(
127134
self,
128135
T: int,
@@ -898,6 +905,9 @@ def test_nbit_forward_cpu_seq_int8(
898905
lengths = torch.cat(lengths_list, 0)
899906
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
900907
quant_cc_output = quant_cc(indices.int(), offsets.int())
908+
self.assertEqual(
909+
quant_cc_output.is_pinned(), self._is_cpu_output_on_pinned_memory()
910+
)
901911
tables_rows = [
902912
T for T, _, _ in quant_cc.split_embedding_weights_with_scale_bias(0)
903913
]

0 commit comments

Comments
 (0)