Skip to content

Commit c966b46

Browse files
authored
Merge pull request #78 from iksnagreb/fix/transpose_into_quant
Fix FoldTransposeIntoQuantInit Transformation
2 parents e62517a + 0351d9e commit c966b46

File tree

2 files changed

+217
-41
lines changed

2 files changed

+217
-41
lines changed

src/qonnx/transformation/quant_constant_folding.py

Lines changed: 75 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,57 +26,91 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29-
import warnings
29+
# Protobuf onnx graph node type
30+
from onnx import NodeProto
3031

32+
# QONNX wrapper of ONNX model graphs
33+
from qonnx.core.modelwrapper import ModelWrapper
34+
35+
# QONNX graph transformations base class
3136
from qonnx.transformation.base import Transformation
37+
38+
# Gets items from protobuf by name
3239
from qonnx.util.basic import get_by_name
3340

3441

42+
# Tests whether a node is a quant-init, i.e., a quantizer with only initializer
43+
# inputs
44+
def is_quant_init(node: NodeProto, model: ModelWrapper):
45+
# Only handle existing Quant or BipolarQuant type nodes
46+
if node is not None and node.op_type in {"Quant", "BipolarQuant"}:
47+
# All inputs must have initializers, otherwise this is just a normal
48+
# quant, but not a quant-init
49+
return all(model.get_initializer(i) is not None for i in node.input)
50+
# Did not match the operator type
51+
return False
52+
53+
54+
# Transpose nodes can be folded into quantized initializers, i.e., Quant nodes
55+
# where *all* inputs are initializers. Initializers are constants and part of
56+
# the model graph and thus can be transposed offline.
3557
class FoldTransposeIntoQuantInit(Transformation):
3658
"""
37-
Fueses a Transpose node into the initalizer of a Quant node.
59+
Fuses a Transpose node into the initializers of a Quant node.
3860
"""
3961

40-
def apply(self, model):
62+
# Applies the transform to a whole model graph
63+
def apply(self, model: ModelWrapper):
64+
# Get the model graph out of the model wrapper object
4165
graph = model.graph
42-
node_ind = 0
66+
# Keep track of whether the graph has been modified
4367
graph_modified = False
44-
# Find transpose nodes, which have Quant node with initilizer upstream.
45-
for n in graph.node:
46-
node_ind += 1
47-
if n.op_type == "Transpose":
48-
predecessors = model.find_direct_predecessors(n)
49-
# Check if we reached the top of the graph
50-
if predecessors is None:
68+
# Iterate all nodes in the graph keeping track of the index
69+
for index, node in enumerate(graph.node):
70+
# This transformation is triggered by finding a Transpose node
71+
if node.op_type == "Transpose":
72+
# Get the predecessors feeding into the transpose node
73+
predecessors = model.find_direct_predecessors(node)
74+
# The transform applies only to transpose with exactly one input
75+
if predecessors is None or len(predecessors) != 1:
76+
# Note: Softly skip this node, maybe consider a hard failure
77+
# at least in case there are multiple inputs?
5178
continue
52-
predecessor = predecessors[0]
53-
if predecessor.op_type == "Quant" or predecessor.op_type == "BipolarQuant":
54-
for inp in predecessor.input:
55-
if not isinstance(model.get_initializer(inp), type(None)):
56-
# Explicitly apply the transpose to the initializers
57-
# of the previous node
58-
target_tensor = model.get_initializer(inp)
59-
if target_tensor is None:
60-
warnings.warn(
61-
f"Cannot fold transpose {n} into Quant/BipolarQuant node {predecessor}, "
62-
f"due to not initialized tensor: {inp}. "
63-
f"Exiting FoldTransposeIntoQuantInit transformation."
64-
)
65-
return model, False
66-
# Make sure the tensor has the correct shape
67-
perm = get_by_name(n.attribute, "perm")
68-
if perm is None:
69-
target_tensor = target_tensor.transpose()
70-
model.set_initializer(inp, target_tensor)
71-
graph_modified = True
72-
elif len(perm.ints) == len(target_tensor.shape):
73-
target_tensor = target_tensor.transpose(perm.ints)
74-
model.set_initializer(inp, target_tensor)
75-
graph_modified = True
76-
# Reconnect predecessor and delete transpose node
77-
predecessor.output[0] = n.output[0]
78-
graph.node.remove(n)
79-
80-
return model, graph_modified
81-
79+
# Check whether the predecessor is a quantizer with only
80+
# initializer inputs
81+
if is_quant_init(predecessors[0], model):
82+
# Alias to the single predecessor node
83+
quant_init = predecessors[0]
84+
# Get the (optional) permutation indices of the transpose in
85+
# case it is a multi-axis transpose
86+
perm = get_by_name(node.attribute, "perm")
87+
# Convert permutation indices to list of integers if it is
88+
# given
89+
perm = perm.ints if perm is not None else None
90+
# Transpose all(!) initializer inputs of the quant node
91+
for i in quant_init.input:
92+
# Get the initializer tensor
93+
# Note: No need to validate the presence of the
94+
# initializer here, as we already tested this as the
95+
# applicability condition above
96+
tensor = model.get_initializer(i)
97+
# Skip transposing the initializer if the number of
98+
# dimensions do not match
99+
if perm is not None and len(perm) != tensor.ndim:
100+
# Note: Soft skip ok or is this an error?
101+
continue
102+
# Transpose the tensor, optionally according to the
103+
# permutation indices (perm might be None)
104+
tensor = tensor.transpose(perm)
105+
# Reassign the transposed initializer tensor
106+
model.set_initializer(i, tensor)
107+
# The graph has been modified, this needs to be reported
108+
# back to the caller
109+
graph_modified = True
110+
# Rewire the graph to skip the transpose node
111+
quant_init.output[0] = node.output[0]
112+
# Remove the now absorbed transpose node
113+
graph.node.remove(node)
114+
# Return the transformed model and indicate whether the graph actually
115+
# has been transformed
82116
return model, graph_modified
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Set pytest parameters
2+
import pytest
3+
4+
# Numpy for handling simulation of tensor operations
5+
import numpy as np
6+
7+
# Helper for creating ONNX nodes
8+
from onnx import NodeProto, TensorProto # noqa
9+
from onnx import helper as oh # noqa
10+
11+
# QONNX wrapper of ONNX model graphs
12+
from qonnx.core.modelwrapper import ModelWrapper # noqa
13+
14+
# Execute QONNX model graphs
15+
from qonnx.core.onnx_exec import execute_onnx # noqa
16+
17+
# QONNX quantizer function modeling the behavior of the Quant operator
18+
from qonnx.custom_op.general.quant import quant as quant_fn # noqa
19+
20+
# QONNX graph transformations for inferring datatypes and shapes required by
21+
# test setup
22+
from qonnx.transformation.infer_datatypes import InferDataTypes # noqa
23+
from qonnx.transformation.infer_shapes import InferShapes # noqa
24+
25+
# Graph transformation to be tested: Transposes the initializers to Quantizer if
26+
# ALL inputs are initializers
27+
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit # noqa
28+
29+
# QONNX utility for creating models from ONNX graphs
30+
from qonnx.util.basic import qonnx_make_model # noqa
31+
32+
33+
@pytest.mark.parametrize("quant_init", [True, False])
34+
@pytest.mark.parametrize("signed", [0, 1])
35+
@pytest.mark.parametrize("narrow", [0, 1])
36+
@pytest.mark.parametrize("rounding_mode", ["ROUND"])
37+
@pytest.mark.parametrize("shape", [(16, 8, 12)])
38+
@pytest.mark.parametrize(
39+
"perm",
40+
[
41+
# All axis permutations
42+
(0, 1, 2),
43+
(0, 2, 1),
44+
(1, 0, 2),
45+
(1, 2, 0),
46+
(2, 0, 1),
47+
(2, 1, 0),
48+
],
49+
)
50+
@pytest.mark.parametrize("scale", [0.01])
51+
@pytest.mark.parametrize("zeropoint", [0])
52+
@pytest.mark.parametrize("bitwidth", [8])
53+
# Tests the FoldTransposeIntoQuantInit transformation
54+
def test_fold_transpose_into_quant_init(quant_init, signed, narrow, rounding_mode, shape, perm, scale, zeropoint, bitwidth):
55+
# Prepare the quantizer node attributes and input/output lists
56+
node_attrs = {
57+
# Type of the operation
58+
"op_type": "Quant",
59+
# This operator type is defined within QONNX
60+
"domain": "qonnx.custom_op.general",
61+
# List the inputs to the operator in order
62+
# Note: The proper input followed by initializers configuring the
63+
# quantizer
64+
"inputs": ["input", "scale", "zeropoint", "bitwidth"],
65+
# List the outputs of the operator in order
66+
# Note: Intermediate feeds to the next operator input
67+
"outputs": ["intermediate"],
68+
# Whether the quantization interval should be signed or not
69+
# (e.g. at 8b unsigned=[0, 255] vs signed=[-128, 127])
70+
"signed": signed,
71+
# When signed=1, whether to use narrow range or not
72+
# (e.g. at 8b regular=[-128, 127] vs narrow=[-127, 127])
73+
"narrow": narrow,
74+
# The rounding mode, which is used for the quant function
75+
"rounding_mode": rounding_mode,
76+
}
77+
# Create a dummy quantizer node
78+
quant = oh.make_node(**node_attrs, name="Quant")
79+
# Attach a Transpose operation to the quantizer
80+
transpose = oh.make_node("Transpose", ["intermediate"], ["output"], name="Transpose", perm=perm)
81+
# Derive the transposed shape
82+
transposed_shape = np.transpose(np.zeros(shape), perm).shape
83+
# Create tensor information for the input, intermediate and output tensor
84+
x = oh.make_tensor_value_info("input", TensorProto.FLOAT, shape) # noqa
85+
y = oh.make_tensor_value_info("output", TensorProto.FLOAT, transposed_shape)
86+
# Create the initializer tensors for quantizer parameters
87+
s = oh.make_tensor_value_info("scale", TensorProto.FLOAT, (1,))
88+
z = oh.make_tensor_value_info("zeropoint", TensorProto.FLOAT, (1,))
89+
b = oh.make_tensor_value_info("bitwidth", TensorProto.FLOAT, (1,))
90+
# Create the graph connecting the nodes and tensors
91+
graph = oh.make_graph(
92+
[quant, transpose],
93+
"quant-transpose",
94+
[x, s, z, b],
95+
[y],
96+
)
97+
# Wrap the graph in an QONNX model wrapper
98+
model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests"))
99+
# Add the actual initializers to the initializer tensors
100+
model.set_initializer("scale", np.array(scale))
101+
model.set_initializer("zeropoint", np.array(zeropoint))
102+
model.set_initializer("bitwidth", np.array(bitwidth))
103+
# Prepare the model graph by inferring all missing shape and datatype
104+
# information
105+
model = model.transform(InferShapes())
106+
model = model.transform(InferDataTypes())
107+
108+
# Get a random dummy input for testing
109+
x = np.random.rand(*shape) # noqa
110+
# Fill the execution context with dummy input data
111+
context = {"input": x}
112+
113+
# Some test cases even turn the input into an initializer
114+
if quant_init:
115+
# Turn the model input into an initializer
116+
model.set_initializer("input", x)
117+
# Clear the execution context removing the input as it is now baked into
118+
# the model graph
119+
context = {}
120+
121+
# Run the transformation to be tested
122+
model = model.transform(FoldTransposeIntoQuantInit())
123+
# Verify that shape and datatype inference still works
124+
# Note: This has been an issue, please see
125+
# https://github.com/fastmachinelearning/qonnx/issues/77
126+
model = model.transform(InferShapes())
127+
model = model.transform(InferDataTypes())
128+
129+
# For the case of quant-initializers there must not be a Transpose left
130+
# after transforming and contrariwise, the Transpose must remain in place if
131+
# there is non-initializer input.
132+
assert quant_init != ("Transpose" in [n.op_type for n in model.graph.node])
133+
134+
# Execute the ONNX model
135+
o_produced = execute_onnx(model, context)["output"]
136+
# Use numpy and QONNX quantizer to generate expectation
137+
o_expected = np.transpose(
138+
quant_fn(x, np.array(scale), np.array(zeropoint), np.array(bitwidth), signed, narrow, rounding_mode), perm
139+
)
140+
141+
# The output must match the "manual" execution using numpy
142+
assert np.allclose(o_produced, o_expected)

0 commit comments

Comments
 (0)