Skip to content

Commit 252af20

Browse files
authored
Merge pull request #20 from eki-project/feature/split-concat
Integrate Split and Concat Operators
2 parents 64282e5 + 8a70060 commit 252af20

17 files changed

+1705
-539
lines changed

src/finn/custom_op/fpgadataflow/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from finn.custom_op.fpgadataflow.lookup import Lookup
4343
from finn.custom_op.fpgadataflow.matrixvectoractivation import MVAU
4444
from finn.custom_op.fpgadataflow.pool import Pool
45+
from finn.custom_op.fpgadataflow.split import StreamingSplit
4546
from finn.custom_op.fpgadataflow.streamingdataflowpartition import (
4647
StreamingDataflowPartition,
4748
)
@@ -77,6 +78,7 @@
7778
custom_op["Lookup"] = Lookup
7879
custom_op["Pool"] = Pool
7980
custom_op["StreamingConcat"] = StreamingConcat
81+
custom_op["StreamingSplit"] = StreamingSplit
8082
custom_op["StreamingDataWidthConverter"] = StreamingDataWidthConverter
8183
custom_op["StreamingEltwise"] = StreamingEltwise
8284
custom_op["StreamingMaxPool"] = StreamingMaxPool

src/finn/custom_op/fpgadataflow/concat.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2828
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30+
import math
3031
import numpy as np
32+
import warnings
3133
from qonnx.core.datatype import DataType
3234
from qonnx.util.basic import roundup_to_integer_multiple
3335

@@ -36,17 +38,18 @@
3638

3739
class StreamingConcat(HWCustomOp):
3840
"""Abstraction layer for HW implementation of Concat.
39-
Only supports concatenating along the last axis."""
41+
Only supports concatenating along the last (channel) axis."""
4042

4143
def __init__(self, onnx_node, **kwargs):
4244
super().__init__(onnx_node, **kwargs)
4345

4446
def get_nodeattr_types(self):
4547
my_attrs = {
48+
"SIMD": ("i", True, 0),
4649
# number of elements from each stream to concat
47-
"ElemsPerStream": ("ints", True, []),
48-
# FINN DataTypes for inputs; output datatype inferred from input
49-
"inputDataType": ("s", True, ""),
50+
"ChannelsPerStream": ("ints", True, []),
51+
# FINN DataTypes for inputs; output datatype inferred from inputs
52+
"inputDataTypes": ("strings", True, [""]),
5053
# number of input vectors for non-concat axes, examples:
5154
# [1] is a single vector (like a FC layer with batch=1)
5255
# [4] is four vectors (like a FC layer with batch=4)
@@ -57,29 +60,36 @@ def get_nodeattr_types(self):
5760
return my_attrs
5861

5962
def get_n_inputs(self):
60-
return len(self.get_nodeattr("ElemsPerStream"))
63+
return len(self.get_nodeattr("ChannelsPerStream"))
6164

6265
def get_total_elems(self):
63-
elems_per_stream = self.get_nodeattr("ElemsPerStream")
66+
elems_per_stream = self.get_nodeattr("ChannelsPerStream")
6467
return int(np.sum(elems_per_stream))
6568

6669
def get_normal_input_shape(self, ind=0):
67-
elems_per_stream = self.get_nodeattr("ElemsPerStream")
70+
elems_per_stream = self.get_nodeattr("ChannelsPerStream")
6871
elems = elems_per_stream[ind]
6972
vecs = list(self.get_nodeattr("numInputVectors"))
7073
ishape = tuple(vecs + [elems])
7174
return ishape
7275

7376
def get_folded_input_shape(self, ind=0):
74-
return self.get_normal_input_shape(ind)
77+
simd = self.get_nodeattr("SIMD")
78+
folds = self.get_nodeattr("ChannelsPerStream")[ind] // simd
79+
vecs = list(self.get_nodeattr("numInputVectors"))
80+
return tuple(vecs + [folds, simd])
7581

7682
def get_normal_output_shape(self, ind=0):
7783
total_elems = self.get_total_elems()
7884
vecs = list(self.get_nodeattr("numInputVectors"))
7985
return tuple(vecs + [total_elems])
8086

8187
def get_folded_output_shape(self, ind=0):
82-
return self.get_normal_output_shape()
88+
total_elems = self.get_total_elems()
89+
simd = self.get_nodeattr("SIMD")
90+
folds = total_elems // simd
91+
vecs = list(self.get_nodeattr("numInputVectors"))
92+
return tuple(vecs + [folds, simd])
8393

8494
def make_shape_compatible_op(self, model):
8595
# check all input shapes
@@ -94,7 +104,16 @@ def infer_node_datatype(self, model):
94104
# check all input datatypes
95105
for i, inp in enumerate(self.onnx_node.input):
96106
idt = model.get_tensor_datatype(inp)
97-
assert idt == self.get_input_datatype()
107+
if idt != self.get_input_datatype(i):
108+
warn_str = "inputDataType changing for %s: %s -> %s " % (
109+
self.onnx_node.name,
110+
str(self.get_input_datatype(i)),
111+
str(idt),
112+
)
113+
warnings.warn(warn_str)
114+
old_datatypes_attr = self.get_nodeattr("inputDataTypes")
115+
old_datatypes_attr[i] = idt.name
116+
self.set_nodeattr("inputDataTypes", old_datatypes_attr)
98117
odt = self.get_output_datatype()
99118
model.set_tensor_datatype(self.onnx_node.output[0], odt)
100119

@@ -103,21 +122,37 @@ def verify_node(self):
103122

104123
def get_input_datatype(self, ind=0):
105124
# input dt identical for all inputs
106-
return DataType[self.get_nodeattr("inputDataType")]
125+
return DataType[self.get_nodeattr("inputDataTypes")[ind]]
107126

108127
def get_output_datatype(self, ind=0):
109-
return self.get_input_datatype()
128+
# infer output datatype from declared inputDataTypes
129+
min_input = 0
130+
max_input = 0
131+
for i in range(len(self.get_nodeattr("inputDataTypes"))):
132+
idt = self.get_input_datatype(i)
133+
if idt.min() < min_input:
134+
min_input = idt.min()
135+
if idt.max() > max_input:
136+
max_input = idt.max()
137+
# if the input range is always greater than 0, then acc_max <= 2^P - 1
138+
if min_input >= 0:
139+
out_bit_width = math.ceil(np.log2(max_input + 1))
140+
odt = DataType[f"UINT{out_bit_width}"]
141+
# if the input range is signed, then acc_min >= -2^{P-1} and acc_max <=
142+
# 2^{P - 1} - 1, which means 2^{P - 1} >= max(-acc_min, 1 + acc_max)
143+
else:
144+
max_abs_input = max(-min_input, 1 + max_input)
145+
out_bit_width = math.ceil(np.log2(max_abs_input) + 1)
146+
odt = DataType[f"INT{out_bit_width}"]
147+
return odt
110148

111149
def get_instream_width(self, ind=0):
112-
elems_per_stream = self.get_nodeattr("ElemsPerStream")
113-
elems = elems_per_stream[ind]
114-
ibits = self.get_input_datatype().bitwidth()
115-
return elems * ibits
150+
ibits = self.get_input_datatype(ind).bitwidth()
151+
return ibits * self.get_nodeattr("SIMD")
116152

117153
def get_outstream_width(self, ind=0):
118154
obits = self.get_output_datatype().bitwidth()
119-
total_elems = self.get_total_elems()
120-
out_width = total_elems * obits
155+
out_width = obits * self.get_nodeattr("SIMD")
121156
return out_width
122157

123158
def get_number_output_values(self):

src/finn/custom_op/fpgadataflow/hls/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from finn.custom_op.fpgadataflow.hls.lookup_hls import Lookup_hls
4444
from finn.custom_op.fpgadataflow.hls.matrixvectoractivation_hls import MVAU_hls
4545
from finn.custom_op.fpgadataflow.hls.pool_hls import Pool_hls
46+
from finn.custom_op.fpgadataflow.hls.split_hls import StreamingSplit_hls
4647
from finn.custom_op.fpgadataflow.hls.streamingdatawidthconverter_hls import (
4748
StreamingDataWidthConverter_hls,
4849
)
@@ -71,6 +72,7 @@
7172
custom_op["Lookup_hls"] = Lookup_hls
7273
custom_op["Pool_hls"] = Pool_hls
7374
custom_op["StreamingConcat_hls"] = StreamingConcat_hls
75+
custom_op["StreamingSplit_hls"] = StreamingSplit_hls
7476
custom_op["StreamingEltwise_hls"] = StreamingEltwise_hls
7577
custom_op["StreamingDataWidthConverter_hls"] = StreamingDataWidthConverter_hls
7678
custom_op["StreamingMaxPool_hls"] = StreamingMaxPool_hls

0 commit comments

Comments
 (0)