Skip to content

Commit f3690c8

Browse files
authored
Merge pull request #156 from fastmachinelearning/feature/fp16_fixes
Fixes for fp16 tensors and datatypes in ONNX
2 parents d9269a9 + a27b4a6 commit f3690c8

File tree

6 files changed

+64
-27
lines changed

6 files changed

+64
-27
lines changed

src/qonnx/core/datatype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def is_fixed_point(self):
168168
return False
169169

170170
def get_hls_datatype_str(self):
171-
return "float"
171+
return "half"
172172

173173
def to_numpy_dt(self):
174174
return np.float16

src/qonnx/core/modelwrapper.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,30 @@ def get_tensor_datatype(self, tensor_name):
183183
ret = util.get_by_name(ret.quant_parameter_tensor_names, "finn_datatype", "key")
184184
if ret is not None:
185185
return DataType[ret.value]
186-
# TODO maybe use native ONNX tensor type instead of assuming fp32?
187-
return DataType["FLOAT32"]
186+
onnx_dtype_to_qonnx_dtype = {
187+
TensorProto.FLOAT: "FLOAT32",
188+
TensorProto.FLOAT16: "FLOAT16",
189+
# TODO: dtypes below need testing to ensure they do not break FINN,
190+
# since it normally assumes float32 containers for these dtypes
191+
# TensorProto.UINT8 : "UINT8",
192+
# TensorProto.INT8 : "INT8",
193+
# TensorProto.UINT16 : "UINT16",
194+
# TensorProto.INT16 : "INT16",
195+
# TensorProto.UINT32 : "UINT32",
196+
# TensorProto.INT32 : "INT32",
197+
# TensorProto.UINT64 : "UINT64",
198+
# TensorProto.INT64 : "INT64",
199+
}
200+
tensor_vi = self.get_tensor_valueinfo(tensor_name)
201+
if tensor_vi is None:
202+
# some initialized tensors don't get ValueInfo even after shape inference
203+
_, onnx_dtype = self.get_initializer(tensor_name, return_dtype=True)
204+
else:
205+
onnx_dtype = tensor_vi.type.tensor_type.elem_type
206+
if onnx_dtype in onnx_dtype_to_qonnx_dtype.keys():
207+
return DataType[onnx_dtype_to_qonnx_dtype[onnx_dtype]]
208+
else:
209+
return DataType["FLOAT32"]
188210

189211
def set_tensor_datatype(self, tensor_name, datatype):
190212
"""Sets the QONNX DataType of tensor with given name."""

src/qonnx/custom_op/general/quant.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,23 @@ def get_nodeattr_types(self):
172172
def make_shape_compatible_op(self, model):
173173
"""Returns a standard ONNX op which is compatible with this CustomOp
174174
for performing shape inference."""
175-
return helper.make_node(
176-
"Cast",
177-
inputs=[self.onnx_node.input[0]],
178-
outputs=[self.onnx_node.output[0]],
179-
to=int(TensorProto.FLOAT),
180-
)
175+
node_out = self.onnx_node.output[0]
176+
# preserve existing ONNX tensor type if it exists
177+
node_out_vi = model.get_tensor_valueinfo(node_out)
178+
if node_out_vi is None:
179+
return helper.make_node(
180+
"Cast",
181+
inputs=[self.onnx_node.input[0]],
182+
outputs=[node_out],
183+
to=int(TensorProto.FLOAT),
184+
)
185+
else:
186+
return helper.make_node(
187+
"Cast",
188+
inputs=[self.onnx_node.input[0]],
189+
outputs=[node_out],
190+
to=int(node_out_vi.type.tensor_type.elem_type),
191+
)
181192
# For Quant the output shape should be the same as the input shape.
182193
# Get the output shape from the input
183194
out_shape = model.get_tensor_shape(self.onnx_node.input[0])

src/qonnx/transformation/infer_datatypes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def is_scaled_int(x):
3838
return x.is_integer() or x.is_fixed_point() or isinstance(x, ScaledIntType)
3939

4040

41-
def infer_mac_result_dtype(idtypes, possible_negation):
42-
# will default to float32 unless specific cases detected
43-
ret = DataType["FLOAT32"]
41+
def infer_mac_result_dtype(idtypes, odtype_orig, possible_negation):
42+
# will default to original output dtype unless specific cases detected
43+
ret = odtype_orig
4444
# result may be signed if:
4545
# - any of the operands are signed
4646
# - the operator itself may induce negation (like subtraction)
@@ -97,7 +97,8 @@ def _infer_node_datatype(model, node):
9797
model.set_tensor_datatype(node.output[0], DataType["BIPOLAR"])
9898
elif node.op_type in mac_like_optypes:
9999
possible_negation = node.op_type in ["Sub"]
100-
odtype = infer_mac_result_dtype(idtypes, possible_negation=possible_negation)
100+
odtype_orig = model.get_tensor_datatype(node.output[0])
101+
odtype = infer_mac_result_dtype(idtypes, odtype_orig, possible_negation=possible_negation)
101102
model.set_tensor_datatype(node.output[0], odtype)
102103
elif node.op_type in ["Resize", "Upsample"]:
103104
mode = get_by_name(node.attribute, "mode").s

src/qonnx/util/basic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,15 @@ def gen_finn_dt_tensor(finn_dt, tensor_shape):
233233
int_dt = DataType["INT" + str(finn_dt.bitwidth())]
234234
tensor_values = np.random.randint(int_dt.min(), high=int_dt.max() + 1, size=tensor_shape)
235235
tensor_values = tensor_values * finn_dt.scale_factor()
236-
elif finn_dt == DataType["FLOAT32"]:
236+
elif finn_dt in [DataType["FLOAT32"], DataType["FLOAT16"]]:
237237
tensor_values = np.random.randn(*tensor_shape)
238238
else:
239239
raise ValueError("Datatype {} is not supported, no tensor could be generated".format(finn_dt))
240240
# always use float type as container
241-
return tensor_values.astype(np.float32)
241+
if finn_dt == DataType["FLOAT16"]:
242+
return tensor_values.astype(np.float16)
243+
else:
244+
return tensor_values.astype(np.float32)
242245

243246

244247
def calculate_signed_dot_prod_range(dt_a, dt_b, len):

tests/transformation/test_infer_datatypes.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,19 @@ def test_infer_mac_dtype_result():
4747
si4 = DataType["SCALEDINT<4>"]
4848
si32 = DataType["SCALEDINT<32>"]
4949
# test several 2-input (e.g. weights, inputs) cases
50-
assert infer_mac_result_dtype([iu4, iu4], False) == iu32
51-
assert infer_mac_result_dtype([iu4, is4], False) == is32
52-
assert infer_mac_result_dtype([iu4, iu4], True) == is32
53-
assert infer_mac_result_dtype([iu4, fx4], False) == si32
54-
assert infer_mac_result_dtype([fx4, si4], False) == si32
55-
assert infer_mac_result_dtype([is4, si4], False) == si32
56-
assert infer_mac_result_dtype([f32, iu4], False) == f32
57-
assert infer_mac_result_dtype([f32, si4], False) == f32
50+
assert infer_mac_result_dtype([iu4, iu4], None, False) == iu32
51+
assert infer_mac_result_dtype([iu4, is4], None, False) == is32
52+
assert infer_mac_result_dtype([iu4, iu4], None, True) == is32
53+
assert infer_mac_result_dtype([iu4, fx4], None, False) == si32
54+
assert infer_mac_result_dtype([fx4, si4], None, False) == si32
55+
assert infer_mac_result_dtype([is4, si4], None, False) == si32
56+
assert infer_mac_result_dtype([f32, iu4], f32, False) == f32
57+
assert infer_mac_result_dtype([f32, si4], f32, False) == f32
5858
# test several 3-input (e.g. weights, inputs, biases) cases
59-
assert infer_mac_result_dtype([iu4, iu4, iu4], False) == iu32
60-
assert infer_mac_result_dtype([iu4, iu4, is4], False) == is32
61-
assert infer_mac_result_dtype([is4, iu4, fx4], False) == si32
62-
assert infer_mac_result_dtype([is4, iu4, f32], False) == f32
59+
assert infer_mac_result_dtype([iu4, iu4, iu4], None, False) == iu32
60+
assert infer_mac_result_dtype([iu4, iu4, is4], None, False) == is32
61+
assert infer_mac_result_dtype([is4, iu4, fx4], None, False) == si32
62+
assert infer_mac_result_dtype([is4, iu4, f32], f32, False) == f32
6363

6464

6565
def test_infer_datatypes():

0 commit comments

Comments
 (0)