diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index f7392ab8da..64b60c97b9 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -161,3 +161,86 @@ def handle_upsample(operation, layer_name, input_names, input_shapes, node, clas layer['align_corners'] = bool(class_object.align_corners) return layer, output_shape + + +@pytorch_handler('ConstantPad2d') +def parse_constantpad2d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation == 'ConstantPad2d' + + layer = {} + layer['class_name'] = 'ZeroPadding2D' + layer['name'] = layer_name + layer['inputs'] = input_names + + # PyTorch padding is (left, right, top, bottom) + padding = class_object.padding + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif isinstance(padding, (tuple, list)) and len(padding) == 4: + pad_left, pad_right, pad_top, pad_bottom = padding + else: + raise Exception(f'Unsupported padding format: {padding}') + + layer['pad_left'] = pad_left + layer['pad_right'] = pad_right + layer['pad_top'] = pad_top + layer['pad_bottom'] = pad_bottom + + # Only support zero padding for now + pad_value = getattr(class_object, 'value', 0) + if pad_value != 0: + raise Exception('Only zero padding is supported for ConstantPad2d in hls4ml') + + # Compute output shape + batch, channels, height, width = input_shapes[0] + out_height = height + pad_top + pad_bottom + out_width = width + pad_left + pad_right + output_shape = [batch, channels, out_height, out_width] + + # Add required attributes for hls4ml + layer['n_chan'] = channels + layer['in_height'] = height + layer['in_width'] = width + layer['out_height'] = out_height + layer['out_width'] = out_width + + return layer, output_shape + + +@pytorch_handler('ConstantPad1d') +def parse_constantpad1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation == 'ConstantPad1d' + + layer = {} + layer['class_name'] = 'ZeroPadding1D' + layer['name'] = layer_name + layer['inputs'] = input_names + + # PyTorch padding is (left, right) + padding = class_object.padding + if isinstance(padding, int): + pad_left = pad_right = padding + elif isinstance(padding, (tuple, list)) and len(padding) == 2: + pad_left, pad_right = padding + else: + raise Exception(f'Unsupported padding format: {padding}') + + layer['pad_left'] = pad_left + layer['pad_right'] = pad_right + + # Only support zero padding for now + pad_value = getattr(class_object, 'value', 0) + if pad_value != 0: + raise Exception('Only zero padding is supported for ConstantPad1d in hls4ml') + + # Compute output shape + batch, channels, width = input_shapes[0] + out_width = width + pad_left + pad_right + output_shape = [batch, channels, out_width] + + # Add required attributes for hls4ml + layer['n_chan'] = channels + layer['in_width'] = width + layer['out_width'] = out_width + + return layer, output_shape diff --git a/test/pytest/test_pytorch_constpadmapping.py b/test/pytest/test_pytorch_constpadmapping.py new file mode 100644 index 0000000000..b4f602d711 --- /dev/null +++ b/test/pytest/test_pytorch_constpadmapping.py @@ -0,0 +1,44 @@ +import torch.nn as nn + +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils.config import config_from_pytorch_model + + +def test_pytorch_constantpad_1d_2d(): + class Pad1DModel(nn.Module): + def __init__(self): + super().__init__() + self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right + + def forward(self, x): + return self.pad(x) + + class Pad2DModel(nn.Module): + def __init__(self): + super().__init__() + self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom + + def forward(self, x): + return self.pad(x) + + # 1D test: batch=1, channels=2, width=4, values 1,2,3,4 + model_1d = Pad1DModel() + model_1d.eval() + config_1d = config_from_pytorch_model(model_1d, (2, 4)) + hls_model_1d = convert_from_pytorch_model(model_1d, hls_config=config_1d) + print("1D Padding Model Layers:") + for layer in hls_model_1d.get_layers(): + print(f"{layer.name}: {layer.class_name}") + + # 2D test: batch=1, channels=1, height=2, width=4, values 1,2,3,4,5,6,7,8 + model_2d = Pad2DModel() + model_2d.eval() + config_2d = config_from_pytorch_model(model_2d, (1, 2, 4)) + hls_model_2d = convert_from_pytorch_model(model_2d, hls_config=config_2d) + print("2D Padding Model Layers:") + for layer in hls_model_2d.get_layers(): + print(f"{layer.name}: {layer.class_name}") + + # Write the HLS projects, cannot compile on Windows + hls_model_1d.write() + hls_model_2d.write()