Skip to content

Commit aaab34a

Browse files
committed
fix softmax parsing in pytorch and add test
1 parent afed23b commit aaab34a

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

hls4ml/converters/pytorch/core.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,21 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
6161
layer['class_name'] = 'ThresholdedReLU'
6262
layer['activation'] = 'ThresholdedReLU'
6363
if layer['activ_param'] < 0:
64-
raise Exception('negative threshold values not supported')
65-
66-
if hasattr(node, 'dim'):
64+
raise Exception('negative threshold values not supported')
65+
if hasattr(class_object, 'dim'):
6766
layer['axis'] = class_object.dim
67+
if layer['class_name'] == 'Softmax' and layer['axis'] is None:
68+
layer['axis'] = -1
69+
if 'IOType' in config:
70+
if layer['class_name'] == 'Softmax' and config['IOType'] == 'io_stream' and layer['axis'] != -1:
71+
raise Exception('dim needs to be -1 for io_stream')
6872
else:
6973
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
7074
layer['class_name'] = 'Activation'
7175
if layer['class_name'] == 'LeakyReLU':
7276
layer['activ_param'] = node.kwargs['negative_slope']
7377
if layer['class_name'] == 'ELU':
74-
layer['activ_param'] = node.kwargs['alpha']
78+
layer['activ_param'] = node.kwargs['alpha']
7579
if layer['class_name'] == 'Threshold':
7680
layer['activ_param'] = node.args[1]
7781
if layer['activ_param'] < 0:
@@ -80,7 +84,12 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
8084
layer['activation'] = 'ThresholdedReLU'
8185
if 'dim' in node.kwargs:
8286
layer['axis'] = node.kwargs['dim']
83-
87+
if layer['class_name'] == 'Softmax' and layer['axis'] is None:
88+
layer['axis'] = -1
89+
if 'IOType' in config:
90+
if layer['class_name'] == 'Softmax' and config['IOType'] == 'io_stream' and layer['axis'] != -1:
91+
raise Exception('dim needs to be -1 for io_stream')
92+
8493
output_shape = input_shapes[0]
8594
return layer, output_shape
8695

test/pytest/test_pytorch_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_linear(backend, io_type):
6363
@pytest.mark.parametrize(
6464
"activation_function",
6565
[
66+
nn.Softmax(dim=-1),
6667
nn.ReLU(),
6768
nn.Tanh(),
6869
nn.LeakyReLU(negative_slope=1.0),
@@ -74,6 +75,7 @@ def test_linear(backend, io_type):
7475
)
7576
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
7677
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
78+
7779
def test_activations(activation_function, backend, io_type):
7880
model = torch.nn.Sequential(nn.Linear(1, 1), activation_function).to()
7981
model.eval()
@@ -118,6 +120,12 @@ def __init__(self):
118120
def forward(self, x):
119121
return nn.functional.relu(x)
120122

123+
class SoftmaxModel(nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
127+
def forward(self, x):
128+
return nn.functional.softmax(x,dim=-1)
121129

122130
class TanHModel(nn.Module):
123131
def __init__(self):
@@ -162,6 +170,7 @@ def forward(self, x):
162170
@pytest.mark.parametrize(
163171
"activation_function",
164172
[
173+
SoftmaxModel(),
165174
ReLuModel(),
166175
TanHModel(),
167176
LeakyReLuModel(),

0 commit comments

Comments
 (0)