Skip to content

Commit b176362

Browse files
committed
[RemoveIdentityOps] add ratio=0 Dropout to identity ops
1 parent 763c0c1 commit b176362

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

src/qonnx/transformation/remove.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,40 +107,54 @@ def __init__(self, atol=1e-05):
107107
self.atol = atol
108108

109109
def apply(self, model):
110+
opset_version = model.model.opset_import[0].version
110111
graph = model.graph
111112
node_ind = 0
112113
graph_modified = False
113-
for n in graph.node:
114+
for node in graph.node:
114115
node_ind += 1
115-
if n.op_type in ["Add", "Sub"] and not model.is_fork_node(n) and not model.is_join_node(n):
116-
A = model.get_initializer(n.input[1])
116+
if node.op_type in ["Add", "Sub"] and not model.is_fork_node(node) and not model.is_join_node(node):
117+
A = model.get_initializer(node.input[1])
117118
if A is not None and np.isclose(A, np.zeros_like(A), atol=self.atol).all():
118-
remove_node_and_rewire(model, n)
119+
remove_node_and_rewire(model, node)
119120
graph_modified = True
120121
break
121122

122-
elif n.op_type in ["Mul", "Div"] and not model.is_fork_node(n) and not model.is_join_node(n):
123-
A = model.get_initializer(n.input[1])
123+
elif node.op_type in ["Mul", "Div"] and not model.is_fork_node(node) and not model.is_join_node(node):
124+
A = model.get_initializer(node.input[1])
124125
if A is not None and np.isclose(A, np.ones_like(A), atol=self.atol).all():
125-
remove_node_and_rewire(model, n)
126+
remove_node_and_rewire(model, node)
126127
graph_modified = True
127128
break
128-
elif n.op_type == "Pad" and not model.is_fork_node(n) and not model.is_join_node(n):
129-
pads = get_by_name(n.attribute, "pads")
129+
elif node.op_type == "Pad" and not model.is_fork_node(node) and not model.is_join_node(node):
130+
pads = get_by_name(node.attribute, "pads")
130131
if pads is not None:
131132
# older versions of Pad op specify pads as attribute
132133
pads = np.asarray(pads.ints, dtype=np.int64)
133134
else:
134135
# newer versions of Pad op specify pads as input
135-
pads = model.get_initializer(n.input[1])
136+
pads = model.get_initializer(node.input[1])
136137

137138
if (pads is not None) and (pads == 0).all():
138-
remove_node_and_rewire(model, n)
139+
remove_node_and_rewire(model, node)
139140
graph_modified = True
140141
break
141-
elif n.op_type == "Identity":
142-
remove_node_and_rewire(model, n)
142+
elif node.op_type == "Identity":
143+
remove_node_and_rewire(model, node)
143144
graph_modified = True
144145
break
146+
elif node.op_type == "Dropout":
147+
if opset_version < 12:
148+
dropout_ratio = get_by_name(node.attribute, "ratio")
149+
dropout_id_cond = not (dropout_ratio is None) and dropout_ratio.f == 0
150+
else:
151+
based_on_inplen = len(node.input) == 1
152+
based_on_ratio_inp = (not based_on_inplen) and model.get_initializer(node.input[1]) == 0
153+
dropout_id_cond = based_on_inplen or based_on_ratio_inp
154+
if dropout_id_cond:
155+
remove_node_and_rewire(model, node)
156+
graph_modified = True
157+
break
158+
145159
model = model.transform(InferShapes())
146160
return (model, graph_modified)

0 commit comments

Comments
 (0)