Skip to content

[Quantized DeConv Support] Dynamically Quantized Deconvolutions with groups ==1 #11775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/mcr229/31/orig
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer):
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
QuantPattern("linear", True, False, LINEAR_TARGETS),
QuantPattern("conv", True, False, CONV_TARGETS),
QuantPattern("conv_transpose", False, False, CONV_TARGETS),
QuantPattern("conv_transpose", True, False, CONV_TARGETS),
QuantPattern("conv_relu", False, False, CONV_TARGETS),
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),
Expand Down
82 changes: 52 additions & 30 deletions backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
from executorch.backends.xnnpack.utils.utils import (
get_groups_from_conv,
is_depthwise_conv,
)
from torch._subclasses import FakeTensor
from torch.fx import Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
Expand Down Expand Up @@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
return decorator


def change_quantization_config(
original_qspec,
dtype=None,
quant_min=None,
quant_max=None,
qscheme=None,
ch_axis=None,
is_dynamic=None,
observer_or_fake_quant_ctr=None,
):
return QuantizationSpec(
dtype=dtype or original_qspec.dtype,
quant_min=quant_min or original_qspec.quant_min,
quant_max=quant_max or original_qspec.quant_max,
qscheme=qscheme or original_qspec.qscheme,
ch_axis=ch_axis or original_qspec.ch_axis,
is_dynamic=is_dynamic or original_qspec.is_dynamic,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr
or original_qspec.observer_or_fake_quant_ctr,
)


def is_relu_node(node: Node) -> bool:
"""
Check if a given node is a relu node
Expand Down Expand Up @@ -231,6 +256,9 @@ def _do_annotate_conv(
if is_relu_node(user):
continue

# Tracks conditions for whether or not to skip
skip = False

input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
Expand All @@ -239,35 +267,33 @@ def _do_annotate_conv(
weight = conv_node.args[1]
assert isinstance(weight, Node)
weight_qspec = get_weight_qspec(quantization_config)
num_groups = get_groups_from_conv(conv_node)

# skip if transposed conv has more than 1 group
skip = skip or (is_conv_transpose and num_groups != 1)
print(f"{skip} conv transpose and num_groups")

if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = QuantizationSpec(
dtype=weight_qspec.dtype,
quant_min=weight_qspec.quant_min,
quant_max=weight_qspec.quant_max,
qscheme=weight_qspec.qscheme,
ch_axis=1,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
)
input_qspec_map[weight] = weight_qspec
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if (
input_qspec_map[weight] = weight_qspec
is_dynamic = (
quantization_config
and quantization_config.input_activation
and quantization_config.input_activation.is_dynamic
):
)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if is_dynamic:
weight_val = weight.meta.get("val", None)
weight_shape = getattr(weight_val, "shape", None)

# Skip if not a 4D weight tensor (i.e. not conv2d)
if weight_shape is not None and len(weight_shape) != 4:
continue

skip = skip or (weight_shape is not None and len(weight_shape) != 4)
# Skip if depthwise (default to groups=1 since it's not an arg)
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
continue
skip = skip or (
not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False)
)

# adding weight node to the partition as well
partition = [conv_node, conv_node.args[1]]
Expand All @@ -277,7 +303,7 @@ def _do_annotate_conv(
input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias)

if _is_annotated(partition):
if _is_annotated(partition) or skip:
continue

if filter_fn and any(not filter_fn(n) for n in partition):
Expand Down Expand Up @@ -324,17 +350,10 @@ def _do_annotate_conv_relu(
weight = conv_node.args[1]
assert isinstance(weight, Node)
weight_qspec = get_weight_qspec(quantization_config)
groups = get_groups_from_conv(conv_node)
if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = QuantizationSpec(
dtype=weight_qspec.dtype,
quant_min=weight_qspec.quant_min,
quant_max=weight_qspec.quant_max,
qscheme=weight_qspec.qscheme,
ch_axis=1,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
)
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
input_qspec_map[weight] = weight_qspec

# adding weight node to the partition as well
Expand All @@ -347,6 +366,9 @@ def _do_annotate_conv_relu(
if _is_annotated(partition):
continue

if is_conv_transpose and groups != 1:
continue

if filter_fn and any(not filter_fn(n) for n in partition):
continue

Expand Down
81 changes: 49 additions & 32 deletions backends/xnnpack/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,11 @@ def get_inputs(self):


class Conv2dDQSeq(torch.nn.Module):
def __init__(self):
def __init__(self, transpose=False):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)
self.second = torch.nn.Conv2d(
in_channels=8, out_channels=10, kernel_size=3, padding=1
)
op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d
self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1)
self.second = op(in_channels=8, out_channels=10, kernel_size=3, padding=1)

def forward(self, x):
y = self.first(x)
Expand All @@ -192,14 +189,11 @@ def get_inputs(self):


class Conv2dDQParallel(torch.nn.Module):
def __init__(self):
def __init__(self, transpose=False):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)
self.second = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)
op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d
self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1)
self.second = op(in_channels=3, out_channels=10, kernel_size=3, padding=1)

def forward(self, x):
first = self.first(x)
Expand Down Expand Up @@ -266,8 +260,7 @@ def _test_dq(
)

DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=True,
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True
)

tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes)
Expand Down Expand Up @@ -349,11 +342,10 @@ def test_fp32_conv2d_depthwise(self):
)

def test_qs8_conv2d_depthwise(self):
for transpose in (True, False):
self._test(
Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose),
quant_config=get_symmetric_quantization_config(),
)
self._test(
Conv2d(groups=2, in_channels=2, out_channels=6),
quant_config=get_symmetric_quantization_config(),
)

def test_fp32_conv2d_bn(self):
class Conv2dBatchNorm(torch.nn.Module):
Expand Down Expand Up @@ -515,17 +507,14 @@ def forward(self, x):
def get_inputs(self):
return (torch.randn(batches, in_channels, height, width) * 11,)

for transpose in (True, False):
for per_channel_quant in (False, True):
if transpose and per_channel_quant:
continue
model = ModelConvReLU(transpose=transpose)
self._test(
model,
quant_config=get_symmetric_quantization_config(
is_per_channel=per_channel_quant
),
)
for per_channel_quant in (False, True):
model = ModelConvReLU()
self._test(
model,
quant_config=get_symmetric_quantization_config(
is_per_channel=per_channel_quant
),
)

def test_qs8_conv2d_relu_seq(self):
class ConvReLUSeq(torch.nn.Module):
Expand Down Expand Up @@ -728,3 +717,31 @@ def test_dq_conv2d_parallel(self) -> None:
model = Conv2dDQParallel()
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
self._test_dq(model, conv_count)

def test_dq_conv2d_transpose(self) -> None:
model = Conv2d(
in_channels=3,
out_channels=10,
kernel_size=(3, 3),
stride=(1, 1),
padding=(0, 0),
batches=1,
width=8,
height=8,
transpose=True,
)
self._test_dq(model)

def test_dq_conv2d_transpose_seq(self) -> None:
model = Conv2dDQSeq(transpose=True)
conv_count = sum(
1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d
)
self._test_dq(model, conv_count)

def test_dq_conv2d_transpose_parallel(self) -> None:
model = Conv2dDQParallel(transpose=True)
conv_count = sum(
1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d
)
self._test_dq(model, conv_count)
31 changes: 31 additions & 0 deletions backends/xnnpack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
is_lifted_tensor_constant,
is_param,
)
from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node


### XNNPACK Capture ###
Expand Down Expand Up @@ -160,6 +161,36 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
return source_fn[1]


def get_groups_from_conv(conv_node: torch.fx.Node) -> int:
if _is_conv_node(conv_node):
in_node = cast(torch.fx.Node, conv_node.args[0])
weight_node = cast(torch.fx.Node, conv_node.args[1])
# groups isn't given to us in the training graph so we deduce it from the weight shape
# and the input shape

# input shape is (N, C_in, H_in, W_in)
in_channels = in_node.meta["val"].shape[1]

# weight shape is (C_out, C_in/groups, kernel_size[0], kernel_size[1])
in_groups = weight_node.meta["val"].shape[1]

return in_channels // in_groups
elif _is_conv_transpose_node(conv_node):
weight_node = cast(torch.fx.Node, conv_node.args[1])
# groups isn't given to us in the training graph so we deduce it from the weight shape
# and the output shape

# weight shape is (C_in, C_out/groups, kernel_size[0], kernel_size[1])
out_groups = weight_node.meta["val"].shape[1]

# output shape is (N, C_out, H_out, W_out)
out_channels = conv_node.meta["val"].shape[1]

return out_channels // out_groups

raise RuntimeError(f"expected {conv_node} to be a conv or conv_transpose node")


def is_depthwise_conv(
kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False
) -> bool:
Expand Down
Loading