Skip to content

Commit e1ce5be

Browse files
q10facebook-github-bot
authored andcommitted
Prevent duplicate operator registrations
Summary: - This is a simple hack to prevent duplicator operator registrations Differential Revision: D76469487
1 parent 87a3770 commit e1ce5be

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
{%- if has_vbe_support %}
3333
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
3434
{%- endif %}
35+
#include "fbgemm_gpu/utils/torch_schema.h"
3536

3637
using Tensor = at::Tensor;
3738
using namespace fbgemm_gpu;
@@ -352,7 +353,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
352353
ndesc, wdesc, vdesc
353354
)
354355
%}
355-
{%- if not nobag %} {#-/*nobag schema is registered in cuda template*/#}
356+
357+
358+
if (!utils::torch::schemaExists("fbgemm::{{ embedding_codegen_forward_op }}_wrapper")) {
356359
m.def("{{ embedding_codegen_forward_op }}_wrapper("
357360
" Tensor host_weights, "
358361
" Tensor dev_weights, "
@@ -396,7 +399,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
396399
, {PT2_COMPLIANT_TAG}
397400
{%- endif %}
398401
);
399-
{%- endif %} {#-/*if not nobag*/#}
402+
}
400403
DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper);
401404

402405
{%- else %} {#-/* backward */#}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/core/dispatch/Dispatcher.h>
10+
11+
namespace fbgemm_gpu::utils::torch {
12+
13+
inline bool schemaExists(const std::string& qualified_name) {
14+
return c10::Dispatcher::realSingleton()
15+
.findSchema({qualified_name, ""})
16+
.has_value();
17+
}
18+
19+
} // namespace fbgemm_gpu::utils::torch

0 commit comments

Comments
 (0)