Skip to content

Add support for ConstantPad1d and ConstantPad2d layers in PyTorch con… #1322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions hls4ml/converters/pytorch/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,86 @@ def handle_upsample(operation, layer_name, input_names, input_shapes, node, clas
layer['align_corners'] = bool(class_object.align_corners)

return layer, output_shape


@pytorch_handler('ConstantPad2d')
def parse_constantpad2d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation == 'ConstantPad2d'

layer = {}
layer['class_name'] = 'ZeroPadding2D'
layer['name'] = layer_name
layer['inputs'] = input_names

# PyTorch padding is (left, right, top, bottom)
padding = class_object.padding
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif isinstance(padding, (tuple, list)) and len(padding) == 4:
pad_left, pad_right, pad_top, pad_bottom = padding
else:
raise Exception(f'Unsupported padding format: {padding}')

layer['pad_left'] = pad_left
layer['pad_right'] = pad_right
layer['pad_top'] = pad_top
layer['pad_bottom'] = pad_bottom

# Only support zero padding for now
pad_value = getattr(class_object, 'value', 0)
if pad_value != 0:
raise Exception('Only zero padding is supported for ConstantPad2d in hls4ml')

# Compute output shape
batch, channels, height, width = input_shapes[0]
out_height = height + pad_top + pad_bottom
out_width = width + pad_left + pad_right
output_shape = [batch, channels, out_height, out_width]

# Add required attributes for hls4ml
layer['n_chan'] = channels
layer['in_height'] = height
layer['in_width'] = width
layer['out_height'] = out_height
layer['out_width'] = out_width

return layer, output_shape


@pytorch_handler('ConstantPad1d')
def parse_constantpad1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation == 'ConstantPad1d'

layer = {}
layer['class_name'] = 'ZeroPadding1D'
layer['name'] = layer_name
layer['inputs'] = input_names

# PyTorch padding is (left, right)
padding = class_object.padding
if isinstance(padding, int):
pad_left = pad_right = padding
elif isinstance(padding, (tuple, list)) and len(padding) == 2:
pad_left, pad_right = padding
else:
raise Exception(f'Unsupported padding format: {padding}')

layer['pad_left'] = pad_left
layer['pad_right'] = pad_right

# Only support zero padding for now
pad_value = getattr(class_object, 'value', 0)
if pad_value != 0:
raise Exception('Only zero padding is supported for ConstantPad1d in hls4ml')

# Compute output shape
batch, channels, width = input_shapes[0]
out_width = width + pad_left + pad_right
output_shape = [batch, channels, out_width]

# Add required attributes for hls4ml
layer['n_chan'] = channels
layer['in_width'] = width
layer['out_width'] = out_width

return layer, output_shape
44 changes: 44 additions & 0 deletions test/pytest/test_pytorch_constpadmapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch.nn as nn

from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model


def test_pytorch_constantpad_1d_2d():
class Pad1DModel(nn.Module):
def __init__(self):
super().__init__()
self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right

def forward(self, x):
return self.pad(x)

class Pad2DModel(nn.Module):
def __init__(self):
super().__init__()
self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom

def forward(self, x):
return self.pad(x)

# 1D test: batch=1, channels=2, width=4, values 1,2,3,4
model_1d = Pad1DModel()
model_1d.eval()
config_1d = config_from_pytorch_model(model_1d, (2, 4))
hls_model_1d = convert_from_pytorch_model(model_1d, hls_config=config_1d)
print("1D Padding Model Layers:")
for layer in hls_model_1d.get_layers():
print(f"{layer.name}: {layer.class_name}")

# 2D test: batch=1, channels=1, height=2, width=4, values 1,2,3,4,5,6,7,8
model_2d = Pad2DModel()
model_2d.eval()
config_2d = config_from_pytorch_model(model_2d, (1, 2, 4))
hls_model_2d = convert_from_pytorch_model(model_2d, hls_config=config_2d)
print("2D Padding Model Layers:")
for layer in hls_model_2d.get_layers():
print(f"{layer.name}: {layer.class_name}")

# Write the HLS projects, cannot compile on Windows
hls_model_1d.write()
hls_model_2d.write()