Open
Description
Prerequisites
Please make sure to check off these prerequisites before submitting a bug report.
- Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
- Check that the issue hasn't already been reported, by checking the currently open issues.
- If there are steps to reproduce the problem, make sure to write them down below.
- If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.
Quick summary
hls4ml config_from_onnx_model fails when using a Resize node with no ROI.
Details
hls4ml config_from_onnx_model fails when using a Resize node with no ROI comming from the conversion of a QuantUpsample from Brevitas.
Steps to Reproduce
- Clone the hls4ml repository
- Checkout the master branch, with commit hash: [77b8331]
- Run code below
import torch.nn as nn
import torch.nn.functional as F
import brevitas.nn as qnn
from brevitas.export import export_qonnx
import torch
import qonnx
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.cleanup import cleanup_model
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
from qonnx.transformation.qcdq_to_qonnx import QCDQToQuant
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
import onnx
def init_weights(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class test_model(nn.Module):
def __init__(self):
super(test_model, self).__init__()
self.quant_inp = qnn.QuantIdentity(bit_width=4, return_quant_tensor=True)
self.upsample = qnn.QuantUpsample(scale_factor=2)
for m in self.modules():
init_weights(m)
def forward(self, x):
x1 = self.quant_inp(x)
x2 = self.upsample(x1)
return x2
model = test_model()
export_qonnx(model, torch.randn(1, 1, 25, 25), export_path='qmodel.onnx')
model = ModelWrapper('qmodel.onnx')
model = cleanup_model(model)
model = model.transform(ConvertToChannelsLastAndClean())
model = model.transform(GemmToMatMul())
model = cleanup_model(model)
onnx.save(model.model, 'transformed_model.onnx')
import hls4ml
from hls4ml.converters import convert_from_onnx_model
from hls4ml.utils.config import config_from_onnx_model
config = hls4ml.utils.config.config_from_onnx_model(model)
Expected behavior
Sucessfull creation of config from model.
Actual behavior
Warning: it is recommended to pass the backend to "config_from_onnx_model"
Output layers: ['Resize_0']
Input shape: [1, 25, 25]
Topology:
Layer name: Quant_0, layer type: Quant, current shape: [[1, 1, 25, 25]]
Traceback (most recent call last):
File "/home/user/project/ConversionQuantUpsample.py", line 54, in <module>
config = hls4ml.utils.config.config_from_onnx_model(
File "/home/user/project/.venv/lib/python3.10/site-packages/hls4ml/utils/config.py", line 492, in config_from_onnx_model
layer_list, _, _ = hls4ml.converters.parse_onnx_model(model)
File "/home/user/project/.venv/lib/python3.10/site-packages/hls4ml/converters/onnx_to_hls.py", line 244, in parse_onnx_model
input_shapes = get_input_shape(onnx_model.graph, node)
File "/home/user/project/.venv/lib/python3.10/site-packages/hls4ml/converters/onnx_to_hls.py", line 76, in get_input_shape
raise RuntimeError(f'Could not find the shape for input {inp}')
RuntimeError: Could not find the shape for input
Optional
Additional context
When printing the node.input
you get ['Quant_0_out0', '', 'Resize_0_param0']
. It seems it's the ''
that is causing the issue. In addition when checking the model with Netron in the inputs of the Resize node there is no roi.