Skip to content

Commit 75fcf03

Browse files
shiyang-wengjansel
authored andcommitted
Add dont constant fold flag (pytorch#154945)
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: pytorch#154945 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
1 parent 14120ef commit 75fcf03

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13550,6 +13550,40 @@ def test_special_polygamma(self):
1355013550
self.common(fn, (1, x))
1355113551
self.common(fn, (2, x))
1355213552

13553+
@config.patch({"freezing": True})
13554+
def test_dont_constant_fold(self):
13555+
from torch._inductor.constant_folding import (
13556+
add_dont_constant_fold,
13557+
clear_dont_constant_fold,
13558+
)
13559+
13560+
m = 5
13561+
13562+
class M(torch.nn.Module):
13563+
def __init__(self):
13564+
super().__init__()
13565+
self.w = torch.randn(m)
13566+
self.s = torch.randn(m)
13567+
13568+
def forward(self, x):
13569+
return self.w * self.s + x
13570+
13571+
x = torch.rand(m)
13572+
mod = M()
13573+
for dont_constant_fold in [True, False]:
13574+
clear_dont_constant_fold()
13575+
if dont_constant_fold:
13576+
add_dont_constant_fold(torch.ops.aten.mul.Tensor)
13577+
with torch.no_grad():
13578+
refe_out = mod(x)
13579+
mod = torch.compile(mod)
13580+
test_out, (code,) = run_and_get_code(mod, x)
13581+
if dont_constant_fold:
13582+
FileCheck().check("cpp_fused_add_mul").run(code)
13583+
else:
13584+
FileCheck().check("cpp_fused_add_0").run(code)
13585+
self.assertEqual(refe_out, test_out)
13586+
1355313587

1355413588
@dataclasses.dataclass
1355513589
class TestFailure:

torch/_inductor/constant_folding.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616
MODULE_TAG = "_MAIN_MODULE"
1717
CONST_MODULE_TAG = "_CONST_MODULE"
1818

19+
_dont_constant_fold: list[torch.fx.node.Target] = []
20+
21+
22+
def add_dont_constant_fold(op: torch.fx.node.Target) -> None:
23+
global _dont_constant_fold
24+
_dont_constant_fold.append(op)
25+
26+
27+
def clear_dont_constant_fold() -> None:
28+
global _dont_constant_fold
29+
_dont_constant_fold.clear()
30+
1931

2032
def replace_node_with_constant(
2133
gm: torch.fx.GraphModule,
@@ -146,6 +158,9 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
146158
# We only folding fp32_weight -> q
147159
# int8_weight and leave dq in graph to be fused
148160
return True
161+
162+
if node.target in _dont_constant_fold:
163+
return True
149164
return False
150165

151166
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:

0 commit comments

Comments
 (0)