1
- from pathlib import Path
2
- import torch
3
1
import torch .nn as nn
2
+
4
3
from hls4ml .converters import convert_from_pytorch_model
5
4
from hls4ml .utils .config import config_from_pytorch_model
6
5
6
+
7
7
def test_pytorch_constantpad_1d_2d ():
8
8
class Pad1DModel (nn .Module ):
9
9
def __init__ (self ):
@@ -22,8 +22,6 @@ def forward(self, x):
22
22
return self .pad (x )
23
23
24
24
# 1D test: batch=1, channels=2, width=4, values 1,2,3,4
25
- x1d = torch .tensor ([[[1. , 2. , 3. , 4. ],
26
- [4. , 3. , 2. , 1. ]]]) # shape (1, 2, 4)
27
25
model_1d = Pad1DModel ()
28
26
model_1d .eval ()
29
27
config_1d = config_from_pytorch_model (model_1d , (2 , 4 ))
@@ -33,8 +31,6 @@ def forward(self, x):
33
31
print (f"{ layer .name } : { layer .class_name } " )
34
32
35
33
# 2D test: batch=1, channels=1, height=2, width=4, values 1,2,3,4,5,6,7,8
36
- x2d = torch .tensor ([[[[1. , 2. , 3. , 4. ],
37
- [5. , 6. , 7. , 8. ]]]]) # shape (1, 1, 2, 4)
38
34
model_2d = Pad2DModel ()
39
35
model_2d .eval ()
40
36
config_2d = config_from_pytorch_model (model_2d , (1 , 2 , 4 ))
@@ -45,4 +41,4 @@ def forward(self, x):
45
41
46
42
# Write the HLS projects, cannot compile on Windows
47
43
hls_model_1d .write ()
48
- hls_model_2d .write ()
44
+ hls_model_2d .write ()
0 commit comments