@@ -161,3 +161,86 @@ def handle_upsample(operation, layer_name, input_names, input_shapes, node, clas
161
161
layer ['align_corners' ] = bool (class_object .align_corners )
162
162
163
163
return layer , output_shape
164
+
165
+
166
+ @pytorch_handler ('ConstantPad2d' )
167
+ def parse_constantpad2d_layer (operation , layer_name , input_names , input_shapes , node , class_object , data_reader , config ):
168
+ assert operation == 'ConstantPad2d'
169
+
170
+ layer = {}
171
+ layer ['class_name' ] = 'ZeroPadding2D'
172
+ layer ['name' ] = layer_name
173
+ layer ['inputs' ] = input_names
174
+
175
+ # PyTorch padding is (left, right, top, bottom)
176
+ padding = class_object .padding
177
+ if isinstance (padding , int ):
178
+ pad_left = pad_right = pad_top = pad_bottom = padding
179
+ elif isinstance (padding , (tuple , list )) and len (padding ) == 4 :
180
+ pad_left , pad_right , pad_top , pad_bottom = padding
181
+ else :
182
+ raise Exception (f'Unsupported padding format: { padding } ' )
183
+
184
+ layer ['pad_left' ] = pad_left
185
+ layer ['pad_right' ] = pad_right
186
+ layer ['pad_top' ] = pad_top
187
+ layer ['pad_bottom' ] = pad_bottom
188
+
189
+ # Only support zero padding for now
190
+ pad_value = getattr (class_object , 'value' , 0 )
191
+ if pad_value != 0 :
192
+ raise Exception ('Only zero padding is supported for ConstantPad2d in hls4ml' )
193
+
194
+ # Compute output shape
195
+ batch , channels , height , width = input_shapes [0 ]
196
+ out_height = height + pad_top + pad_bottom
197
+ out_width = width + pad_left + pad_right
198
+ output_shape = [batch , channels , out_height , out_width ]
199
+
200
+ # Add required attributes for hls4ml
201
+ layer ['n_chan' ] = channels
202
+ layer ['in_height' ] = height
203
+ layer ['in_width' ] = width
204
+ layer ['out_height' ] = out_height
205
+ layer ['out_width' ] = out_width
206
+
207
+ return layer , output_shape
208
+
209
+
210
+ @pytorch_handler ('ConstantPad1d' )
211
+ def parse_constantpad1d_layer (operation , layer_name , input_names , input_shapes , node , class_object , data_reader , config ):
212
+ assert operation == 'ConstantPad1d'
213
+
214
+ layer = {}
215
+ layer ['class_name' ] = 'ZeroPadding1D'
216
+ layer ['name' ] = layer_name
217
+ layer ['inputs' ] = input_names
218
+
219
+ # PyTorch padding is (left, right)
220
+ padding = class_object .padding
221
+ if isinstance (padding , int ):
222
+ pad_left = pad_right = padding
223
+ elif isinstance (padding , (tuple , list )) and len (padding ) == 2 :
224
+ pad_left , pad_right = padding
225
+ else :
226
+ raise Exception (f'Unsupported padding format: { padding } ' )
227
+
228
+ layer ['pad_left' ] = pad_left
229
+ layer ['pad_right' ] = pad_right
230
+
231
+ # Only support zero padding for now
232
+ pad_value = getattr (class_object , 'value' , 0 )
233
+ if pad_value != 0 :
234
+ raise Exception ('Only zero padding is supported for ConstantPad1d in hls4ml' )
235
+
236
+ # Compute output shape
237
+ batch , channels , width = input_shapes [0 ]
238
+ out_width = width + pad_left + pad_right
239
+ output_shape = [batch , channels , out_width ]
240
+
241
+ # Add required attributes for hls4ml
242
+ layer ['n_chan' ] = channels
243
+ layer ['in_width' ] = width
244
+ layer ['out_width' ] = out_width
245
+
246
+ return layer , output_shape
0 commit comments