27
27
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
28
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
29
30
+ import math
30
31
import numpy as np
32
+ import warnings
31
33
from qonnx .core .datatype import DataType
32
34
from qonnx .util .basic import roundup_to_integer_multiple
33
35
36
38
37
39
class StreamingConcat (HWCustomOp ):
38
40
"""Abstraction layer for HW implementation of Concat.
39
- Only supports concatenating along the last axis."""
41
+ Only supports concatenating along the last (channel) axis."""
40
42
41
43
def __init__ (self , onnx_node , ** kwargs ):
42
44
super ().__init__ (onnx_node , ** kwargs )
43
45
44
46
def get_nodeattr_types (self ):
45
47
my_attrs = {
48
+ "SIMD" : ("i" , True , 0 ),
46
49
# number of elements from each stream to concat
47
- "ElemsPerStream " : ("ints" , True , []),
48
- # FINN DataTypes for inputs; output datatype inferred from input
49
- "inputDataType " : ("s " , True , "" ),
50
+ "ChannelsPerStream " : ("ints" , True , []),
51
+ # FINN DataTypes for inputs; output datatype inferred from inputs
52
+ "inputDataTypes " : ("strings " , True , [ "" ] ),
50
53
# number of input vectors for non-concat axes, examples:
51
54
# [1] is a single vector (like a FC layer with batch=1)
52
55
# [4] is four vectors (like a FC layer with batch=4)
@@ -57,29 +60,36 @@ def get_nodeattr_types(self):
57
60
return my_attrs
58
61
59
62
def get_n_inputs (self ):
60
- return len (self .get_nodeattr ("ElemsPerStream " ))
63
+ return len (self .get_nodeattr ("ChannelsPerStream " ))
61
64
62
65
def get_total_elems (self ):
63
- elems_per_stream = self .get_nodeattr ("ElemsPerStream " )
66
+ elems_per_stream = self .get_nodeattr ("ChannelsPerStream " )
64
67
return int (np .sum (elems_per_stream ))
65
68
66
69
def get_normal_input_shape (self , ind = 0 ):
67
- elems_per_stream = self .get_nodeattr ("ElemsPerStream " )
70
+ elems_per_stream = self .get_nodeattr ("ChannelsPerStream " )
68
71
elems = elems_per_stream [ind ]
69
72
vecs = list (self .get_nodeattr ("numInputVectors" ))
70
73
ishape = tuple (vecs + [elems ])
71
74
return ishape
72
75
73
76
def get_folded_input_shape (self , ind = 0 ):
74
- return self .get_normal_input_shape (ind )
77
+ simd = self .get_nodeattr ("SIMD" )
78
+ folds = self .get_nodeattr ("ChannelsPerStream" )[ind ] // simd
79
+ vecs = list (self .get_nodeattr ("numInputVectors" ))
80
+ return tuple (vecs + [folds , simd ])
75
81
76
82
def get_normal_output_shape (self , ind = 0 ):
77
83
total_elems = self .get_total_elems ()
78
84
vecs = list (self .get_nodeattr ("numInputVectors" ))
79
85
return tuple (vecs + [total_elems ])
80
86
81
87
def get_folded_output_shape (self , ind = 0 ):
82
- return self .get_normal_output_shape ()
88
+ total_elems = self .get_total_elems ()
89
+ simd = self .get_nodeattr ("SIMD" )
90
+ folds = total_elems // simd
91
+ vecs = list (self .get_nodeattr ("numInputVectors" ))
92
+ return tuple (vecs + [folds , simd ])
83
93
84
94
def make_shape_compatible_op (self , model ):
85
95
# check all input shapes
@@ -94,7 +104,16 @@ def infer_node_datatype(self, model):
94
104
# check all input datatypes
95
105
for i , inp in enumerate (self .onnx_node .input ):
96
106
idt = model .get_tensor_datatype (inp )
97
- assert idt == self .get_input_datatype ()
107
+ if idt != self .get_input_datatype (i ):
108
+ warn_str = "inputDataType changing for %s: %s -> %s " % (
109
+ self .onnx_node .name ,
110
+ str (self .get_input_datatype (i )),
111
+ str (idt ),
112
+ )
113
+ warnings .warn (warn_str )
114
+ old_datatypes_attr = self .get_nodeattr ("inputDataTypes" )
115
+ old_datatypes_attr [i ] = idt .name
116
+ self .set_nodeattr ("inputDataTypes" , old_datatypes_attr )
98
117
odt = self .get_output_datatype ()
99
118
model .set_tensor_datatype (self .onnx_node .output [0 ], odt )
100
119
@@ -103,21 +122,37 @@ def verify_node(self):
103
122
104
123
def get_input_datatype (self , ind = 0 ):
105
124
# input dt identical for all inputs
106
- return DataType [self .get_nodeattr ("inputDataType" ) ]
125
+ return DataType [self .get_nodeattr ("inputDataTypes" )[ ind ] ]
107
126
108
127
def get_output_datatype (self , ind = 0 ):
109
- return self .get_input_datatype ()
128
+ # infer output datatype from declared inputDataTypes
129
+ min_input = 0
130
+ max_input = 0
131
+ for i in range (len (self .get_nodeattr ("inputDataTypes" ))):
132
+ idt = self .get_input_datatype (i )
133
+ if idt .min () < min_input :
134
+ min_input = idt .min ()
135
+ if idt .max () > max_input :
136
+ max_input = idt .max ()
137
+ # if the input range is always greater than 0, then acc_max <= 2^P - 1
138
+ if min_input >= 0 :
139
+ out_bit_width = math .ceil (np .log2 (max_input + 1 ))
140
+ odt = DataType [f"UINT{ out_bit_width } " ]
141
+ # if the input range is signed, then acc_min >= -2^{P-1} and acc_max <=
142
+ # 2^{P - 1} - 1, which means 2^{P - 1} >= max(-acc_min, 1 + acc_max)
143
+ else :
144
+ max_abs_input = max (- min_input , 1 + max_input )
145
+ out_bit_width = math .ceil (np .log2 (max_abs_input ) + 1 )
146
+ odt = DataType [f"INT{ out_bit_width } " ]
147
+ return odt
110
148
111
149
def get_instream_width (self , ind = 0 ):
112
- elems_per_stream = self .get_nodeattr ("ElemsPerStream" )
113
- elems = elems_per_stream [ind ]
114
- ibits = self .get_input_datatype ().bitwidth ()
115
- return elems * ibits
150
+ ibits = self .get_input_datatype (ind ).bitwidth ()
151
+ return ibits * self .get_nodeattr ("SIMD" )
116
152
117
153
def get_outstream_width (self , ind = 0 ):
118
154
obits = self .get_output_datatype ().bitwidth ()
119
- total_elems = self .get_total_elems ()
120
- out_width = total_elems * obits
155
+ out_width = obits * self .get_nodeattr ("SIMD" )
121
156
return out_width
122
157
123
158
def get_number_output_values (self ):
0 commit comments