Skip to content

Commit fc26918

Browse files
authored
Merge pull request #8 from iksnagreb/feature/generalized_multi_threshold_layouts
Make quantized activation handlers data layout aware
2 parents 252af20 + c2905f7 commit fc26918

File tree

1 file changed

+92
-11
lines changed

1 file changed

+92
-11
lines changed

src/finn/transformation/qonnx/qonnx_activation_handlers.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28-
2928
import numpy as np
29+
import warnings
3030
from abc import ABC, abstractmethod
3131
from onnx import TensorProto, helper
3232
from qonnx.core.modelwrapper import ModelWrapper
@@ -70,7 +70,7 @@ def _check_compatibility(self):
7070
@abstractmethod
7171
def _calculate_act_bias(self):
7272
"""Calculate the activation bias,
73-
which is introduced as an Add node behind the MultiTrheshold node.
73+
which is introduced as an Add node behind the MultiThreshold node.
7474
"""
7575
raise NotImplementedError()
7676

@@ -82,7 +82,7 @@ def _calculate_thresholds(self):
8282
@abstractmethod
8383
def _calculate_act_scale(self):
8484
"""Calculate the activation scale,
85-
which is indroduced as a Mul node behind the Add node
85+
which is introduced as a Mul node behind the Add node
8686
for the activation bias.
8787
"""
8888
raise NotImplementedError()
@@ -139,6 +139,8 @@ def replace_quant_node(self):
139139
graph.value_info.append(thresh_tensor)
140140
model.set_initializer(thresh_tensor.name, thresholds)
141141

142+
data_layout = model.get_tensor_layout(n.input[0])
143+
142144
# Insert MultiThreshold node
143145
outp_trans_node = helper.make_node(
144146
"MultiThreshold",
@@ -154,10 +156,15 @@ def replace_quant_node(self):
154156
mt_node = graph.node[running_node_index - 1]
155157
mt_inst = getCustomOp(mt_node)
156158

159+
# Inherit the data layout from the input tensor if available
160+
if data_layout is not None:
161+
# Convert list to string representation of the data layout
162+
mt_inst.set_nodeattr("data_layout", "".join(data_layout))
163+
157164
# Set scale and bias
158165
# If these values are scalar then they can be set as attributes
159166
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
160-
# behind the MultiTrheshold nodes.
167+
# behind the MultiThreshold nodes.
161168
bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
162169
scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
163170
if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant":
@@ -355,7 +362,7 @@ def _calculate_thresholds(self):
355362
act_node = self._model.find_direct_predecessors(self._q_node)
356363
act_node = act_node[0]
357364
if act_node.op_type == "Relu":
358-
# Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/
365+
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/blob/
359366
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
360367
# onnx/finn/handler/act.py#L21
361368
num_distinct_values = 2**bit_width
@@ -395,8 +402,46 @@ def _calculate_thresholds(self):
395402
else:
396403
thresholds[c][t] = step / selu_scale
397404

405+
# Get the shape of the input (should also be the output) tensor
406+
# Note: Querying the input is more safe as we do not want to
407+
# propagate shapes backwards by accident.
408+
shape = self._model.get_tensor_shape(self._q_node.input[0]) # noqa
409+
# First try to consider the tensor layout of the input for
410+
# determining the number of output channels
411+
layout = self._model.get_tensor_layout(self._q_node.input[0])
412+
# If there is no layout annotation, guess based on rank of the
413+
# tensor
414+
# TODO: No support for Rank >= 5
415+
if layout is None and len(shape) < 5:
416+
# Maps tensor rank to layout annotation
417+
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"}
418+
# Lookup the layout required by this input shape
419+
layout = rank_to_layout[len(shape)]
420+
# If there is a layout annotation, use this to determine the index
421+
# of the channel dimension
422+
if layout is not None and "C" in layout: # noqa: Duplicate
423+
# Lookup the index in list
424+
cdim = layout.index("C")
425+
# If no layout has been annotated or there is no channel dimension, fall
426+
# back to the previous default assumption
427+
else:
428+
# Assume the channels to be in axis 1
429+
cdim = 1
430+
# Issue a warning to the user, so they are aware of this
431+
warnings.warn(
432+
f"No layout annotations for {self._q_node.input[0]}:"
433+
f" Assuming channel dimension at index {cdim}"
434+
)
435+
398436
# ToDo: The index 1 needs to be changed to -1 for the channels last format
399-
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
437+
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]
438+
439+
assert (
440+
thresholds.shape[0] == 1 or thresholds.shape[
441+
0] == num_output_channels
442+
), """Quant node cannot be converted to MultiThreshold because only
443+
per tensor or per channel quantization supported."""
444+
400445
final_shape = (num_output_channels, num_thresholds)
401446
if thresholds.shape != final_shape:
402447
thresholds = np.broadcast_to(thresholds, final_shape)
@@ -417,12 +462,12 @@ def _remove_activation_node(self, multi_threshold_node):
417462
act_node = self._model.find_direct_predecessors(self._q_node)
418463
if act_node is None:
419464
raise RuntimeError(
420-
"For handling of Relu activations a predecesor to " "the Quant node must exist."
465+
"For handling of Relu activations a predecessor to " "the Quant node must exist."
421466
)
422467
act_node = act_node[0]
423468
if act_node.op_type not in self.valid_predecessor_op_types():
424469
raise RuntimeError(
425-
"The predecesor of the Quant node must be Relu or Selu for handling "
470+
"The predecessor of the Quant node must be Relu or Selu for handling "
426471
"of activations."
427472
)
428473

@@ -509,7 +554,7 @@ def _calculate_thresholds(self):
509554
else:
510555
raise RuntimeError("Got an unexpected quantizer node type")
511556

512-
# Calculate thersholds, see: https://github.com/Xilinx/brevitas/
557+
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/
513558
# blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
514559
# export/onnx/finn/handler/act.py#L76
515560
if bit_width == 1.0:
@@ -537,13 +582,49 @@ def _calculate_thresholds(self):
537582
for t in range(num_thresholds):
538583
thresholds[c][t] = min_threshold[c] + step[c] * t
539584

540-
# currently only per tensor or per channel quantization is supported
541-
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
585+
# Get the shape of the input (should also be the output) tensor
586+
# Note: Querying the input is more safe as we do not want to
587+
# propagate shapes backwards by accident.
588+
shape = self._model.get_tensor_shape(self._q_node.input[0])
589+
# First try to consider the tensor layout of the input for
590+
# determining the number of output channels
591+
layout = self._model.get_tensor_layout(self._q_node.input[0]) # noqa
592+
# If there is no layout annotation, guess based on rank of the
593+
# tensor
594+
# TODO: No support for Rank >= 5
595+
if layout is None and len(shape) < 5:
596+
# Maps tensor rank to layout annotation
597+
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"}
598+
# Lookup the layout required by this input shape
599+
layout = rank_to_layout[len(shape)]
600+
# If there is a layout annotation, use this to determine the index
601+
# of the channel dimension
602+
if layout is not None and "C" in layout: # noqa: Duplicate
603+
# Lookup the index in list
604+
cdim = layout.index("C")
605+
# If no layout has been annotated or there is no channel dimension,
606+
# fall back to the previous default assumption
607+
else:
608+
# Assume the channels to be in axis 1
609+
cdim = 1
610+
# Issue a warning to the user, so they are aware of this
611+
warnings.warn(
612+
f"No layout annotations for {self._q_node.input[0]}:"
613+
f" Assuming channel dimension at index {cdim}"
614+
)
615+
616+
# ToDo: The index 1 needs to be changed to -1 for the channels last format
617+
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]
618+
542619
assert (
543620
thresholds.shape[0] == 1 or thresholds.shape[0] == num_output_channels
544621
), """Quant node cannot be converted to MultiThreshold because only
545622
per tensor or per channel quantization supported."""
546623

624+
final_shape = (num_output_channels, num_thresholds)
625+
if thresholds.shape != final_shape:
626+
thresholds = np.broadcast_to(thresholds, final_shape)
627+
547628
return thresholds
548629

549630
def _calculate_act_scale(self):

0 commit comments

Comments
 (0)