Skip to content

[Quant] Can quant not be decomposed on inductor? #2228

@shiyang-weng

Description

@shiyang-weng

torch.ops.torchao.dequantize_affine decomposed to convert_element_type and mul.
Inductor will do constant_fold before pattern matching
On constant_fold, inductor replace fp8 weight and some previous operations with fp32 weight
Is this as expected?

Now register_decomposition on register_decomposition

This sample test can reproduce the issue

import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "0"
os.environ["TORCHDYNAMO_PRINT_GUARD_FAILS"] = "0"

from typing import Callable, List, Optional, Union
import torch
from torch import nn
import torchao
#import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq

def dequantize_per_tensor(
        input: torch.Tensor,
        scale: torch.Tensor,
        output_dtype: torch.dtype
) -> torch.Tensor:
    res = torch.ops.torchao.dequantize_affine(
        input=input,
        block_size=input.shape,
        scale=scale,
        zero_point=torch.tensor(0),
        input_dtype=torch.float8_e4m3fn,
    )
    if output_dtype != torch.float:
        res = res.to(output_dtype)
    return res

def quantize_per_tensor(
        input: torch.Tensor,
        scale: torch.Tensor,
) -> torch.Tensor:
    return torch.ops.torchao.quantize_affine(
        input=input,
        block_size=input.shape,
        scale=scale,
        zero_point=torch.tensor(0),
        output_dtype=torch.float8_e4m3fn,
    )

class Perceptron(torch.nn.Module):
    def __init__(
        self,
        in_size: int,
        out_size: int,
        bias: bool = True,
        activation: Union[
            torch.nn.Module,
            Callable[[torch.Tensor], torch.Tensor],
        ] = torch.relu,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        super().__init__()
        self._out_size = out_size
        self._in_size = in_size
        self._linear: nn.Linear = nn.Linear(
            self._in_size,
            self._out_size,
            bias=bias,
            device=device,
            dtype=dtype,
        )
        self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._activation_fn(self._linear(input))

class MLP(torch.nn.Module):
    def __init__(
        self,
        in_size: int,
        layer_sizes: List[int],
        bias: bool = True,
        activation: Union[
            str,
            Callable[[], torch.nn.Module],
            torch.nn.Module,
            Callable[[torch.Tensor], torch.Tensor],
        ] = torch.relu,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        super().__init__()

        if activation == "relu":
            activation = torch.relu
        elif activation == "sigmoid":
            activation = torch.sigmoid

        if not isinstance(activation, str):
            self._mlp: torch.nn.Module = torch.nn.Sequential(
                *[
                    Perceptron(
                        layer_sizes[i - 1] if i > 0 else in_size,
                        layer_sizes[i],
                        bias=bias,
                        activation=activation,
                        device=device,
                        dtype=dtype,
                    )
                    for i in range(len(layer_sizes))
                ]
            )
        else:
                assert (
                    ValueError
                ), "This MLP only support str version activation function of relu, sigmoid, and swish_layernorm"

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._mlp(input)

class DenseArch(nn.Module):
    def __init__(
        self,
        in_features: int,
        layer_sizes: List[int],
        device: Optional[torch.device] = None,
    ) -> None:
        super().__init__()
        self.model: nn.Module = MLP(
            in_features, layer_sizes, bias=True, activation="relu", device=device
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        return self.model(features)


def inc_convert(model, dtype):
    model.eval()
    qtype = torch.float8_e4m3fn

    #from torch.ao.quantization.fx._decomposed import quantize_per_tensor, dequantize_per_tensor
    from torch.nn import functional as F

    class FP8QDQLinear(torch.nn.Module):
        def __init__(self, in_features, out_features):
            super().__init__()
            self.weight = torch.empty((out_features, in_features),)
            self.weight_scale = None
            self.scale = None
            self.bias = None

        def forward(self, input):
            weight = dequantize_per_tensor(
                self.weight.data,
                self.weight_scale,
                dtype,
            )
            q_input = quantize_per_tensor(
                input,
                self.scale,
            )

            dq_input = dequantize_per_tensor(
                q_input,
                self.scale,
                dtype
            )
            # out1 = torch._scaled_mm(q_input, self.weight.T, torch.tensor(self.scale), torch.tensor(self.weight_scale), bias=self.bias, out_dtype=torch.float8_e4m3fn)
            # out2 = torch.mm(dq_input, weight.T) + self.bias
            # out3 = torch.nn.functional.linear(dq_input, weight, self.bias)
            out = torch.nn.functional.linear(dq_input, weight, self.bias)

            return out


    class FP8QDQEmbeddingBag(torch.nn.Module):

        def __init__(self, weight_shape, max_norm, norm_type, scale_grad_by_freq, mode, sparse,
                     include_last_offset, padding_idx):
            super().__init__()
            #self.mod = mod
            self.max_norm = max_norm
            self.norm_type = norm_type
            self.scale_grad_by_freq = scale_grad_by_freq
            self.mode = mode
            self.sparse = sparse
            self.include_last_offset = include_last_offset
            self.padding_idx = padding_idx
            self.weight = torch.empty(weight_shape)
            self.weight_scale = None

        def forward(
            self,
            input,
            offsets=None,
            per_sample_weights=None,
        ):
            weight = dequantize_per_tensor(
                self.weight.data,
                self.weight_scale,
                dtype,
            )


            return F.embedding_bag(
                input,
                weight,
                offsets,
                self.max_norm,
                self.norm_type,
                self.scale_grad_by_freq,
                self.mode,
                self.sparse,
                per_sample_weights,
                self.include_last_offset,
                self.padding_idx,
            )

    hook_handles = []
    import json
    from collections import namedtuple

    def generate_model_info(model):
        mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"])
        parent_child_mod_dict = {}

        def create_mod_info_recursion(parent):
            for name, mod in parent.named_children():
                parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent)
                create_mod_info_recursion(mod)

        create_mod_info_recursion(model)
        return parent_child_mod_dict
    parent_child_mod_dict = generate_model_info(model)

    with torch.no_grad():
        for i, (name, mod) in enumerate(model.named_modules()):
            mod_type_str = mod.__class__.__name__
            #print(mod_type_str)
            #continue
            if mod_type_str not in ["Linear", "EmbeddingBag"]:
                continue
            print(mod_type_str, name)
            param = mod.weight
            xmax = torch.max(param)
            weight_scale = xmax / torch.finfo(qtype).max
            setattr(mod, "weight_scale", weight_scale)
            q_param = torch.clamp((param / weight_scale), torch.finfo(qtype).min, torch.finfo(qtype).max).to(qtype)
            mod.weight.data = q_param
            if mod_type_str in ["Linear"]:
                scale = [1 / torch.finfo(qtype).max]
                assert len(scale) == 1
                #setattr(mod, "scale", scale[0])
                patched_mod = FP8QDQLinear(mod.in_features, mod.out_features)
                patched_mod.bias = mod.bias
                patched_mod.weight.data = q_param
                patched_mod.scale = torch.tensor(scale[0])
                patched_mod.weight_scale = torch.tensor(weight_scale.item())
            else:
                patched_mod = FP8QDQEmbeddingBag(
                    weight_shape=mod.weight.shape,
                    max_norm=mod.max_norm,
                    norm_type=mod.norm_type,
                    scale_grad_by_freq=mod.scale_grad_by_freq,
                    mode=mod.mode,
                    sparse=mod.sparse,
                    include_last_offset=mod.include_last_offset,
                    padding_idx=mod.padding_idx)
                patched_mod.weight_scale = weight_scale.item()
                patched_mod.weight.data = q_param

            parent = parent_child_mod_dict[mod].parent
            name = parent_child_mod_dict[mod].name
            setattr(parent, name, patched_mod)

def pt2e(model, inputs):
    from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
    import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
    from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
    from torch.export import export_for_training
    with torch.no_grad():
        out = model(*inputs)
        exported_model = export_for_training(
            model,
            example_inputs,
            strict=True
        ).module()
        quantizer = X86InductorQuantizer()
        quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
        prepared_model = prepare_pt2e(exported_model, quantizer)
        prepared_model(*inputs)
        converted_model = convert_pt2e(prepared_model)
        torch.ao.quantization.move_exported_model_to_eval(converted_model)
        converted_model(*inputs)
    return converted_model


import time

from torch._inductor import config as inductor_config
from torch._dynamo import config

config.error_on_recompile = True
#inductor_config.cpp_wrapper = True
inductor_config.max_autotune = False
inductor_config.freezing = True

inductor_config.aot_inductor.debug_compile = False


model = DenseArch(13,[512,256,128], "cpu")
example_inputs = (torch.randn(128, 13),)

print(model)
tmp0 = model.model._mlp[0]._linear(*example_inputs)
# tmp1 = model.model._mlp[0]._linear(*example_inputs)
import contextlib
ctx1 = contextlib.suppress()
ctx2 = torch.autocast("cpu", enabled=True, dtype=torch.bfloat16)
#dtype = torch.float
dtype = torch.float32
if dtype == torch.float32:
    ctx = ctx1
else:
    ctx = ctx2
with torch.no_grad(), ctx:
    qtype = torch.float8_e4m3fn
    refe = model(*example_inputs)
    if qtype == torch.int8:
        model = pt2e(model, example_inputs)
    else:
        inc_convert(model, dtype)
    test_eager = model(*example_inputs)
    model = torch.compile(model)
    model(*example_inputs)
    test = model(*example_inputs)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions