-
Notifications
You must be signed in to change notification settings - Fork 283
[Quant] Can quant not be decomposed on inductor? #2228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
yeah we use Line 180 in 96aec6a
do you want the op to be preserved in inductor? |
Yes. There is an issue that fp8 weight will be fixed to fp32 weight on constant_fold. And quant/dequant decomposition will make the pattern complicated. Can we not decompose here? |
Hi @jerryzh168 , I'm not sure if removing decompose here would cause any other issues. |
Hi @jerryzh168 , do you have any suggestions? |
Hi @jerryzh168 Please let me explain the whole story. PyTorch core:
Torchao:
So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch/pytorch#153601.
Do you think the short term solution makes sense? And for the solution with Torchao, do you have more comments or concerns? We are looking forward to your suggestions. Thanks. |
@Xia-Weiwen thanks for the clear summary. I have duplicated the constant_fold code in torchao:
torchao.dequantize_affine_float8 there?
I agree that for the longer term, inductor should allow registration for impure ops, cc @eellison @jansel for |
Is dequantize impure? What is it mutating? IMO this op should be decomposed in inductor. You can register the decomp in the same place the op is defined. |
@jansel technically it's not, but we may need to preserve dequantize op so it can be fused with other ops to become a quantized op that takes integer tensor as input. is there a different way to specify this? |
Impure isn't what you are looking for. Impure means the op mutates one of its inputs, so when we functionalize we need to introduce more copies (which might increase memory usage if inductor cant optimize the copies away). Ops will be preserved if you don't write a decomp for them, which forces them to be ExternKernels and prevents fusion with other ops. |
what about for constant folding? what prevents an op to be constant folded (except for marking them as impure)? I think that's the original reason we marked these ops as impure |
I don't believe we have a dont-constant-fold flag (correct me if I'm wrong @eellison ), though maybe we should. |
Thanks for your replies.
@jerryzh168 If I understand correctly, the duplicate code is used in @jansel There are patterns like |
this makes sense, how does it work before? also as Jason mentioned if you don't register decomposition for it, it won't be decomposed, maybe we could try adding an option to skip the registration here: Line 228 in 4d5f657
|
I will do it. Plan to
|
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via #153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via #153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: #154945 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: pytorch#154945 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: pytorch#154945 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via #153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via #153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: #154945 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
For support pytorch/ao#2228 > What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph. > > However we met problems with these q/dq ops both in the PyTorch core and Torchao. > > PyTorch core: > > The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated. > Torchao: > > In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor: > https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1 > After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now. > For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because > It is an op from Torchao, which is unknown to the constant folder > It is decomposed to smaller ops, so we cannot put it in the list as a single op. > So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601. > However, if we want to resolve the issue with Torchao, we need to > Add a method in the constant folder in Inductor to allow registration of impure ops Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch Pull Request resolved: pytorch#154945 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
torch.ops.torchao.dequantize_affine decomposed to convert_element_type and mul.
Inductor will do constant_fold before pattern matching
On constant_fold, inductor replace fp8 weight and some previous operations with fp32 weight
Is this as expected?
Now register_decomposition on register_decomposition
This sample test can reproduce the issue
The text was updated successfully, but these errors were encountered: