Skip to content

[Quant] Can quant not be decomposed on inductor? #2228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shiyang-weng opened this issue May 20, 2025 · 14 comments · May be fixed by #2299 or #2379
Open

[Quant] Can quant not be decomposed on inductor? #2228

shiyang-weng opened this issue May 20, 2025 · 14 comments · May be fixed by #2299 or #2379

Comments

@shiyang-weng
Copy link

torch.ops.torchao.dequantize_affine decomposed to convert_element_type and mul.
Inductor will do constant_fold before pattern matching
On constant_fold, inductor replace fp8 weight and some previous operations with fp32 weight
Is this as expected?

Now register_decomposition on register_decomposition

This sample test can reproduce the issue

import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "0"
os.environ["TORCHDYNAMO_PRINT_GUARD_FAILS"] = "0"

from typing import Callable, List, Optional, Union
import torch
from torch import nn
import torchao
#import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq

def dequantize_per_tensor(
        input: torch.Tensor,
        scale: torch.Tensor,
        output_dtype: torch.dtype
) -> torch.Tensor:
    res = torch.ops.torchao.dequantize_affine(
        input=input,
        block_size=input.shape,
        scale=scale,
        zero_point=torch.tensor(0),
        input_dtype=torch.float8_e4m3fn,
    )
    if output_dtype != torch.float:
        res = res.to(output_dtype)
    return res

def quantize_per_tensor(
        input: torch.Tensor,
        scale: torch.Tensor,
) -> torch.Tensor:
    return torch.ops.torchao.quantize_affine(
        input=input,
        block_size=input.shape,
        scale=scale,
        zero_point=torch.tensor(0),
        output_dtype=torch.float8_e4m3fn,
    )

class Perceptron(torch.nn.Module):
    def __init__(
        self,
        in_size: int,
        out_size: int,
        bias: bool = True,
        activation: Union[
            torch.nn.Module,
            Callable[[torch.Tensor], torch.Tensor],
        ] = torch.relu,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        super().__init__()
        self._out_size = out_size
        self._in_size = in_size
        self._linear: nn.Linear = nn.Linear(
            self._in_size,
            self._out_size,
            bias=bias,
            device=device,
            dtype=dtype,
        )
        self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._activation_fn(self._linear(input))

class MLP(torch.nn.Module):
    def __init__(
        self,
        in_size: int,
        layer_sizes: List[int],
        bias: bool = True,
        activation: Union[
            str,
            Callable[[], torch.nn.Module],
            torch.nn.Module,
            Callable[[torch.Tensor], torch.Tensor],
        ] = torch.relu,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        super().__init__()

        if activation == "relu":
            activation = torch.relu
        elif activation == "sigmoid":
            activation = torch.sigmoid

        if not isinstance(activation, str):
            self._mlp: torch.nn.Module = torch.nn.Sequential(
                *[
                    Perceptron(
                        layer_sizes[i - 1] if i > 0 else in_size,
                        layer_sizes[i],
                        bias=bias,
                        activation=activation,
                        device=device,
                        dtype=dtype,
                    )
                    for i in range(len(layer_sizes))
                ]
            )
        else:
                assert (
                    ValueError
                ), "This MLP only support str version activation function of relu, sigmoid, and swish_layernorm"

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._mlp(input)

class DenseArch(nn.Module):
    def __init__(
        self,
        in_features: int,
        layer_sizes: List[int],
        device: Optional[torch.device] = None,
    ) -> None:
        super().__init__()
        self.model: nn.Module = MLP(
            in_features, layer_sizes, bias=True, activation="relu", device=device
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        return self.model(features)


def inc_convert(model, dtype):
    model.eval()
    qtype = torch.float8_e4m3fn

    #from torch.ao.quantization.fx._decomposed import quantize_per_tensor, dequantize_per_tensor
    from torch.nn import functional as F

    class FP8QDQLinear(torch.nn.Module):
        def __init__(self, in_features, out_features):
            super().__init__()
            self.weight = torch.empty((out_features, in_features),)
            self.weight_scale = None
            self.scale = None
            self.bias = None

        def forward(self, input):
            weight = dequantize_per_tensor(
                self.weight.data,
                self.weight_scale,
                dtype,
            )
            q_input = quantize_per_tensor(
                input,
                self.scale,
            )

            dq_input = dequantize_per_tensor(
                q_input,
                self.scale,
                dtype
            )
            # out1 = torch._scaled_mm(q_input, self.weight.T, torch.tensor(self.scale), torch.tensor(self.weight_scale), bias=self.bias, out_dtype=torch.float8_e4m3fn)
            # out2 = torch.mm(dq_input, weight.T) + self.bias
            # out3 = torch.nn.functional.linear(dq_input, weight, self.bias)
            out = torch.nn.functional.linear(dq_input, weight, self.bias)

            return out


    class FP8QDQEmbeddingBag(torch.nn.Module):

        def __init__(self, weight_shape, max_norm, norm_type, scale_grad_by_freq, mode, sparse,
                     include_last_offset, padding_idx):
            super().__init__()
            #self.mod = mod
            self.max_norm = max_norm
            self.norm_type = norm_type
            self.scale_grad_by_freq = scale_grad_by_freq
            self.mode = mode
            self.sparse = sparse
            self.include_last_offset = include_last_offset
            self.padding_idx = padding_idx
            self.weight = torch.empty(weight_shape)
            self.weight_scale = None

        def forward(
            self,
            input,
            offsets=None,
            per_sample_weights=None,
        ):
            weight = dequantize_per_tensor(
                self.weight.data,
                self.weight_scale,
                dtype,
            )


            return F.embedding_bag(
                input,
                weight,
                offsets,
                self.max_norm,
                self.norm_type,
                self.scale_grad_by_freq,
                self.mode,
                self.sparse,
                per_sample_weights,
                self.include_last_offset,
                self.padding_idx,
            )

    hook_handles = []
    import json
    from collections import namedtuple

    def generate_model_info(model):
        mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"])
        parent_child_mod_dict = {}

        def create_mod_info_recursion(parent):
            for name, mod in parent.named_children():
                parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent)
                create_mod_info_recursion(mod)

        create_mod_info_recursion(model)
        return parent_child_mod_dict
    parent_child_mod_dict = generate_model_info(model)

    with torch.no_grad():
        for i, (name, mod) in enumerate(model.named_modules()):
            mod_type_str = mod.__class__.__name__
            #print(mod_type_str)
            #continue
            if mod_type_str not in ["Linear", "EmbeddingBag"]:
                continue
            print(mod_type_str, name)
            param = mod.weight
            xmax = torch.max(param)
            weight_scale = xmax / torch.finfo(qtype).max
            setattr(mod, "weight_scale", weight_scale)
            q_param = torch.clamp((param / weight_scale), torch.finfo(qtype).min, torch.finfo(qtype).max).to(qtype)
            mod.weight.data = q_param
            if mod_type_str in ["Linear"]:
                scale = [1 / torch.finfo(qtype).max]
                assert len(scale) == 1
                #setattr(mod, "scale", scale[0])
                patched_mod = FP8QDQLinear(mod.in_features, mod.out_features)
                patched_mod.bias = mod.bias
                patched_mod.weight.data = q_param
                patched_mod.scale = torch.tensor(scale[0])
                patched_mod.weight_scale = torch.tensor(weight_scale.item())
            else:
                patched_mod = FP8QDQEmbeddingBag(
                    weight_shape=mod.weight.shape,
                    max_norm=mod.max_norm,
                    norm_type=mod.norm_type,
                    scale_grad_by_freq=mod.scale_grad_by_freq,
                    mode=mod.mode,
                    sparse=mod.sparse,
                    include_last_offset=mod.include_last_offset,
                    padding_idx=mod.padding_idx)
                patched_mod.weight_scale = weight_scale.item()
                patched_mod.weight.data = q_param

            parent = parent_child_mod_dict[mod].parent
            name = parent_child_mod_dict[mod].name
            setattr(parent, name, patched_mod)

def pt2e(model, inputs):
    from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
    import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
    from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
    from torch.export import export_for_training
    with torch.no_grad():
        out = model(*inputs)
        exported_model = export_for_training(
            model,
            example_inputs,
            strict=True
        ).module()
        quantizer = X86InductorQuantizer()
        quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
        prepared_model = prepare_pt2e(exported_model, quantizer)
        prepared_model(*inputs)
        converted_model = convert_pt2e(prepared_model)
        torch.ao.quantization.move_exported_model_to_eval(converted_model)
        converted_model(*inputs)
    return converted_model


import time

from torch._inductor import config as inductor_config
from torch._dynamo import config

config.error_on_recompile = True
#inductor_config.cpp_wrapper = True
inductor_config.max_autotune = False
inductor_config.freezing = True

inductor_config.aot_inductor.debug_compile = False


model = DenseArch(13,[512,256,128], "cpu")
example_inputs = (torch.randn(128, 13),)

print(model)
tmp0 = model.model._mlp[0]._linear(*example_inputs)
# tmp1 = model.model._mlp[0]._linear(*example_inputs)
import contextlib
ctx1 = contextlib.suppress()
ctx2 = torch.autocast("cpu", enabled=True, dtype=torch.bfloat16)
#dtype = torch.float
dtype = torch.float32
if dtype == torch.float32:
    ctx = ctx1
else:
    ctx = ctx2
with torch.no_grad(), ctx:
    qtype = torch.float8_e4m3fn
    refe = model(*example_inputs)
    if qtype == torch.int8:
        model = pt2e(model, example_inputs)
    else:
        inc_convert(model, dtype)
    test_eager = model(*example_inputs)
    model = torch.compile(model)
    model(*example_inputs)
    test = model(*example_inputs)
@jerryzh168
Copy link
Contributor

jerryzh168 commented May 20, 2025

yeah we use

def _register_custom_op(lib):
to prevent the op of being decomposed during export, but continue to be decomposed in inductor

do you want the op to be preserved in inductor?

@shiyang-weng
Copy link
Author

yeah we use

ao/torchao/utils.py

Line 180 in 96aec6a

def _register_custom_op(lib):
to prevent the op of being decomposed during export, but continue to be decomposed in inductor
do you want the op to be preserved in inductor?

Yes. There is an issue that fp8 weight will be fixed to fp32 weight on constant_fold.
Or do we have any other way to avoid this issue?

And quant/dequant decomposition will make the pattern complicated. Can we not decompose here?

@shiyang-weng
Copy link
Author

Hi @jerryzh168 , I'm not sure if removing decompose here would cause any other issues.
Can we consider landing in pytorch first and then migrating over?
My PR on PT:
pytorch/pytorch#153602
pytorch/pytorch#153601

@shiyang-weng
Copy link
Author

Hi @jerryzh168 , do you have any suggestions?

@Xia-Weiwen
Copy link
Collaborator

Xia-Weiwen commented May 27, 2025

Hi @jerryzh168 Please let me explain the whole story.
What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
However we met problems with these q/dq ops both in the PyTorch core and Torchao.

PyTorch core:

Torchao:

So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch/pytorch#153601.
However, if we want to resolve the issue with Torchao, we need to

  • Add a method in the constant folder in Inductor to allow registration of impure ops
  • Avoid decomposition of torchao.dequantize_affine_float8 and register this op as impure so that it won't be constant-folded.

Do you think the short term solution makes sense? And for the solution with Torchao, do you have more comments or concerns? We are looking forward to your suggestions. Thanks.

@jerryzh168
Copy link
Contributor

@Xia-Weiwen thanks for the clear summary.

I have duplicated the constant_fold code in torchao:

from .constant_fold import constant_fold
, would it be enough for you to add torchao.dequantize_affine_float8 there?

I agree that for the longer term, inductor should allow registration for impure ops, cc @eellison @jansel

for Avoid decomposition of torchao.dequantize_affine_float8 I think this is not done before, in INT8 path we explicitly decompose it for inductor right? what changed for float8?

@jansel
Copy link

jansel commented May 27, 2025

Is dequantize impure? What is it mutating?

IMO this op should be decomposed in inductor. You can register the decomp in the same place the op is defined.

@jerryzh168
Copy link
Contributor

jerryzh168 commented May 27, 2025

@jansel technically it's not, but we may need to preserve dequantize op so it can be fused with other ops to become a quantized op that takes integer tensor as input. is there a different way to specify this?

@jansel
Copy link

jansel commented May 27, 2025

Impure isn't what you are looking for. Impure means the op mutates one of its inputs, so when we functionalize we need to introduce more copies (which might increase memory usage if inductor cant optimize the copies away).

Ops will be preserved if you don't write a decomp for them, which forces them to be ExternKernels and prevents fusion with other ops.

@jerryzh168
Copy link
Contributor

Ops will be preserved if you don't write a decomp for them, which forces them to be ExternKernels and prevents fusion with other ops.

what about for constant folding? what prevents an op to be constant folded (except for marking them as impure)? I think that's the original reason we marked these ops as impure

@jansel
Copy link

jansel commented May 28, 2025

I don't believe we have a dont-constant-fold flag (correct me if I'm wrong @eellison ), though maybe we should.

@Xia-Weiwen
Copy link
Collaborator

Xia-Weiwen commented May 28, 2025

Thanks for your replies.

I have duplicated the constant_fold code in torchao:

@jerryzh168 If I understand correctly, the duplicate code is used in convert_pt2e in Torchao. However, what we talked about was the constant-folding pass in Inductor here: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69. So, I don't think we can add something in the Torchao code and resolve the issue.
We don't want the dequantize op decomposed because once it's decomposed, the op is gone and it becomes difficult to tell the constant folder not to fold such patterns. What do you think? Thanks.

@jansel There are patterns like constant_quantized_weight -> dequantize -> fp32_op -> ... in the quantization scenario. And during a lowering process by torch.compile, we want to fuse such patterns to constant_quantized_weight -> quantized_op -> .... The fusion is done via https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L72C1-L75C1. However, the constant folding pass is applied before the fusion pass: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69. So, we need some mechanism to avoid the quantization pattern being folded to constant_fp32_weight -> fp32_op, otherwise the fusion for quantization won't be applied and the quantization semantics are lost. Do you have any suggestions? Thanks.

@jerryzh168
Copy link
Contributor

@Xia-Weiwen

We don't want the dequantize op decomposed because once it's decomposed, the op is gone and it becomes difficult to tell the constant folder not to fold such patterns. What do you think? Thanks.

this makes sense, how does it work before? also as Jason mentioned if you don't register decomposition for it, it won't be decomposed, maybe we could try adding an option to skip the registration here:

register_decomposition([op])(fn)
do you want to test this out with the new affine quant ops?

@shiyang-weng
Copy link
Author

I will do it. Plan to

  1. create _dont_constant_fold on PT
  2. add option to skip decompose on ao

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jun 5, 2025
For support pytorch/ao#2228
> What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
>
> However we met problems with these q/dq ops both in the PyTorch core and Torchao.
>
> PyTorch core:
>
> The quantize_per_tensor op does not support FP8. We want to fix it via #153601. And as you commented, the op is deprecated.
> Torchao:
>
> In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor:
> https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1
> After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now.
> For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because
> It is an op from Torchao, which is unknown to the constant folder
> It is decomposed to smaller ops, so we cannot put it in the list as a single op.
> So, we think an easy and short-term solution is to modify the ops in PyTorch core via #153601.
> However, if we want to resolve the issue with Torchao, we need to
> Add a method in the constant folder in Inductor to allow registration of impure ops

Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch

Pull Request resolved: #154945
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
angelayi pushed a commit to angelayi/pytorch that referenced this issue Jun 5, 2025
For support pytorch/ao#2228
> What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
>
> However we met problems with these q/dq ops both in the PyTorch core and Torchao.
>
> PyTorch core:
>
> The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated.
> Torchao:
>
> In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor:
> https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1
> After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now.
> For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because
> It is an op from Torchao, which is unknown to the constant folder
> It is decomposed to smaller ops, so we cannot put it in the list as a single op.
> So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601.
> However, if we want to resolve the issue with Torchao, we need to
> Add a method in the constant folder in Inductor to allow registration of impure ops

Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch

Pull Request resolved: pytorch#154945
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
framoncg pushed a commit to docathon-pytorch-friends/pytorch that referenced this issue Jun 6, 2025
For support pytorch/ao#2228
> What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
>
> However we met problems with these q/dq ops both in the PyTorch core and Torchao.
>
> PyTorch core:
>
> The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated.
> Torchao:
>
> In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor:
> https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1
> After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now.
> For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because
> It is an op from Torchao, which is unknown to the constant folder
> It is decomposed to smaller ops, so we cannot put it in the list as a single op.
> So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601.
> However, if we want to resolve the issue with Torchao, we need to
> Add a method in the constant folder in Inductor to allow registration of impure ops

Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch

Pull Request resolved: pytorch#154945
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
@shiyang-weng shiyang-weng linked a pull request Jun 9, 2025 that will close this issue
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jun 10, 2025
For support pytorch/ao#2228
> What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
>
> However we met problems with these q/dq ops both in the PyTorch core and Torchao.
>
> PyTorch core:
>
> The quantize_per_tensor op does not support FP8. We want to fix it via #153601. And as you commented, the op is deprecated.
> Torchao:
>
> In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor:
> https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1
> After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now.
> For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because
> It is an op from Torchao, which is unknown to the constant folder
> It is decomposed to smaller ops, so we cannot put it in the list as a single op.
> So, we think an easy and short-term solution is to modify the ops in PyTorch core via #153601.
> However, if we want to resolve the issue with Torchao, we need to
> Add a method in the constant folder in Inductor to allow registration of impure ops

Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch

Pull Request resolved: #154945
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
thatgeeman pushed a commit to thatgeeman/pytorch-docathon that referenced this issue Jun 15, 2025
For support pytorch/ao#2228
> What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
>
> However we met problems with these q/dq ops both in the PyTorch core and Torchao.
>
> PyTorch core:
>
> The quantize_per_tensor op does not support FP8. We want to fix it via pytorch#153601. And as you commented, the op is deprecated.
> Torchao:
>
> In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor:
> https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1
> After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now.
> For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because
> It is an op from Torchao, which is unknown to the constant folder
> It is decomposed to smaller ops, so we cannot put it in the list as a single op.
> So, we think an easy and short-term solution is to modify the ops in PyTorch core via pytorch#153601.
> However, if we want to resolve the issue with Torchao, we need to
> Add a method in the constant folder in Inductor to allow registration of impure ops

Based on [Jansel‘s reply](pytorch/ao#2228 (comment)), add dont constant fold flag on this patch

Pull Request resolved: pytorch#154945
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants