-
Notifications
You must be signed in to change notification settings - Fork 292
[CPU] Enable DA8W4 on CPU #2128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] Enable DA8W4 on CPU #2128
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2128
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e3731f7 with merge base 4ebc9c0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@leslie-fang-intel This PR is updated to use a new layout. Please review again. Thanks. |
Hi @jerryzh168 Could you please review this PR? Thanks. |
2 similar comments
Hi @jerryzh168 Could you please review this PR? Thanks. |
Hi @jerryzh168 Could you please review this PR? Thanks. |
Hi @leslie-fang-intel Please review this PR again. I have also added the kernel code in this PR. It showed reasonable performance in internal benchmarks. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also describe how we choose different implementations based on the CPU Info.
I have added more details in the description. Thanks. |
Hi @jerryzh168 Could you please review this PR? Thanks. It's changed a lot since your last review. |
Hi @jerryzh168 Could you please review this PR? Thanks. |
|
||
|
||
@dataclass(frozen=True) | ||
class Int8DynamicActInt4WeightCPULayout(Layout): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like you can just reuse Int4CPULayout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move the layout and impl to a separate file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Done.
|
||
|
||
@register_layout(Int8DynamicActInt4WeightCPULayout) | ||
class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see, OK if you need a separate Impl then makes sense to have a separate layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We need a different impl from W16W4 because the ISA (AMX and VNNI) requires different memory formats of weight for computation in BF16 or INT8. Thanks.
int_data = (int_data + 8).to(torch.uint8) | ||
if scale.dim() == 1: | ||
scale.unsqueeze_(-1) | ||
scale = scale.to(torch.float) | ||
if zero_point.dim() == 1: | ||
zero_point.unsqueeze_(-1) | ||
zero_point = zero_point.to(torch.int8) + 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you configure dtypes of int_data, scale, zero_point and shapes in the call to to_affine_quantized_intx
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. I have improved this part.
assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] | ||
quantize_( | ||
m2, | ||
int8_dynamic_activation_int4_weight( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you use the new API: Int8DynamicActivationInt4WeightConfig
instead of int8_dynamic_activation_int4_weight
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Done.
torchao/quantization/quant_api.py
Outdated
@@ -728,9 +761,17 @@ def _int8_dynamic_activation_int4_weight_transform( | |||
quant_min = -8 | |||
quant_max = 7 | |||
|
|||
if isinstance(layout, Int8DynamicActInt4WeightCPULayout): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this happen in kernel? we have dtype conversions like this:
ao/torchao/dtypes/uintx/plain_layout.py
Line 260 in 2898903
w_vals_int8_t.to(input_tensor.dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. I have moved this to _linear_int8_act_int4_weight_cpu_impl
.
Summary
This PR enables DA8W4 on CPU.
Int8DynamicActInt4WeightCPULayout
and its implementationda8w4_linear_prepack_cpu
for weight packingda8w4_linear_cpu
for A8W4 GEMM.The feature supports symmetric and asymmetric quantization of activation.
The ops and kernels won't be available unless
USE_CPP_KERNELS=1
on Linux with an X86 CPU with AVX512.To get the best performance, one needs a CPU with AMX support.
Implementation details
at::cpublas
brgemm utilities from Pytorch core if available.Usage
Test plan