Skip to content

Commit 01a187d

Browse files
Alexander Balabanfacebook-github-bot
authored andcommitted
tbe cpu nobag dispatch and forward pass kernel impl (#4302)
Summary: Pull Request resolved: #4302 X-link: facebookresearch/FBGEMM#1378 diff introduces simple forward pass kernel for cpu to cover tbe with pooling mode none Differential Revision: D75464152
1 parent b9c9c89 commit 01a187d

File tree

4 files changed

+296
-7
lines changed

4 files changed

+296
-7
lines changed

fbgemm_gpu/cmake/tbe_sources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@
319319
static_cpu_files_common = [
320320
"codegen/utils/embedding_bounds_check_host_cpu.cpp",
321321
"codegen/training/forward/embedding_forward_split_cpu.cpp",
322+
"codegen/training/forward/embedding_forward_split_nobag_cpu.cpp",
322323
"codegen/training/pt2/pt2_autograd_utils.cpp",
323324
]
324325

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/Parallel.h>
10+
#include <ATen/core/op_registration/op_registration.h>
11+
#include <torch/script.h>
12+
13+
#include "fbgemm_gpu/embedding_common.h"
14+
#include "fbgemm_gpu/utils/cpu_utils.h"
15+
#include "fbgemm_gpu/utils/dispatch_macros.h"
16+
#include "fbgemm_gpu/utils/ops_utils.h"
17+
18+
#if FBGEMM_GPU_MEMCHECK
19+
#define FBGEMM_MEM_CHECK_ONLY
20+
#else
21+
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
22+
#endif
23+
24+
using Tensor = at::Tensor;
25+
using namespace fbgemm_gpu;
26+
27+
template <
28+
typename weights_t,
29+
typename index_t,
30+
typename offset_t,
31+
typename output_t>
32+
void split_embedding_nobag_codegen_forward_cpu_kernel(
33+
const Tensor& weights,
34+
const Tensor& weights_offsets,
35+
int64_t D,
36+
const Tensor& hash_size_cumsum,
37+
const Tensor& indices,
38+
const Tensor& offsets,
39+
const Tensor& output) {
40+
TORCH_CHECK(weights.is_contiguous());
41+
Tensor indices_contig = indices.contiguous();
42+
Tensor offsets_contig = offsets.contiguous();
43+
44+
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
45+
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
46+
const auto indices_data = indices.data_ptr<index_t>();
47+
const auto offsets_data = offsets.data_ptr<offset_t>();
48+
const auto weights_data = weights.data_ptr<weights_t>();
49+
auto output_data = output.data_ptr<output_t>();
50+
51+
int64_t T = weights_offsets.size(0);
52+
int64_t B = (offsets.size(0) - 1) / T;
53+
TORCH_CHECK_GE(B, 0);
54+
55+
at::parallel_for(0, T, 0, [&](int64_t t_begin, int64_t t_end) {
56+
for (const auto t : c10::irange(t_begin, t_end)) {
57+
int64_t hash_size = 0;
58+
int64_t t_temp = static_cast<int64_t>(t) + 1;
59+
do {
60+
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
61+
++t_temp;
62+
} while (hash_size == 0);
63+
64+
const auto table_begin = weights_offsets_data[t];
65+
66+
bool success = true;
67+
at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
68+
for (const auto b : c10::irange(b_begin, b_end)) {
69+
const auto indices_start = offsets_data[t * B + b];
70+
const auto indices_end = offsets_data[t * B + b + 1];
71+
for (auto i = indices_start; i < indices_end; ++i) {
72+
const auto idx = indices_data[i];
73+
if (idx < 0 || idx >= hash_size) {
74+
success = false;
75+
continue;
76+
}
77+
const auto embedding_offset = table_begin + idx * D;
78+
for (const auto d : c10::irange(D)) {
79+
output_data[i * D + d] =
80+
static_cast<output_t>(weights_data[embedding_offset + d]);
81+
}
82+
}
83+
}
84+
});
85+
86+
if (!success) {
87+
fbgemm_gpu::report_embedding_error(
88+
static_cast<int>(t),
89+
static_cast<int>(B),
90+
0,
91+
static_cast<int>(B),
92+
offsets_data,
93+
indices_data,
94+
hash_size);
95+
}
96+
}
97+
});
98+
}
99+
100+
Tensor split_embedding_nobag_codegen_forward_cpu(
101+
const Tensor& weights,
102+
const Tensor& weights_offsets,
103+
int64_t D,
104+
const Tensor& hash_size_cumsum,
105+
const Tensor& indices,
106+
const Tensor& offsets,
107+
int64_t output_dtype) {
108+
int64_t num_indices = indices.size(0);
109+
auto options = weights.options();
110+
if (output_dtype == static_cast<int64_t>(SparseType::FP32)) {
111+
options = weights.options().dtype(at::kFloat);
112+
} else if (output_dtype == static_cast<int64_t>(SparseType::FP16)) {
113+
options = weights.options().dtype(at::kHalf);
114+
} else if (output_dtype == static_cast<int64_t>(SparseType::BF16)) {
115+
options = weights.options().dtype(at::kBFloat16);
116+
}
117+
Tensor output = at::empty({num_indices, D}, options);
118+
119+
// Dispatch based on indices, offsets, and output types
120+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
121+
output.scalar_type(), "split_embedding_nobag_cpu_forward_1", [&]() {
122+
using output_t = scalar_t;
123+
124+
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
125+
weights.scalar_type(), "split_embedding_nobag_cpu_forward_2", [&] {
126+
using weights_t = scalar_t;
127+
128+
AT_DISPATCH_INDEX_TYPES(
129+
offsets.scalar_type(),
130+
"split_embedding_nobag_cpu_forward_3",
131+
[&] {
132+
using offset_t = index_t;
133+
134+
AT_DISPATCH_INDEX_TYPES(
135+
indices.scalar_type(),
136+
"split_embedding_nobag_cpu_forward_4",
137+
[&] {
138+
split_embedding_nobag_codegen_forward_cpu_kernel<
139+
weights_t,
140+
index_t,
141+
offset_t,
142+
output_t>(
143+
weights,
144+
weights_offsets,
145+
D,
146+
hash_size_cumsum,
147+
indices,
148+
offsets,
149+
output);
150+
});
151+
});
152+
});
153+
});
154+
155+
return output;
156+
}
157+
158+
Tensor split_embedding_nobag_codegen_forward_cpu_meta(
159+
const Tensor& weights,
160+
const Tensor& /* weights_offsets */,
161+
int64_t D,
162+
const Tensor& /* hash_size_cumsum */,
163+
const Tensor& indices,
164+
const Tensor& /* offsets */,
165+
int64_t output_dtype) {
166+
c10::SymInt num_indices = indices.sym_size(0);
167+
auto dtype = weights.options();
168+
if (output_dtype == static_cast<int64_t>(SparseType::FP32)) {
169+
dtype = weights.options().dtype(at::kFloat);
170+
} else if (output_dtype == static_cast<int64_t>(SparseType::FP16)) {
171+
dtype = weights.options().dtype(at::kHalf);
172+
} else if (output_dtype == static_cast<int64_t>(SparseType::BF16)) {
173+
dtype = weights.options().dtype(at::kBFloat16);
174+
}
175+
return at::empty_symint({num_indices, D}, dtype);
176+
}
177+
178+
namespace {
179+
180+
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
181+
m.def(
182+
"split_embedding_nobag_codegen_forward_cpu(Tensor weights, "
183+
" Tensor weights_offsets, "
184+
" int D, "
185+
" Tensor hash_size_cumsum, "
186+
" Tensor indices, "
187+
" Tensor offsets, "
188+
" int output_dtype) -> Tensor");
189+
190+
DISPATCH_TO_CPU(
191+
"split_embedding_nobag_codegen_forward_cpu",
192+
split_embedding_nobag_codegen_forward_cpu);
193+
194+
DISPATCH_TO_META(
195+
"split_embedding_nobag_codegen_forward_cpu",
196+
split_embedding_nobag_codegen_forward_cpu_meta);
197+
}
198+
} // namespace

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,26 +100,35 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
100100
feature_requires_grad);
101101
}
102102
{%- endif %}
103+
103104
{%- for weighted in [True, False] %}
104105
{%- set wdesc = "weighted" if weighted else "unweighted" %}
106+
{%- for nobag in ([False] if (weighted or vbe) else [True, False]) %}
107+
{%- set ndesc = "_nobag" if nobag else "" %}
105108

106109
{% if is_forward %}
107110
{#-/* PT2 wrapper function for forward CPU */#}
108-
Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
111+
Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
109112
const Tensor& host_weights,
110113
const Tensor& /*dev_weights*/,
111114
const Tensor& /*uvm_weights*/,
112115
const Tensor& /*lxu_cache_weights*/,
113116
const Tensor& /*weights_placements*/,
114117
const Tensor& weights_offsets,
118+
{%- if nobag %}
119+
const c10::SymInt D,
120+
{%- else %}
115121
const Tensor& D_offsets,
116122
const c10::SymInt total_D,
117123
const c10::SymInt /*max_D*/,
124+
{%- endif %}
118125
const Tensor& hash_size_cumsum,
119126
const Tensor& indices,
120127
const Tensor& offsets,
128+
{%- if not nobag %}
121129
const int64_t pooling_mode,
122130
const Tensor& indice_weights,
131+
{%- endif %}
123132
const Tensor& /*lxu_cache_locations*/,
124133
const Tensor& /*uvm_cache_stats*/,
125134
{%- if vbe %}
@@ -142,11 +151,34 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
142151
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B_int, D_offsets.numel() - 1);
143152
});
144153
{%- endif %}
154+
{%- set op = "split_embedding{}_codegen_forward_cpu".format(
155+
ndesc
156+
)
157+
%}
145158
static auto op =
146159
torch::Dispatcher::singleton()
147-
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
160+
.findSchemaOrThrow("fbgemm::{{ op }}", "")
148161
.typed<Tensor(
149-
Tensor, Tensor, Tensor, c10::SymInt, Tensor, Tensor, Tensor, int64_t, Tensor, int64_t
162+
{%- if nobag %}
163+
const Tensor&, /*weights*/
164+
const Tensor&, /*weights_offsets*/
165+
c10::SymInt, /*D*/
166+
const Tensor&, /*hash_size_cumsum*/
167+
const Tensor&, /*indices*/
168+
const Tensor&, /*offsets*/
169+
int64_t /*output_dtype*/
170+
{%- else %}
171+
Tensor, /*weights*/
172+
Tensor, /*weights_offsets*/
173+
Tensor, /*D_offsets*/
174+
c10::SymInt, /*total_D*/
175+
Tensor, /*hash_size_cumsum*/
176+
Tensor, /*indices*/
177+
Tensor, /*offsets*/
178+
int64_t, /*pooling_mode*/
179+
Tensor, /*indice_weights*/
180+
int64_t /*output_dtype*/
181+
{%- endif %}
150182
)>();
151183
{%- if vbe %}
152184
// TODO: remove this after vbe is implemented for CPU kernel
@@ -189,18 +221,25 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
189221
return op.call(
190222
host_weights,
191223
weights_offsets,
224+
{%- if nobag %}
225+
D,
226+
{%- else %}
192227
D_offsets,
193228
total_D,
229+
{%- endif %}
194230
hash_size_cumsum,
195231
indices,
196232
offsets,
233+
{%- if not nobag %}
197234
pooling_mode,
198235
indice_weights,
236+
{%- endif %}
199237
output_dtype);
200238
{%- endif %}
201239
}
202240
{% else %}
203241
{#-/* PT2 wrapper function for backward CPU */#}
242+
{%- if not nobag %}
204243
Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
205244
const Tensor& grad_output,
206245
const Tensor& host_weights,
@@ -296,18 +335,22 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
296335
);
297336
return Tensor();
298337
}
338+
{% endif %} {#-/*if not nobag*/#}
299339
{% endif %}
340+
{%- endfor %} {#-/*for nobag*/#}
300341
{%- endfor %} {#-/*for weighted*/#}
301342

302343

303344
namespace {
304345
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
305346
{%- for weighted in [True, False] %}
306347
{%- set wdesc = "weighted" if weighted else "unweighted" %}
348+
{%- for nobag in ([False] if (weighted or vbe) else [True, False]) %}
349+
{%- set ndesc = "_nobag" if nobag else "" %}
307350

308351
{%- if is_forward %}
309-
{%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format(
310-
wdesc, vdesc
352+
{%- set embedding_codegen_forward_op = "split_embedding{}_codegen_forward_{}{}_pt2".format(
353+
ndesc, wdesc, vdesc
311354
)
312355
%}
313356

@@ -360,6 +403,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
360403
DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper);
361404

362405
{%- else %} {#-/* backward */#}
406+
{%- if not nobag %}
363407
{%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format(
364408
optimizer, wdesc, vdesc
365409
)
@@ -410,7 +454,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
410454
{%- endif %}
411455
") -> Tensor");
412456
DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper);
457+
{%- endif %} {#-/*if not nobag*/#}
413458
{%- endif %} {#-/*if is_forward*/#}
459+
{%- endfor %} {#-/*for nobag*/#}
414460
{%- endfor %} {#-/*for weighted*/#}
415461

416462
{%- if is_forward %}

0 commit comments

Comments
 (0)