Skip to content

Commit 1bcf7d3

Browse files
q10facebook-github-bot
authored andcommitted
Migrate TBE inference kernels to FBGEMM_LAUNCH_KERNEL (#4092)
Summary: X-link: facebookresearch/FBGEMM#1176 - Migrate TBE inference kernels to `FBGEMM_LAUNCH_KERNEL` Reviewed By: spcyppt Differential Revision: D73731461
1 parent b35c0f8 commit 1bcf7d3

File tree

2 files changed

+38
-50
lines changed

2 files changed

+38
-50
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
10+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
1011
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
1112

1213
using namespace fbgemm_gpu;
@@ -170,28 +171,24 @@ Tensor pruned_hashmap_lookup_cuda(
170171

171172
AT_DISPATCH_INDEX_TYPES(
172173
indices.scalar_type(), "pruned_hashmap_lookup_cuda_1", [&] {
173-
#ifdef FBGEMM_GPU_MEMCHECK
174-
const auto func_name =
175-
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
176-
#endif
177-
178-
int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<<
174+
FBGEMM_LAUNCH_KERNEL(
175+
(int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<
176+
index_t,
177+
hash_t>),
179178
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
180179
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
181180
0,
182-
at::cuda::getCurrentCUDAStream()>>>(
183-
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
184-
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
185-
MAKE_PTA_WITH_NAME(func_name, hash_table, hash_t, 2, 64),
186-
MAKE_PTA_WITH_NAME(
187-
func_name, hash_table_offsets, int64_t, 1, 32),
181+
at::cuda::getCurrentCUDAStream(),
182+
PTA_B(indices, index_t, 1, 32),
183+
PTA_B(offsets, index_t, 1, 32),
184+
PTA_B(hash_table, hash_t, 2, 64),
185+
PTA_B(hash_table_offsets, int64_t, 1, 32),
188186
B,
189187
T,
190-
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
188+
PTA_B(dense_indices, index_t, 1, 32));
191189
});
192190
});
193191

194-
C10_CUDA_KERNEL_LAUNCH_CHECK();
195192
return dense_indices;
196193
}
197194

@@ -235,29 +232,24 @@ Tensor pruned_array_lookup_cuda(
235232

236233
AT_DISPATCH_INDEX_TYPES(
237234
indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] {
238-
#ifdef FBGEMM_GPU_MEMCHECK
239-
const auto func_name =
240-
"int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel";
241-
#endif
242-
243-
int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
235+
FBGEMM_LAUNCH_KERNEL(
236+
(int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<
237+
index_t,
238+
remap_t>),
244239
nbit::div_round_up(
245240
offsets.size(0), kForwardMaxThreads / kWarpSize),
246241
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
247242
0,
248-
at::cuda::getCurrentCUDAStream()>>>(
249-
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
250-
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
251-
MAKE_PTA_WITH_NAME(
252-
func_name, index_remappings, remap_t, 1, 64),
253-
MAKE_PTA_WITH_NAME(
254-
func_name, index_remappings_offsets, int64_t, 1, 32),
243+
at::cuda::getCurrentCUDAStream(),
244+
PTA_B(indices, index_t, 1, 32),
245+
PTA_B(offsets, index_t, 1, 32),
246+
PTA_B(index_remappings, remap_t, 1, 64),
247+
PTA_B(index_remappings_offsets, int64_t, 1, 32),
255248
B,
256249
T,
257-
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
250+
PTA_B(dense_indices, index_t, 1, 32));
258251
});
259252
});
260253

261-
C10_CUDA_KERNEL_LAUNCH_CHECK();
262254
return dense_indices;
263255
}

fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// clang-format off
1010
{%- set wdesc = "weighted" if weighted else "unweighted" %}
1111
#include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
12+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
1213
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
1314
#include "fbgemm_gpu/config/feature_gates.h"
1415

@@ -63,51 +64,46 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
6364
{%- macro define_kernel_invocation(emb_weight_type) %}
6465
{%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %}
6566

66-
#ifdef FBGEMM_GPU_MEMCHECK
67-
const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}";
68-
#endif
69-
7067
#ifdef X
7168
#undef X
7269
#endif
7370

74-
// Define {{ emb_weight_type }} kernel invocation macro
7571
#define X(DeviceOnly, PackedMode, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \
76-
{{ func_name }}<index_t, output_t, OutputRowsPerThread, kWarpsPerBlock, InputRowsInFlight, MinNum128BRows, MaxNum128BRows, DeviceOnly, PackedMode><<< \
72+
FBGEMM_LAUNCH_KERNEL( \
73+
({{ func_name }}<index_t, output_t, OutputRowsPerThread, kWarpsPerBlock, InputRowsInFlight, MinNum128BRows, MaxNum128BRows, DeviceOnly, PackedMode>), \
7774
nbit::div_round_up(T * nbit::div_round_up(B, num_packed_bags * OutputRowsPerThread), kWarpsPerBlock), \
7875
dim3(kWarpSize, kWarpsPerBlock), \
7976
0, \
80-
at::cuda::getCurrentCUDAStream()>>>( \
81-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, dev_weights, uint8_t, 1, 64), \
82-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, uvm_weights, uint8_t, 1, 64), \
83-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_placements, int32_t, 1, 32), \
84-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_offsets, int64_t, 1, 32), \
85-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_tys, uint8_t, 1, 32), \
77+
at::cuda::getCurrentCUDAStream(), \
78+
PTA_B(dev_weights, uint8_t, 1, 64), \
79+
PTA_B(uvm_weights, uint8_t, 1, 64), \
80+
PTA_B(weights_placements, int32_t, 1, 32), \
81+
PTA_B(weights_offsets, int64_t, 1, 32), \
82+
PTA_B(weights_tys, uint8_t, 1, 32), \
8683
{%- if not nobag %}
87-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, D_offsets, int32_t, 1, 32), \
84+
PTA_B(D_offsets, int32_t, 1, 32), \
8885
{%- else %}
8986
D, \
9087
{%- endif %}
9188
FixedDivisor(div_round_up(B, num_packed_bags * OutputRowsPerThread)), \
92-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, indices, index_t, 1, 32), \
93-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, offsets, index_t, 1, 32), \
89+
PTA_B(indices, index_t, 1, 32), \
90+
PTA_B(offsets, index_t, 1, 32), \
9491
{%- if not nobag %}
9592
pooling_mode, \
9693
{%- endif %}
9794
row_alignment, \
9895
{%- if weighted %}
99-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, indice_weights, float, 1, 32), \
96+
PTA_B(indice_weights, float, 1, 32), \
10097
{%- endif %}
10198
{%- if emb_weight_type == "FP8" %}
10299
fp8_exponent_bits, \
103100
fp8_exponent_bias, \
104101
{%- endif %}
105102
num_packed_bags, \
106-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, output, output_t, 2, 32), \
107-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_weights, uint8_t, 2, 64), \
108-
MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_locations, int32_t, 1, 32) \
109-
); \
110-
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
103+
PTA_B(output, output_t, 2, 32), \
104+
PTA_B(lxu_cache_weights, uint8_t, 2, 64), \
105+
PTA_B(lxu_cache_locations, int32_t, 1, 32) \
106+
);
111107
{%- endmacro %}
112108

113109
{%- macro construct_and_return_output_tensor() %}

0 commit comments

Comments
 (0)