Skip to content

Commit 5eb6d68

Browse files
spcypptfacebook-github-bot
authored andcommitted
Fix int_nbit inference int8 nobag kernel meta function
Summary: **TLDR;** Fix int8 nobag in TBE inference meta function such that - output shape is {total_L, D + kINT8QparamsBytes} - kINT8QparamsBytes = 4 **Detail** For nobag int8, the output shape should be `{total_L, D + kINT8QparamsBytes}`, since `total_L` dimension already includes `T`. `T *` was unintentionally added in D36018114. `kINT8QparamsBytes` is 4 in CPU, since a half is used. However, 8 is used in CUDA. Our meta implementation follows CUDA implementation which mismatches that of CPU. This diff removes `T*` from the output shape and change `kINT8QparamsBytes` to be 4 for meta implementation to match CPU and production. There has been no issue because our meta function is not being used and int8 nobag CUDA kernel is not currently used in production. CUDA kernel changes will be in the next diff. ---- Note that this is currently used meta function is [fbgemm_int_nbit_split_embedding_codegen_lookup_function_meta](https://www.internalfb.com/code/fbsource/[d4f61c30f747f0a8c2e6d806904bc8ef3ee5ea42]/fbcode/caffe2/torch/fb/model_transform/splitting/split_dispatcher.py?lines=231%2C423), which has different logic for int8 and nobag cases. The discrepancy has not been an issue because: - Nobag - split_dispatcher: D = average D - FBGEMM: D = max(max_D of each dtype) -> The embedding dimensions are the same, so average D = max D. - Int8 Pooled - split_dispatcher: [B, total_D] here - FBGEMM: [B, total_D + T * 8] -> This is not being used in prod This will be a problem if embedding dimensions are mixed, or int8 pooled is going to be used. Reviewed By: q10 Differential Revision: D75808485 fbshipit-source-id: 0765ca258c04c45234938f9b6d13837635b1fa93
1 parent d473af0 commit 5eb6d68

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def int_nbit_split_embedding_codegen_lookup_function_meta(
420420
kINT8QparamsBytes = 8
421421

422422
if pooling_mode == PoolingMode.NONE:
423+
kINT8QparamsBytes = 4
423424
D = max(
424425
[
425426
max_int2_D,
@@ -435,7 +436,7 @@ def int_nbit_split_embedding_codegen_lookup_function_meta(
435436
torch._check(D > 0)
436437
adjusted_D = D
437438
if SparseType.from_int(output_dtype_int) == SparseType.INT8:
438-
adjusted_D += T * kINT8QparamsBytes
439+
adjusted_D += kINT8QparamsBytes
439440
output = dev_weights.new_empty([total_L, adjusted_D], dtype=output_dtype)
440441
return output
441442

0 commit comments

Comments
 (0)