Skip to content

Commit b3e891d

Browse files
committed
Qualcomm AI Engine Direct - add pass for extra padding then maxpool2d
Summary: The padding value used in max_pool2d operations differs between PyTorch and QNN implementations. PyTorch uses negative infinity, while QNN uses zero. To ensure consistent max_pool2d output across both frameworks, we handle this by padding tensor with constant in advance then doing max_pool2d without constant padding. Test plans: python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_max_pool2d -b build-android -H ${HOST} -s ${SN} -m ${CHIPID} python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_max_pool2d -b build-android -H ${HOST} -s ${SN} -m ${CHIPID}
1 parent 47dc1de commit b3e891d

File tree

6 files changed

+185
-11
lines changed

6 files changed

+185
-11
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps
4141
from .layout_transform import LayoutTransform
4242
from .lift_constant_scalar_operands import LiftConstantScalarOperands
43+
from .recompose_pad_maxpool2d import RecomposePadMaxPool2d
4344
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
4445
from .recompose_rms_norm import RecomposeRmsNorm
4546
from .reduce_dynamic_range import ReduceDynamicRange
@@ -87,6 +88,7 @@
8788
InsertRequantize,
8889
LayoutTransform,
8990
LiftConstantScalarOperands,
91+
RecomposePadMaxPool2d,
9092
RecomposePixelUnshuffle,
9193
RecomposeRmsNorm,
9294
ReduceDynamicRange,

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
InsertReshapeForReduceOps,
4646
LayoutTransform,
4747
LiftConstantScalarOperands,
48+
RecomposePadMaxPool2d,
4849
RecomposePixelUnshuffle,
4950
RecomposeRmsNorm,
5051
ReduceDynamicRange,
@@ -93,13 +94,14 @@ def get_capture_program_passes():
9394
(ConvertBmmToMatmul, False),
9495
(DecomposeAny, True),
9596
(DecomposeColIm, True),
97+
(DecomposeMaxPool3d, True),
9698
(DecomposeMinMaxDim, True),
9799
(ExpandBroadcastTensorShape, True),
98100
(FixedLinearKeepDim, True),
99101
(FoldQDQ, True),
100102
(I64toI32, True),
101103
(LayoutTransform, True),
102-
(DecomposeMaxPool3d, True),
104+
(RecomposePadMaxPool2d, True),
103105
(RecomposePixelUnshuffle, True),
104106
(RecomposeRmsNorm, True),
105107
(Remove0DTensor, True),
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
from typing import cast, List
9+
10+
import torch
11+
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
from torch._subclasses.fake_tensor import FakeTensorMode
16+
17+
18+
def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype):
19+
fake_mode = FakeTensorMode()
20+
21+
with fake_mode:
22+
batch, channels, height, width = input_shape
23+
pad_left, pad_right, pad_top, pad_bottom = padding_args
24+
output_shape = (
25+
batch,
26+
channels,
27+
height + pad_top + pad_bottom,
28+
width + pad_left + pad_right,
29+
)
30+
fake_output = torch.empty(output_shape, dtype=dtype)
31+
if not hasattr(padding_node, "meta"):
32+
padding_node.meta = {}
33+
padding_node.meta["val"] = fake_output
34+
35+
return fake_output
36+
37+
38+
class RecomposePadMaxPool2d(ExportPass):
39+
"""
40+
The padding value used in max_pool2d operations differs between PyTorch and QNN implementations.
41+
PyTorch uses negative infinity, while QNN uses zero. To ensure consistent max_pool2d output across both frameworks,
42+
we handle this by padding tensor with constant in advance then doing max_pool2d without constant padding.
43+
Note that for the quantization flow, we set quant_min as the padding value. If, at runtime, there is a value smaller than quant_min,
44+
it could result in an accuracy drop.
45+
"""
46+
47+
def __init__(self):
48+
super(RecomposePadMaxPool2d, self).__init__()
49+
self.getitem = operator.getitem
50+
self.max_pool2d = exir_ops.edge.aten.max_pool2d_with_indices.default
51+
self.pad_op = exir_ops.edge.aten.constant_pad_nd.default
52+
53+
def call(self, graph_module: torch.fx.GraphModule):
54+
graph = graph_module.graph
55+
for node in graph.nodes:
56+
num_args = len(node.args)
57+
if (
58+
node.op == "call_function"
59+
and node.target == self.max_pool2d
60+
and num_args > 3
61+
):
62+
padding = cast(List[int], node.args[3])
63+
if len(padding) == 1:
64+
padding *= 2
65+
if padding[0] == 0 and padding[1] == 0:
66+
continue
67+
# create padding info for constant_pad_nd
68+
padding = cast(List[int], node.args[3])
69+
if len(padding) == 1:
70+
padding *= 4
71+
elif len(padding) == 2:
72+
padding = [padding[1], padding[1], padding[0], padding[0]]
73+
74+
input_node = node.args[0]
75+
# kernel info
76+
filter_size = cast(List[int], node.args[1])
77+
if len(filter_size) == 1:
78+
filter_size *= 2
79+
# stride info
80+
stride = cast(List[int], node.args[2])
81+
if len(stride) == 1:
82+
stride *= 2
83+
# dilation info
84+
dilation = [1, 1]
85+
if num_args > 4:
86+
dilation = cast(List[int], node.args[4])
87+
if len(padding) == 1:
88+
dilation *= 2
89+
90+
ceil_mode = node.args[5] if num_args > 5 else False
91+
92+
# We need to know the minimum value of input tensor of max_pool2d.
93+
padding_value = float("-inf")
94+
if quant_attrs := node.meta.get("quant_attrs"):
95+
padding_value = quant_attrs.get("quant_min")
96+
pad_value = padding_value
97+
if quant_attrs:
98+
pad_value = (
99+
padding_value - quant_attrs["zero_point"]
100+
) * quant_attrs["scale"]
101+
with graph_module.graph.inserting_after(input_node):
102+
padding_node = graph.create_node(
103+
"call_function",
104+
self.pad_op,
105+
(
106+
input_node,
107+
padding,
108+
pad_value,
109+
),
110+
)
111+
add_fake_tensor_to_node(
112+
padding_node,
113+
input_node.meta["val"].shape,
114+
padding,
115+
input_node.meta["val"].dtype,
116+
)
117+
if quant_attrs:
118+
padding_node.meta["quant_attrs"] = node.meta["quant_attrs"]
119+
120+
with graph_module.graph.inserting_after(padding_node):
121+
# max_pool2d
122+
maxpool2d_args = (
123+
padding_node,
124+
filter_size,
125+
stride,
126+
(0, 0),
127+
dilation,
128+
ceil_mode,
129+
)
130+
maxpool2d_node_tuple = graph.create_node(
131+
"call_function",
132+
self.max_pool2d,
133+
maxpool2d_args,
134+
)
135+
if quant_attrs:
136+
maxpool2d_node_tuple.meta["quant_attrs"] = node.meta[
137+
"quant_attrs"
138+
]
139+
maxpool2d_node_tuple.meta["val"] = [None, None]
140+
maxpool2d_node_tuple.meta["val"][0] = padding_node.meta["val"]
141+
142+
for user in node.users.copy():
143+
user.replace_input_with(node, maxpool2d_node_tuple)
144+
145+
graph.eliminate_dead_code()
146+
graph_module.recompile()
147+
return PassResult(graph_module, True)

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def get_passes_dependency_for_capture_program():
7575
FoldQDQ,
7676
I64toI32,
7777
LayoutTransform,
78+
RecomposePadMaxPool2d,
7879
RecomposePixelUnshuffle,
7980
RecomposeRmsNorm,
8081
RemoveRedundancy,
@@ -105,6 +106,7 @@ def get_passes_dependency_for_capture_program():
105106
ExpandBroadcastTensorShape,
106107
FixedLinearKeepDim,
107108
],
109+
RecomposePadMaxPool2d: [DecomposeMaxPool3d, FoldQDQ],
108110
RecomposePixelUnshuffle: [RemoveRedundancy],
109111
RecomposeRmsNorm: [RemoveRedundancy],
110112
TagQuantIO: [LayoutTransform],

backends/qualcomm/tests/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,14 +1463,14 @@ def forward(self, x):
14631463

14641464

14651465
class MaxPool2d(torch.nn.Module):
1466-
def __init__(self):
1466+
def __init__(self, kernel_size=3, stride=1, padding=1, ceil_mode=True):
14671467
super().__init__()
14681468
self.max_pool2d = torch.nn.MaxPool2d(
1469-
kernel_size=3,
1470-
stride=1,
1471-
padding=1,
1469+
kernel_size=kernel_size,
1470+
stride=stride,
1471+
padding=padding,
14721472
dilation=1,
1473-
ceil_mode=True,
1473+
ceil_mode=ceil_mode,
14741474
)
14751475

14761476
def forward(self, x):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,9 +1403,16 @@ def test_qnn_backend_max_dim(self):
14031403
self.lower_module_and_test_output(module, sample_input)
14041404

14051405
def test_qnn_backend_max_pool2d(self):
1406-
module = MaxPool2d() # noqa: F405
1406+
modules = [
1407+
MaxPool2d(3, 1, 0, True), # noqa: F405
1408+
MaxPool2d(3, 1, 0, False), # noqa: F405
1409+
MaxPool2d(3, 1, 1, True), # noqa: F405
1410+
MaxPool2d(3, 1, 1, False), # noqa: F405
1411+
]
14071412
sample_input = (torch.randn(4, 3, 24, 24),)
1408-
self.lower_module_and_test_output(module, sample_input)
1413+
for i, module in enumerate(modules):
1414+
with self.subTest(i=i):
1415+
self.lower_module_and_test_output(module, sample_input)
14091416

14101417
def test_qnn_backend_max_pool3d(self):
14111418
# NOTE: The pad should be at most half of effective kernel size.
@@ -3661,10 +3668,24 @@ def test_qnn_backend_max_dim(self):
36613668
self.lower_module_and_test_output(module, sample_input)
36623669

36633670
def test_qnn_backend_max_pool2d(self):
3664-
module = MaxPool2d() # noqa: F405
3671+
modules = [
3672+
MaxPool2d(3, 1, 0, True), # noqa: F405
3673+
MaxPool2d(3, 1, 0, False), # noqa: F405
3674+
MaxPool2d(3, 1, 1, True), # noqa: F405
3675+
MaxPool2d(3, 1, 1, False), # noqa: F405
3676+
]
3677+
test_quants = [QuantDtype.use_8a8w, QuantDtype.use_16a4w, QuantDtype.use_16a8w]
36653678
sample_input = (torch.randn(4, 3, 24, 24),)
3666-
module = self.get_qdq_module(module, sample_input)
3667-
self.lower_module_and_test_output(module, sample_input)
3679+
test_pairs = [
3680+
(module, quant_type) # noqa: F405
3681+
for module, quant_type in itertools.product(modules, test_quants)
3682+
]
3683+
for i, (test_module, qtype) in enumerate(test_pairs):
3684+
with self.subTest(i=i):
3685+
qdq_module = self.get_qdq_module(
3686+
test_module, sample_input, quant_dtype=qtype
3687+
)
3688+
self.lower_module_and_test_output(qdq_module, sample_input)
36683689

36693690
def test_qnn_backend_max_pool3d(self):
36703691
# NOTE: The pad should be at most half of effective kernel size.

0 commit comments

Comments
 (0)