25
25
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
26
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
27
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
-
29
28
import numpy as np
29
+ import warnings
30
30
from abc import ABC , abstractmethod
31
31
from onnx import TensorProto , helper
32
32
from qonnx .core .modelwrapper import ModelWrapper
@@ -70,7 +70,7 @@ def _check_compatibility(self):
70
70
@abstractmethod
71
71
def _calculate_act_bias (self ):
72
72
"""Calculate the activation bias,
73
- which is introduced as an Add node behind the MultiTrheshold node.
73
+ which is introduced as an Add node behind the MultiThreshold node.
74
74
"""
75
75
raise NotImplementedError ()
76
76
@@ -82,7 +82,7 @@ def _calculate_thresholds(self):
82
82
@abstractmethod
83
83
def _calculate_act_scale (self ):
84
84
"""Calculate the activation scale,
85
- which is indroduced as a Mul node behind the Add node
85
+ which is introduced as a Mul node behind the Add node
86
86
for the activation bias.
87
87
"""
88
88
raise NotImplementedError ()
@@ -139,6 +139,8 @@ def replace_quant_node(self):
139
139
graph .value_info .append (thresh_tensor )
140
140
model .set_initializer (thresh_tensor .name , thresholds )
141
141
142
+ data_layout = model .get_tensor_layout (n .input [0 ])
143
+
142
144
# Insert MultiThreshold node
143
145
outp_trans_node = helper .make_node (
144
146
"MultiThreshold" ,
@@ -154,10 +156,15 @@ def replace_quant_node(self):
154
156
mt_node = graph .node [running_node_index - 1 ]
155
157
mt_inst = getCustomOp (mt_node )
156
158
159
+ # Inherit the data layout from the input tensor if available
160
+ if data_layout is not None :
161
+ # Convert list to string representation of the data layout
162
+ mt_inst .set_nodeattr ("data_layout" , "" .join (data_layout ))
163
+
157
164
# Set scale and bias
158
165
# If these values are scalar then they can be set as attributes
159
166
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
160
- # behind the MultiTrheshold nodes.
167
+ # behind the MultiThreshold nodes.
161
168
bias_scalar = adder_bias .shape == (1 ,) or len (adder_bias .shape ) == 0
162
169
scale_scalar = mul_scale .shape == (1 ,) or len (mul_scale .shape ) == 0
163
170
if scale_scalar and bias_scalar and self ._q_node .op_type == "BipolarQuant" :
@@ -355,7 +362,7 @@ def _calculate_thresholds(self):
355
362
act_node = self ._model .find_direct_predecessors (self ._q_node )
356
363
act_node = act_node [0 ]
357
364
if act_node .op_type == "Relu" :
358
- # Calculate thersholds , see: https://github.com/Xilinx/brevitas/blob/
365
+ # Calculate thresholds , see: https://github.com/Xilinx/brevitas/blob/
359
366
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
360
367
# onnx/finn/handler/act.py#L21
361
368
num_distinct_values = 2 ** bit_width
@@ -395,8 +402,46 @@ def _calculate_thresholds(self):
395
402
else :
396
403
thresholds [c ][t ] = step / selu_scale
397
404
405
+ # Get the shape of the input (should also be the output) tensor
406
+ # Note: Querying the input is more safe as we do not want to
407
+ # propagate shapes backwards by accident.
408
+ shape = self ._model .get_tensor_shape (self ._q_node .input [0 ]) # noqa
409
+ # First try to consider the tensor layout of the input for
410
+ # determining the number of output channels
411
+ layout = self ._model .get_tensor_layout (self ._q_node .input [0 ])
412
+ # If there is no layout annotation, guess based on rank of the
413
+ # tensor
414
+ # TODO: No support for Rank >= 5
415
+ if layout is None and len (shape ) < 5 :
416
+ # Maps tensor rank to layout annotation
417
+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NCHW" }
418
+ # Lookup the layout required by this input shape
419
+ layout = rank_to_layout [len (shape )]
420
+ # If there is a layout annotation, use this to determine the index
421
+ # of the channel dimension
422
+ if layout is not None and "C" in layout : # noqa: Duplicate
423
+ # Lookup the index in list
424
+ cdim = layout .index ("C" )
425
+ # If no layout has been annotated or there is no channel dimension, fall
426
+ # back to the previous default assumption
427
+ else :
428
+ # Assume the channels to be in axis 1
429
+ cdim = 1
430
+ # Issue a warning to the user, so they are aware of this
431
+ warnings .warn (
432
+ f"No layout annotations for { self ._q_node .input [0 ]} :"
433
+ f" Assuming channel dimension at index { cdim } "
434
+ )
435
+
398
436
# ToDo: The index 1 needs to be changed to -1 for the channels last format
399
- num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[1 ]
437
+ num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[cdim ]
438
+
439
+ assert (
440
+ thresholds .shape [0 ] == 1 or thresholds .shape [
441
+ 0 ] == num_output_channels
442
+ ), """Quant node cannot be converted to MultiThreshold because only
443
+ per tensor or per channel quantization supported."""
444
+
400
445
final_shape = (num_output_channels , num_thresholds )
401
446
if thresholds .shape != final_shape :
402
447
thresholds = np .broadcast_to (thresholds , final_shape )
@@ -417,12 +462,12 @@ def _remove_activation_node(self, multi_threshold_node):
417
462
act_node = self ._model .find_direct_predecessors (self ._q_node )
418
463
if act_node is None :
419
464
raise RuntimeError (
420
- "For handling of Relu activations a predecesor to " "the Quant node must exist."
465
+ "For handling of Relu activations a predecessor to " "the Quant node must exist."
421
466
)
422
467
act_node = act_node [0 ]
423
468
if act_node .op_type not in self .valid_predecessor_op_types ():
424
469
raise RuntimeError (
425
- "The predecesor of the Quant node must be Relu or Selu for handling "
470
+ "The predecessor of the Quant node must be Relu or Selu for handling "
426
471
"of activations."
427
472
)
428
473
@@ -509,7 +554,7 @@ def _calculate_thresholds(self):
509
554
else :
510
555
raise RuntimeError ("Got an unexpected quantizer node type" )
511
556
512
- # Calculate thersholds , see: https://github.com/Xilinx/brevitas/
557
+ # Calculate thresholds , see: https://github.com/Xilinx/brevitas/
513
558
# blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
514
559
# export/onnx/finn/handler/act.py#L76
515
560
if bit_width == 1.0 :
@@ -537,13 +582,49 @@ def _calculate_thresholds(self):
537
582
for t in range (num_thresholds ):
538
583
thresholds [c ][t ] = min_threshold [c ] + step [c ] * t
539
584
540
- # currently only per tensor or per channel quantization is supported
541
- num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[1 ]
585
+ # Get the shape of the input (should also be the output) tensor
586
+ # Note: Querying the input is more safe as we do not want to
587
+ # propagate shapes backwards by accident.
588
+ shape = self ._model .get_tensor_shape (self ._q_node .input [0 ])
589
+ # First try to consider the tensor layout of the input for
590
+ # determining the number of output channels
591
+ layout = self ._model .get_tensor_layout (self ._q_node .input [0 ]) # noqa
592
+ # If there is no layout annotation, guess based on rank of the
593
+ # tensor
594
+ # TODO: No support for Rank >= 5
595
+ if layout is None and len (shape ) < 5 :
596
+ # Maps tensor rank to layout annotation
597
+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NCHW" }
598
+ # Lookup the layout required by this input shape
599
+ layout = rank_to_layout [len (shape )]
600
+ # If there is a layout annotation, use this to determine the index
601
+ # of the channel dimension
602
+ if layout is not None and "C" in layout : # noqa: Duplicate
603
+ # Lookup the index in list
604
+ cdim = layout .index ("C" )
605
+ # If no layout has been annotated or there is no channel dimension,
606
+ # fall back to the previous default assumption
607
+ else :
608
+ # Assume the channels to be in axis 1
609
+ cdim = 1
610
+ # Issue a warning to the user, so they are aware of this
611
+ warnings .warn (
612
+ f"No layout annotations for { self ._q_node .input [0 ]} :"
613
+ f" Assuming channel dimension at index { cdim } "
614
+ )
615
+
616
+ # ToDo: The index 1 needs to be changed to -1 for the channels last format
617
+ num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[cdim ]
618
+
542
619
assert (
543
620
thresholds .shape [0 ] == 1 or thresholds .shape [0 ] == num_output_channels
544
621
), """Quant node cannot be converted to MultiThreshold because only
545
622
per tensor or per channel quantization supported."""
546
623
624
+ final_shape = (num_output_channels , num_thresholds )
625
+ if thresholds .shape != final_shape :
626
+ thresholds = np .broadcast_to (thresholds , final_shape )
627
+
547
628
return thresholds
548
629
549
630
def _calculate_act_scale (self ):
0 commit comments