Skip to content

Commit cc85809

Browse files
committed
Add tests for ConstantPad1d and ConstantPad2d layers in PyTorch converter
1 parent ecffb66 commit cc85809

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from pathlib import Path
2+
import torch
3+
import torch.nn as nn
4+
from hls4ml.converters import convert_from_pytorch_model
5+
from hls4ml.utils.config import config_from_pytorch_model
6+
7+
def test_pytorch_constantpad_1d_2d():
8+
class Pad1DModel(nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right
12+
13+
def forward(self, x):
14+
return self.pad(x)
15+
16+
class Pad2DModel(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom
20+
21+
def forward(self, x):
22+
return self.pad(x)
23+
24+
# 1D test: batch=1, channels=2, width=4, values 1,2,3,4
25+
x1d = torch.tensor([[[1., 2., 3., 4.],
26+
[4., 3., 2., 1.]]]) # shape (1, 2, 4)
27+
model_1d = Pad1DModel()
28+
model_1d.eval()
29+
config_1d = config_from_pytorch_model(model_1d, (2, 4))
30+
hls_model_1d = convert_from_pytorch_model(model_1d, hls_config=config_1d)
31+
print("1D Padding Model Layers:")
32+
for layer in hls_model_1d.get_layers():
33+
print(f"{layer.name}: {layer.class_name}")
34+
35+
# 2D test: batch=1, channels=1, height=2, width=4, values 1,2,3,4,5,6,7,8
36+
x2d = torch.tensor([[[[1., 2., 3., 4.],
37+
[5., 6., 7., 8.]]]]) # shape (1, 1, 2, 4)
38+
model_2d = Pad2DModel()
39+
model_2d.eval()
40+
config_2d = config_from_pytorch_model(model_2d, (1, 2, 4))
41+
hls_model_2d = convert_from_pytorch_model(model_2d, hls_config=config_2d)
42+
print("2D Padding Model Layers:")
43+
for layer in hls_model_2d.get_layers():
44+
print(f"{layer.name}: {layer.class_name}")
45+
46+
# Write the HLS projects, cannot compile on Windows
47+
hls_model_1d.write()
48+
hls_model_2d.write()

0 commit comments

Comments
 (0)