You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix int_nbit inference int8 nobag kernel meta function
Summary:
**TLDR;**
Fix int8 nobag in TBE inference meta function such that
- output shape is {total_L, D + kINT8QparamsBytes}
- kINT8QparamsBytes = 4
**Detail**
For nobag int8, the output shape should be `{total_L, D + kINT8QparamsBytes}`, since `total_L` dimension already includes `T`. `T *` was unintentionally added in D36018114.
`kINT8QparamsBytes` is 4 in CPU, since a half is used. However, 8 is used in CUDA.
Our meta implementation follows CUDA implementation which mismatches that of CPU.
This diff removes `T*` from the output shape and change `kINT8QparamsBytes` to be 4 for meta implementation to match CPU and production.
There has been no issue because our meta function is not being used and int8 nobag CUDA kernel is not currently used in production.
CUDA kernel changes will be in the next diff.
----
Note that this is currently used meta function is [fbgemm_int_nbit_split_embedding_codegen_lookup_function_meta](https://www.internalfb.com/code/fbsource/[d4f61c30f747f0a8c2e6d806904bc8ef3ee5ea42]/fbcode/caffe2/torch/fb/model_transform/splitting/split_dispatcher.py?lines=231%2C423), which has different logic for int8 and nobag cases.
The discrepancy has not been an issue because:
- Nobag
- split_dispatcher: D = average D
- FBGEMM: D = max(max_D of each dtype)
-> The embedding dimensions are the same, so average D = max D.
- Int8 Pooled
- split_dispatcher: [B, total_D] here
- FBGEMM: [B, total_D + T * 8]
-> This is not being used in prod
This will be a problem if embedding dimensions are mixed, or int8 pooled is going to be used.
Reviewed By: q10
Differential Revision: D75808485
fbshipit-source-id: 0765ca258c04c45234938f9b6d13837635b1fa93
0 commit comments