diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 04d5393748..a99cd87a8f 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -343,8 +343,16 @@ def transform(self, model, node): rescale = scale rebias = -bias * scale + + # precision of the scale is important for overall model accuracy, so it is increased here + # This is somewhat stupid and needs a better solution + frac_bits = node.get_attr('bitwidth') * 2 + scale_precision, scale_quantizer = _calculate_precision_quantizer(frac_bits, 0, signed, narrow, rounding_mode) + attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape) attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape) + attributes_rescale['scale_quantizer'] = scale_quantizer + attributes_rescale['scale_precision'] = scale_precision rescale_node = model.make_node( ApplyAlpha, rescale_name, attributes_rescale, [x for x in node.inputs], [x for x in node.outputs] diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py index f48f268626..7c8e81e0f5 100644 --- a/test/pytest/test_qonnx.py +++ b/test/pytest/test_qonnx.py @@ -2,16 +2,27 @@ import urllib from pathlib import Path +# To test workflow from brevitas +import brevitas.nn as qnn import numpy as np import pytest import qonnx.core.onnx_exec as oxe import qonnx.util.cleanup import qonnx.util.to_channels_last +import torch +from brevitas.export import export_qonnx +from brevitas.quant import ( + Int8ActPerTensorFixedPoint, + Int8ActPerTensorFloat, + Int8WeightPerTensorFixedPoint, + Int8WeightPerTensorFloat, +) # To conveniently run QONNX inference from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean from qonnx.transformation.gemm_to_matmul import GemmToMatMul +from torch.nn import Module import hls4ml @@ -432,3 +443,70 @@ def test_simple_model(model_name, io_type, backend, request): y_hls4ml = hls_model.predict(X) np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) + + +# Test brevitas -> QONNX -> hls4ml workflow +quants = { + 'Int8WeightPerTensorFloat': Int8WeightPerTensorFloat, + 'Int8WeightPerTensorFixedPoint': Int8WeightPerTensorFixedPoint, + 'Int8ActPerTensorFloat': Int8ActPerTensorFloat, + 'Int8ActPerTensorFixedPoint': Int8ActPerTensorFixedPoint, +} + + +class QuantModelLinear(Module): + def __init__(self, weight_quant, act_quant): + super().__init__() + self.lin1 = qnn.QuantLinear(4, 4, bias=True, weight_quant=quants[weight_quant], input_quant=quants[act_quant]) + self.relu1 = qnn.QuantReLU(act_quant=quants[act_quant]) + + def forward(self, x): + out = self.relu1(self.lin1(x)) + return out + + +backend = 'Vivado' +io_type = 'io_parallel' + + +# FixedPoint will give power-of-2 quantization scales, Float non-power-of-2 +@pytest.mark.parametrize('backend', ['Vitis']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('quant_type', ['Float', 'FixedPoint']) +def test_brevitas_workflow(backend, io_type, quant_type): + + weight_quant = f'Int8WeightPerTensor{quant_type}' + act_quant = f'Int8ActPerTensor{quant_type}' + + model = QuantModelLinear(weight_quant, act_quant) + + x = torch.rand(1, 4) + + output_path = 'brevitas_onnx.onnx' + _ = export_qonnx(model, input_t=x, export_path=output_path) + + modelQONNX = ModelWrapper('brevitas_onnx.onnx') + modelQONNX = qonnx.util.cleanup.cleanup_model(modelQONNX) + modelQONNX = modelQONNX.transform(ConvertToChannelsLastAndClean()) + modelQONNX = modelQONNX.transform(GemmToMatMul()) + modelQONNX = qonnx.util.cleanup.cleanup_model(modelQONNX) + + pytorch_prediction = model(x).detach().numpy() + + configQONNX = hls4ml.utils.config.config_from_onnx_model( + modelQONNX, granularity='name', backend=backend, default_precision='fixed<16,6>' + ) + # modify the config as desired + hls_modelQONNX = hls4ml.converters.convert_from_onnx_model( + modelQONNX, + output_dir=str(test_root_path / f'hls4mlprj_onnx_brevitas_{quant_type.lower()}_{io_type}_{backend}'), + io_type=io_type, + backend=backend, + hls_config=configQONNX, + ) + print(hls_modelQONNX.output_vars) + hls_modelQONNX.compile() + + hls_predictionQONNX = np.reshape(hls_modelQONNX.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(pytorch_prediction, hls_predictionQONNX, rtol=0.0, atol=0.05)