|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/Parallel.h> |
| 10 | +#include <ATen/core/op_registration/op_registration.h> |
| 11 | +#include <torch/script.h> |
| 12 | + |
| 13 | +#include "fbgemm_gpu/embedding_common.h" |
| 14 | +#include "fbgemm_gpu/utils/cpu_utils.h" |
| 15 | +#include "fbgemm_gpu/utils/dispatch_macros.h" |
| 16 | +#include "fbgemm_gpu/utils/ops_utils.h" |
| 17 | + |
| 18 | +#if FBGEMM_GPU_MEMCHECK |
| 19 | +#define FBGEMM_MEM_CHECK_ONLY |
| 20 | +#else |
| 21 | +#define FBGEMM_MEM_CHECK_ONLY maybe_unused |
| 22 | +#endif |
| 23 | + |
| 24 | +using Tensor = at::Tensor; |
| 25 | +using namespace fbgemm_gpu; |
| 26 | + |
| 27 | +template < |
| 28 | + typename weights_t, |
| 29 | + typename index_t, |
| 30 | + typename offset_t, |
| 31 | + typename output_t> |
| 32 | +void split_embedding_nobag_codegen_forward_cpu_kernel( |
| 33 | + const Tensor& weights, |
| 34 | + const Tensor& weights_offsets, |
| 35 | + int64_t D, |
| 36 | + const Tensor& hash_size_cumsum, |
| 37 | + const Tensor& indices, |
| 38 | + const Tensor& offsets, |
| 39 | + const Tensor& output) { |
| 40 | + TORCH_CHECK(weights.is_contiguous()); |
| 41 | + Tensor indices_contig = indices.contiguous(); |
| 42 | + Tensor offsets_contig = offsets.contiguous(); |
| 43 | + |
| 44 | + const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>(); |
| 45 | + const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>(); |
| 46 | + const auto indices_data = indices.data_ptr<index_t>(); |
| 47 | + const auto offsets_data = offsets.data_ptr<offset_t>(); |
| 48 | + const auto weights_data = weights.data_ptr<weights_t>(); |
| 49 | + auto output_data = output.data_ptr<output_t>(); |
| 50 | + |
| 51 | + int64_t T = weights_offsets.size(0); |
| 52 | + int64_t B = (offsets.size(0) - 1) / T; |
| 53 | + TORCH_CHECK_GE(B, 0); |
| 54 | + |
| 55 | + at::parallel_for(0, T, 0, [&](int64_t t_begin, int64_t t_end) { |
| 56 | + for (const auto t : c10::irange(t_begin, t_end)) { |
| 57 | + int64_t hash_size = 0; |
| 58 | + int64_t t_temp = static_cast<int64_t>(t) + 1; |
| 59 | + do { |
| 60 | + hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t]; |
| 61 | + ++t_temp; |
| 62 | + } while (hash_size == 0); |
| 63 | + |
| 64 | + const auto table_begin = weights_offsets_data[t]; |
| 65 | + |
| 66 | + bool success = true; |
| 67 | + at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) { |
| 68 | + for (const auto b : c10::irange(b_begin, b_end)) { |
| 69 | + const auto indices_start = offsets_data[t * B + b]; |
| 70 | + const auto indices_end = offsets_data[t * B + b + 1]; |
| 71 | + for (auto i = indices_start; i < indices_end; ++i) { |
| 72 | + const auto idx = indices_data[i]; |
| 73 | + if (idx < 0 || idx >= hash_size) { |
| 74 | + success = false; |
| 75 | + continue; |
| 76 | + } |
| 77 | + const auto embedding_offset = table_begin + idx * D; |
| 78 | + for (const auto d : c10::irange(D)) { |
| 79 | + output_data[i * D + d] = |
| 80 | + static_cast<output_t>(weights_data[embedding_offset + d]); |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + }); |
| 85 | + |
| 86 | + if (!success) { |
| 87 | + fbgemm_gpu::report_embedding_error( |
| 88 | + static_cast<int>(t), |
| 89 | + static_cast<int>(B), |
| 90 | + 0, |
| 91 | + static_cast<int>(B), |
| 92 | + offsets_data, |
| 93 | + indices_data, |
| 94 | + hash_size); |
| 95 | + } |
| 96 | + } |
| 97 | + }); |
| 98 | +} |
| 99 | + |
| 100 | +Tensor split_embedding_nobag_codegen_forward_cpu( |
| 101 | + const Tensor& weights, |
| 102 | + const Tensor& weights_offsets, |
| 103 | + int64_t D, |
| 104 | + const Tensor& hash_size_cumsum, |
| 105 | + const Tensor& indices, |
| 106 | + const Tensor& offsets, |
| 107 | + int64_t output_dtype) { |
| 108 | + int64_t num_indices = indices.size(0); |
| 109 | + auto options = weights.options(); |
| 110 | + if (output_dtype == static_cast<int64_t>(SparseType::FP32)) { |
| 111 | + options = weights.options().dtype(at::kFloat); |
| 112 | + } else if (output_dtype == static_cast<int64_t>(SparseType::FP16)) { |
| 113 | + options = weights.options().dtype(at::kHalf); |
| 114 | + } else if (output_dtype == static_cast<int64_t>(SparseType::BF16)) { |
| 115 | + options = weights.options().dtype(at::kBFloat16); |
| 116 | + } |
| 117 | + Tensor output = at::empty({num_indices, D}, options); |
| 118 | + |
| 119 | + // Dispatch based on indices, offsets, and output types |
| 120 | + FBGEMM_DISPATCH_FLOAT_AND_HALF( |
| 121 | + output.scalar_type(), "split_embedding_nobag_cpu_forward_1", [&]() { |
| 122 | + using output_t = scalar_t; |
| 123 | + |
| 124 | + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( |
| 125 | + weights.scalar_type(), "split_embedding_nobag_cpu_forward_2", [&] { |
| 126 | + using weights_t = scalar_t; |
| 127 | + |
| 128 | + AT_DISPATCH_INDEX_TYPES( |
| 129 | + offsets.scalar_type(), |
| 130 | + "split_embedding_nobag_cpu_forward_3", |
| 131 | + [&] { |
| 132 | + using offset_t = index_t; |
| 133 | + |
| 134 | + AT_DISPATCH_INDEX_TYPES( |
| 135 | + indices.scalar_type(), |
| 136 | + "split_embedding_nobag_cpu_forward_4", |
| 137 | + [&] { |
| 138 | + split_embedding_nobag_codegen_forward_cpu_kernel< |
| 139 | + weights_t, |
| 140 | + index_t, |
| 141 | + offset_t, |
| 142 | + output_t>( |
| 143 | + weights, |
| 144 | + weights_offsets, |
| 145 | + D, |
| 146 | + hash_size_cumsum, |
| 147 | + indices, |
| 148 | + offsets, |
| 149 | + output); |
| 150 | + }); |
| 151 | + }); |
| 152 | + }); |
| 153 | + }); |
| 154 | + |
| 155 | + return output; |
| 156 | +} |
| 157 | + |
| 158 | +Tensor split_embedding_nobag_codegen_forward_cpu_meta( |
| 159 | + const Tensor& weights, |
| 160 | + const Tensor& /* weights_offsets */, |
| 161 | + int64_t D, |
| 162 | + const Tensor& /* hash_size_cumsum */, |
| 163 | + const Tensor& indices, |
| 164 | + const Tensor& /* offsets */, |
| 165 | + int64_t output_dtype) { |
| 166 | + c10::SymInt num_indices = indices.sym_size(0); |
| 167 | + auto dtype = weights.options(); |
| 168 | + if (output_dtype == static_cast<int64_t>(SparseType::FP32)) { |
| 169 | + dtype = weights.options().dtype(at::kFloat); |
| 170 | + } else if (output_dtype == static_cast<int64_t>(SparseType::FP16)) { |
| 171 | + dtype = weights.options().dtype(at::kHalf); |
| 172 | + } else if (output_dtype == static_cast<int64_t>(SparseType::BF16)) { |
| 173 | + dtype = weights.options().dtype(at::kBFloat16); |
| 174 | + } |
| 175 | + return at::empty_symint({num_indices, D}, dtype); |
| 176 | +} |
| 177 | + |
| 178 | +namespace { |
| 179 | + |
| 180 | +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { |
| 181 | + m.def( |
| 182 | + "split_embedding_nobag_codegen_forward_cpu(Tensor weights, " |
| 183 | + " Tensor weights_offsets, " |
| 184 | + " int D, " |
| 185 | + " Tensor hash_size_cumsum, " |
| 186 | + " Tensor indices, " |
| 187 | + " Tensor offsets, " |
| 188 | + " int output_dtype) -> Tensor"); |
| 189 | + |
| 190 | + DISPATCH_TO_CPU( |
| 191 | + "split_embedding_nobag_codegen_forward_cpu", |
| 192 | + split_embedding_nobag_codegen_forward_cpu); |
| 193 | + |
| 194 | + DISPATCH_TO_META( |
| 195 | + "split_embedding_nobag_codegen_forward_cpu", |
| 196 | + split_embedding_nobag_codegen_forward_cpu_meta); |
| 197 | +} |
| 198 | +} // namespace |
0 commit comments