You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Please make sure to check off these prerequisites before submitting a bug report.
Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
Check that the issue hasn't already been reported, by checking the currently open issues.
If there are steps to reproduce the problem, make sure to write them down below.
If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.
Quick summary
hls4ml config_from_onnx_model fails when using a Resize node with no ROI.
Details
hls4ml config_from_onnx_model fails when using a Resize node with no ROI comming from the conversion of a QuantUpsample from Brevitas.
Steps to Reproduce
Clone the hls4ml repository
Checkout the master branch, with commit hash: [77b8331]
Run code below
import torch.nn as nn
import torch.nn.functional as F
import brevitas.nn as qnn
from brevitas.export import export_qonnx
import torch
import qonnx
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.cleanup import cleanup_model
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
from qonnx.transformation.qcdq_to_qonnx import QCDQToQuant
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
import onnx
def init_weights(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class test_model(nn.Module):
def __init__(self):
super(test_model, self).__init__()
self.quant_inp = qnn.QuantIdentity(bit_width=4, return_quant_tensor=True)
self.upsample = qnn.QuantUpsample(scale_factor=2)
for m in self.modules():
init_weights(m)
def forward(self, x):
x1 = self.quant_inp(x)
x2 = self.upsample(x1)
return x2
model = test_model()
export_qonnx(model, torch.randn(1, 1, 25, 25), export_path='qmodel.onnx')
model = ModelWrapper('qmodel.onnx')
model = cleanup_model(model)
model = model.transform(ConvertToChannelsLastAndClean())
model = model.transform(GemmToMatMul())
model = cleanup_model(model)
onnx.save(model.model, 'transformed_model.onnx')
import hls4ml
from hls4ml.converters import convert_from_onnx_model
from hls4ml.utils.config import config_from_onnx_model
config = hls4ml.utils.config.config_from_onnx_model(model)
Expected behavior
Sucessfull creation of config from model.
Actual behavior
Warning: it is recommended to pass the backend to "config_from_onnx_model"
Output layers: ['Resize_0']
Input shape: [1, 25, 25]
Topology:
Layer name: Quant_0, layer type: Quant, current shape: [[1, 1, 25, 25]]
Traceback (most recent call last):
File "/home/user/project/ConversionQuantUpsample.py", line 54, in <module>
config = hls4ml.utils.config.config_from_onnx_model(
File "/home/user/project/.venv/lib/python3.10/site-packages/hls4ml/utils/config.py", line 492, in config_from_onnx_model
layer_list, _, _ = hls4ml.converters.parse_onnx_model(model)
File "/home/user/project/.venv/lib/python3.10/site-packages/hls4ml/converters/onnx_to_hls.py", line 244, in parse_onnx_model
input_shapes = get_input_shape(onnx_model.graph, node)
File "/home/user/project/.venv/lib/python3.10/site-packages/hls4ml/converters/onnx_to_hls.py", line 76, in get_input_shape
raise RuntimeError(f'Could not find the shape for input {inp}')
RuntimeError: Could not find the shape for input
Optional
Additional context
When printing the node.input you get ['Quant_0_out0', '', 'Resize_0_param0']. It seems it's the '' that is causing the issue. In addition when checking the model with Netron in the inputs of the Resize node there is no roi.
The text was updated successfully, but these errors were encountered:
Prerequisites
Please make sure to check off these prerequisites before submitting a bug report.
Quick summary
hls4ml config_from_onnx_model fails when using a Resize node with no ROI.
Details
hls4ml config_from_onnx_model fails when using a Resize node with no ROI comming from the conversion of a QuantUpsample from Brevitas.
Steps to Reproduce
Expected behavior
Sucessfull creation of config from model.
Actual behavior
Optional
Additional context
When printing the
node.input
you get['Quant_0_out0', '', 'Resize_0_param0']
. It seems it's the''
that is causing the issue. In addition when checking the model with Netron in the inputs of the Resize node there is no roi.The text was updated successfully, but these errors were encountered: