Skip to content

[CPU] Enable DA8W4 on CPU #2128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0581451
[CPU] enable int8_dynamic_activation_int4_weight with Int4CPULayout
Xia-Weiwen Apr 25, 2025
dffbbab
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Apr 25, 2025
9fb7f77
Fix format issue
Xia-Weiwen Apr 25, 2025
35ece3b
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Apr 28, 2025
c5b6d87
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 12, 2025
8e80d03
Add Int8DynamicActInt4WeightCPULayout
Xia-Weiwen May 14, 2025
51249c3
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 15, 2025
3e20172
remove dispatch for t()
Xia-Weiwen May 16, 2025
e765664
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 21, 2025
4feac3f
Add cpp kernel for weight packing and GEMM
Xia-Weiwen May 23, 2025
0d85183
Register ATQ linear dispatch for da8w4 linear
Xia-Weiwen May 25, 2025
c42abdb
Fix issues with torch.compile
Xia-Weiwen May 26, 2025
e2815ce
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 26, 2025
8c5eebb
Fix DA8W4CPUAQTTensorImpl.get_plain
Xia-Weiwen May 26, 2025
2a26e15
Test DA8W4CPUAQTTensorImpl.get_plain in UT
Xia-Weiwen May 26, 2025
369000f
Skip UT if CPP kernel not built
Xia-Weiwen May 26, 2025
f6e87ba
Add AVX512_VNNI implementation for small M
Xia-Weiwen May 27, 2025
0a87ef0
improve performance
Xia-Weiwen Jun 3, 2025
e05b96a
Support symmetric quantization of activation
Xia-Weiwen Jun 4, 2025
fd6e4b1
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 4, 2025
18335c6
Refine code
Xia-Weiwen Jun 4, 2025
66ab77f
Refine code
Xia-Weiwen Jun 5, 2025
2c5a799
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 5, 2025
131660e
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 11, 2025
75fbd6f
Put in a separate file
Xia-Weiwen Jun 14, 2025
24268fd
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 14, 2025
4c0a739
Bug fix
Xia-Weiwen Jun 25, 2025
0815d96
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 25, 2025
e3731f7
refine code
Xia-Weiwen Jun 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,20 +385,29 @@ def get_extensions():
extra_compile_args["cxx"].extend(
["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"]
)
if (
use_cpu_kernels
and is_linux
and hasattr(torch._C._cpu, "_is_avx512_supported")
and torch._C._cpu._is_avx512_supported()
):
extra_compile_args["cxx"].extend(
[
"-DCPU_CAPABILITY_AVX512",
"-march=native",
"-mfma",
"-fopenmp",
]
)

if use_cpu_kernels and is_linux:
if (
hasattr(torch._C._cpu, "_is_avx512_supported")
and torch._C._cpu._is_avx512_supported()
):
extra_compile_args["cxx"].extend(
[
"-DCPU_CAPABILITY_AVX512",
"-march=native",
"-mfma",
"-fopenmp",
]
)
if (
hasattr(torch._C._cpu, "_is_avx512_vnni_supported")
and torch._C._cpu._is_avx512_vnni_supported()
):
extra_compile_args["cxx"].extend(
[
"-DCPU_CAPABILITY_AVX512_VNNI",
]
)

if debug_mode:
extra_compile_args["cxx"].append("-g")
Expand Down
68 changes: 68 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AffineQuantizedTensor,
Int4CPULayout,
Int4XPULayout,
Int8DynamicActInt4WeightCPULayout,
PlainLayout,
QDQLayout,
TensorCoreTiledLayout,
Expand Down Expand Up @@ -70,6 +71,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_90,
Expand Down Expand Up @@ -695,6 +697,72 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [1, 160])
@common_utils.parametrize("sym_quant_a", [True, False])
def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8:
# not supported until PT 2.8
return
device = "cpu"
m = ToyLinearModel(bias=bias).eval().to(dtype).to(device)
m2 = copy.deepcopy(m)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)

with torch.no_grad():
# Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout
# is that the former packs two int4 weights into one int8, while the latter does not.
quantize_(
m,
Int8DynamicActivationInt4WeightConfig(
group_size=32,
layout=Int8DynamicActInt4WeightCPULayout(),
act_mapping_type=MappingType.SYMMETRIC
if sym_quant_a
else MappingType.ASYMMETRIC,
),
)
y, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0]
quantize_(
m2,
int8_dynamic_activation_int4_weight(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you use the new API: Int8DynamicActivationInt4WeightConfig instead of int8_dynamic_activation_int4_weight?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done.

group_size=32,
layout=PlainLayout(),
act_mapping_type=MappingType.SYMMETRIC
if sym_quant_a
else MappingType.ASYMMETRIC,
),
)
torch._dynamo.reset() # may segfault without this
y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs)
atol, rtol = 4e-7, 1e-5
if dtype == torch.bfloat16:
atol, rtol = 1e-2, 3e-3
elif dtype == torch.half:
atol, rtol = 6e-3, 2e-3
assert torch.allclose(y, y2, atol=atol, rtol=rtol)
# Test get_plain by dequantize()
dqw1 = m.linear1.weight.original_weight_tensor.dequantize()
dqw2 = m.linear2.weight.original_weight_tensor.dequantize()
dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize()
dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize()
assert torch.allclose(dqw1, dqw1_ref)
assert torch.allclose(dqw2, dqw2_ref)

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
Loading
Loading