|
9 | 9 | // clang-format off
|
10 | 10 | {%- set wdesc = "weighted" if weighted else "unweighted" %}
|
11 | 11 | #include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
|
| 12 | +#include "fbgemm_gpu/utils/kernel_launcher.cuh" |
12 | 13 | #include "fbgemm_gpu/utils/tensor_accessor_builder.h"
|
13 | 14 | #include "fbgemm_gpu/config/feature_gates.h"
|
14 | 15 |
|
@@ -63,51 +64,46 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
|
63 | 64 | {%- macro define_kernel_invocation(emb_weight_type) %}
|
64 | 65 | {%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %}
|
65 | 66 |
|
66 |
| - #ifdef FBGEMM_GPU_MEMCHECK |
67 |
| - const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}"; |
68 |
| - #endif |
69 |
| - |
70 | 67 | #ifdef X
|
71 | 68 | #undef X
|
72 | 69 | #endif
|
73 | 70 |
|
74 |
| - // Define {{ emb_weight_type }} kernel invocation macro |
75 | 71 | #define X(DeviceOnly, PackedMode, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \
|
76 |
| - {{ func_name }}<index_t, output_t, OutputRowsPerThread, kWarpsPerBlock, InputRowsInFlight, MinNum128BRows, MaxNum128BRows, DeviceOnly, PackedMode><<< \ |
| 72 | + FBGEMM_LAUNCH_KERNEL( \ |
| 73 | + ({{ func_name }}<index_t, output_t, OutputRowsPerThread, kWarpsPerBlock, InputRowsInFlight, MinNum128BRows, MaxNum128BRows, DeviceOnly, PackedMode>), \ |
77 | 74 | nbit::div_round_up(T * nbit::div_round_up(B, num_packed_bags * OutputRowsPerThread), kWarpsPerBlock), \
|
78 | 75 | dim3(kWarpSize, kWarpsPerBlock), \
|
79 | 76 | 0, \
|
80 |
| - at::cuda::getCurrentCUDAStream()>>>( \ |
81 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, dev_weights, uint8_t, 1, 64), \ |
82 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, uvm_weights, uint8_t, 1, 64), \ |
83 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_placements, int32_t, 1, 32), \ |
84 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_offsets, int64_t, 1, 32), \ |
85 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_tys, uint8_t, 1, 32), \ |
| 77 | + at::cuda::getCurrentCUDAStream(), \ |
| 78 | + PTA_B(dev_weights, uint8_t, 1, 64), \ |
| 79 | + PTA_B(uvm_weights, uint8_t, 1, 64), \ |
| 80 | + PTA_B(weights_placements, int32_t, 1, 32), \ |
| 81 | + PTA_B(weights_offsets, int64_t, 1, 32), \ |
| 82 | + PTA_B(weights_tys, uint8_t, 1, 32), \ |
86 | 83 | {%- if not nobag %}
|
87 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, D_offsets, int32_t, 1, 32), \ |
| 84 | + PTA_B(D_offsets, int32_t, 1, 32), \ |
88 | 85 | {%- else %}
|
89 | 86 | D, \
|
90 | 87 | {%- endif %}
|
91 | 88 | FixedDivisor(div_round_up(B, num_packed_bags * OutputRowsPerThread)), \
|
92 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, indices, index_t, 1, 32), \ |
93 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, offsets, index_t, 1, 32), \ |
| 89 | + PTA_B(indices, index_t, 1, 32), \ |
| 90 | + PTA_B(offsets, index_t, 1, 32), \ |
94 | 91 | {%- if not nobag %}
|
95 | 92 | pooling_mode, \
|
96 | 93 | {%- endif %}
|
97 | 94 | row_alignment, \
|
98 | 95 | {%- if weighted %}
|
99 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, indice_weights, float, 1, 32), \ |
| 96 | + PTA_B(indice_weights, float, 1, 32), \ |
100 | 97 | {%- endif %}
|
101 | 98 | {%- if emb_weight_type == "FP8" %}
|
102 | 99 | fp8_exponent_bits, \
|
103 | 100 | fp8_exponent_bias, \
|
104 | 101 | {%- endif %}
|
105 | 102 | num_packed_bags, \
|
106 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, output, output_t, 2, 32), \ |
107 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_weights, uint8_t, 2, 64), \ |
108 |
| - MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_locations, int32_t, 1, 32) \ |
109 |
| - ); \ |
110 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ |
| 103 | + PTA_B(output, output_t, 2, 32), \ |
| 104 | + PTA_B(lxu_cache_weights, uint8_t, 2, 64), \ |
| 105 | + PTA_B(lxu_cache_locations, int32_t, 1, 32) \ |
| 106 | + ); |
111 | 107 | {%- endmacro %}
|
112 | 108 |
|
113 | 109 | {%- macro construct_and_return_output_tensor() %}
|
|
0 commit comments