Description
torch.ops.torchao.dequantize_affine decomposed to convert_element_type and mul.
Inductor will do constant_fold before pattern matching
On constant_fold, inductor replace fp8 weight and some previous operations with fp32 weight
Is this as expected?
Now register_decomposition on register_decomposition
This sample test can reproduce the issue
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "0"
os.environ["TORCHDYNAMO_PRINT_GUARD_FAILS"] = "0"
from typing import Callable, List, Optional, Union
import torch
from torch import nn
import torchao
#import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
def dequantize_per_tensor(
input: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype
) -> torch.Tensor:
res = torch.ops.torchao.dequantize_affine(
input=input,
block_size=input.shape,
scale=scale,
zero_point=torch.tensor(0),
input_dtype=torch.float8_e4m3fn,
)
if output_dtype != torch.float:
res = res.to(output_dtype)
return res
def quantize_per_tensor(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
return torch.ops.torchao.quantize_affine(
input=input,
block_size=input.shape,
scale=scale,
zero_point=torch.tensor(0),
output_dtype=torch.float8_e4m3fn,
)
class Perceptron(torch.nn.Module):
def __init__(
self,
in_size: int,
out_size: int,
bias: bool = True,
activation: Union[
torch.nn.Module,
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self._out_size = out_size
self._in_size = in_size
self._linear: nn.Linear = nn.Linear(
self._in_size,
self._out_size,
bias=bias,
device=device,
dtype=dtype,
)
self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self._activation_fn(self._linear(input))
class MLP(torch.nn.Module):
def __init__(
self,
in_size: int,
layer_sizes: List[int],
bias: bool = True,
activation: Union[
str,
Callable[[], torch.nn.Module],
torch.nn.Module,
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
if activation == "relu":
activation = torch.relu
elif activation == "sigmoid":
activation = torch.sigmoid
if not isinstance(activation, str):
self._mlp: torch.nn.Module = torch.nn.Sequential(
*[
Perceptron(
layer_sizes[i - 1] if i > 0 else in_size,
layer_sizes[i],
bias=bias,
activation=activation,
device=device,
dtype=dtype,
)
for i in range(len(layer_sizes))
]
)
else:
assert (
ValueError
), "This MLP only support str version activation function of relu, sigmoid, and swish_layernorm"
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self._mlp(input)
class DenseArch(nn.Module):
def __init__(
self,
in_features: int,
layer_sizes: List[int],
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.model: nn.Module = MLP(
in_features, layer_sizes, bias=True, activation="relu", device=device
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
return self.model(features)
def inc_convert(model, dtype):
model.eval()
qtype = torch.float8_e4m3fn
#from torch.ao.quantization.fx._decomposed import quantize_per_tensor, dequantize_per_tensor
from torch.nn import functional as F
class FP8QDQLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.empty((out_features, in_features),)
self.weight_scale = None
self.scale = None
self.bias = None
def forward(self, input):
weight = dequantize_per_tensor(
self.weight.data,
self.weight_scale,
dtype,
)
q_input = quantize_per_tensor(
input,
self.scale,
)
dq_input = dequantize_per_tensor(
q_input,
self.scale,
dtype
)
# out1 = torch._scaled_mm(q_input, self.weight.T, torch.tensor(self.scale), torch.tensor(self.weight_scale), bias=self.bias, out_dtype=torch.float8_e4m3fn)
# out2 = torch.mm(dq_input, weight.T) + self.bias
# out3 = torch.nn.functional.linear(dq_input, weight, self.bias)
out = torch.nn.functional.linear(dq_input, weight, self.bias)
return out
class FP8QDQEmbeddingBag(torch.nn.Module):
def __init__(self, weight_shape, max_norm, norm_type, scale_grad_by_freq, mode, sparse,
include_last_offset, padding_idx):
super().__init__()
#self.mod = mod
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.mode = mode
self.sparse = sparse
self.include_last_offset = include_last_offset
self.padding_idx = padding_idx
self.weight = torch.empty(weight_shape)
self.weight_scale = None
def forward(
self,
input,
offsets=None,
per_sample_weights=None,
):
weight = dequantize_per_tensor(
self.weight.data,
self.weight_scale,
dtype,
)
return F.embedding_bag(
input,
weight,
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
hook_handles = []
import json
from collections import namedtuple
def generate_model_info(model):
mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"])
parent_child_mod_dict = {}
def create_mod_info_recursion(parent):
for name, mod in parent.named_children():
parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent)
create_mod_info_recursion(mod)
create_mod_info_recursion(model)
return parent_child_mod_dict
parent_child_mod_dict = generate_model_info(model)
with torch.no_grad():
for i, (name, mod) in enumerate(model.named_modules()):
mod_type_str = mod.__class__.__name__
#print(mod_type_str)
#continue
if mod_type_str not in ["Linear", "EmbeddingBag"]:
continue
print(mod_type_str, name)
param = mod.weight
xmax = torch.max(param)
weight_scale = xmax / torch.finfo(qtype).max
setattr(mod, "weight_scale", weight_scale)
q_param = torch.clamp((param / weight_scale), torch.finfo(qtype).min, torch.finfo(qtype).max).to(qtype)
mod.weight.data = q_param
if mod_type_str in ["Linear"]:
scale = [1 / torch.finfo(qtype).max]
assert len(scale) == 1
#setattr(mod, "scale", scale[0])
patched_mod = FP8QDQLinear(mod.in_features, mod.out_features)
patched_mod.bias = mod.bias
patched_mod.weight.data = q_param
patched_mod.scale = torch.tensor(scale[0])
patched_mod.weight_scale = torch.tensor(weight_scale.item())
else:
patched_mod = FP8QDQEmbeddingBag(
weight_shape=mod.weight.shape,
max_norm=mod.max_norm,
norm_type=mod.norm_type,
scale_grad_by_freq=mod.scale_grad_by_freq,
mode=mod.mode,
sparse=mod.sparse,
include_last_offset=mod.include_last_offset,
padding_idx=mod.padding_idx)
patched_mod.weight_scale = weight_scale.item()
patched_mod.weight.data = q_param
parent = parent_child_mod_dict[mod].parent
name = parent_child_mod_dict[mod].name
setattr(parent, name, patched_mod)
def pt2e(model, inputs):
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.export import export_for_training
with torch.no_grad():
out = model(*inputs)
exported_model = export_for_training(
model,
example_inputs,
strict=True
).module()
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepared_model = prepare_pt2e(exported_model, quantizer)
prepared_model(*inputs)
converted_model = convert_pt2e(prepared_model)
torch.ao.quantization.move_exported_model_to_eval(converted_model)
converted_model(*inputs)
return converted_model
import time
from torch._inductor import config as inductor_config
from torch._dynamo import config
config.error_on_recompile = True
#inductor_config.cpp_wrapper = True
inductor_config.max_autotune = False
inductor_config.freezing = True
inductor_config.aot_inductor.debug_compile = False
model = DenseArch(13,[512,256,128], "cpu")
example_inputs = (torch.randn(128, 13),)
print(model)
tmp0 = model.model._mlp[0]._linear(*example_inputs)
# tmp1 = model.model._mlp[0]._linear(*example_inputs)
import contextlib
ctx1 = contextlib.suppress()
ctx2 = torch.autocast("cpu", enabled=True, dtype=torch.bfloat16)
#dtype = torch.float
dtype = torch.float32
if dtype == torch.float32:
ctx = ctx1
else:
ctx = ctx2
with torch.no_grad(), ctx:
qtype = torch.float8_e4m3fn
refe = model(*example_inputs)
if qtype == torch.int8:
model = pt2e(model, example_inputs)
else:
inc_convert(model, dtype)
test_eager = model(*example_inputs)
model = torch.compile(model)
model(*example_inputs)
test = model(*example_inputs)