Skip to content

Commit 009bcf5

Browse files
authored
Merge pull request #1322 from NALozano1/const_pad_map
Add support for ConstantPad1d and ConstantPad2d layers w/ zero padding in PyTorch converter
2 parents 46b7a88 + c82c359 commit 009bcf5

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

hls4ml/converters/pytorch/reshape.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,86 @@ def handle_upsample(operation, layer_name, input_names, input_shapes, node, clas
161161
layer['align_corners'] = bool(class_object.align_corners)
162162

163163
return layer, output_shape
164+
165+
166+
@pytorch_handler('ConstantPad2d')
167+
def parse_constantpad2d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
168+
assert operation == 'ConstantPad2d'
169+
170+
layer = {}
171+
layer['class_name'] = 'ZeroPadding2D'
172+
layer['name'] = layer_name
173+
layer['inputs'] = input_names
174+
175+
# PyTorch padding is (left, right, top, bottom)
176+
padding = class_object.padding
177+
if isinstance(padding, int):
178+
pad_left = pad_right = pad_top = pad_bottom = padding
179+
elif isinstance(padding, (tuple, list)) and len(padding) == 4:
180+
pad_left, pad_right, pad_top, pad_bottom = padding
181+
else:
182+
raise Exception(f'Unsupported padding format: {padding}')
183+
184+
layer['pad_left'] = pad_left
185+
layer['pad_right'] = pad_right
186+
layer['pad_top'] = pad_top
187+
layer['pad_bottom'] = pad_bottom
188+
189+
# Only support zero padding for now
190+
pad_value = getattr(class_object, 'value', 0)
191+
if pad_value != 0:
192+
raise Exception('Only zero padding is supported for ConstantPad2d in hls4ml')
193+
194+
# Compute output shape
195+
batch, channels, height, width = input_shapes[0]
196+
out_height = height + pad_top + pad_bottom
197+
out_width = width + pad_left + pad_right
198+
output_shape = [batch, channels, out_height, out_width]
199+
200+
# Add required attributes for hls4ml
201+
layer['n_chan'] = channels
202+
layer['in_height'] = height
203+
layer['in_width'] = width
204+
layer['out_height'] = out_height
205+
layer['out_width'] = out_width
206+
207+
return layer, output_shape
208+
209+
210+
@pytorch_handler('ConstantPad1d')
211+
def parse_constantpad1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
212+
assert operation == 'ConstantPad1d'
213+
214+
layer = {}
215+
layer['class_name'] = 'ZeroPadding1D'
216+
layer['name'] = layer_name
217+
layer['inputs'] = input_names
218+
219+
# PyTorch padding is (left, right)
220+
padding = class_object.padding
221+
if isinstance(padding, int):
222+
pad_left = pad_right = padding
223+
elif isinstance(padding, (tuple, list)) and len(padding) == 2:
224+
pad_left, pad_right = padding
225+
else:
226+
raise Exception(f'Unsupported padding format: {padding}')
227+
228+
layer['pad_left'] = pad_left
229+
layer['pad_right'] = pad_right
230+
231+
# Only support zero padding for now
232+
pad_value = getattr(class_object, 'value', 0)
233+
if pad_value != 0:
234+
raise Exception('Only zero padding is supported for ConstantPad1d in hls4ml')
235+
236+
# Compute output shape
237+
batch, channels, width = input_shapes[0]
238+
out_width = width + pad_left + pad_right
239+
output_shape = [batch, channels, out_width]
240+
241+
# Add required attributes for hls4ml
242+
layer['n_chan'] = channels
243+
layer['in_width'] = width
244+
layer['out_width'] = out_width
245+
246+
return layer, output_shape
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch.nn as nn
2+
3+
from hls4ml.converters import convert_from_pytorch_model
4+
from hls4ml.utils.config import config_from_pytorch_model
5+
6+
7+
def test_pytorch_constantpad_1d_2d():
8+
class Pad1DModel(nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right
12+
13+
def forward(self, x):
14+
return self.pad(x)
15+
16+
class Pad2DModel(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom
20+
21+
def forward(self, x):
22+
return self.pad(x)
23+
24+
# 1D test: batch=1, channels=2, width=4, values 1,2,3,4
25+
model_1d = Pad1DModel()
26+
model_1d.eval()
27+
config_1d = config_from_pytorch_model(model_1d, (2, 4))
28+
hls_model_1d = convert_from_pytorch_model(model_1d, hls_config=config_1d)
29+
print("1D Padding Model Layers:")
30+
for layer in hls_model_1d.get_layers():
31+
print(f"{layer.name}: {layer.class_name}")
32+
33+
# 2D test: batch=1, channels=1, height=2, width=4, values 1,2,3,4,5,6,7,8
34+
model_2d = Pad2DModel()
35+
model_2d.eval()
36+
config_2d = config_from_pytorch_model(model_2d, (1, 2, 4))
37+
hls_model_2d = convert_from_pytorch_model(model_2d, hls_config=config_2d)
38+
print("2D Padding Model Layers:")
39+
for layer in hls_model_2d.get_layers():
40+
print(f"{layer.name}: {layer.class_name}")
41+
42+
# Write the HLS projects, cannot compile on Windows
43+
hls_model_1d.write()
44+
hls_model_2d.write()

0 commit comments

Comments
 (0)