Skip to content

Commit 235ad4f

Browse files
authored
Merge pull request #157 from fastmachinelearning/feature/remove_id_dropout
Remove Dropout with ratio=0 with RemoveIdentityOps
2 parents 279f9c3 + 69e44fc commit 235ad4f

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

src/qonnx/transformation/remove.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,40 +107,54 @@ def __init__(self, atol=1e-05):
107107
self.atol = atol
108108

109109
def apply(self, model):
110+
opset_version = model.model.opset_import[0].version
110111
graph = model.graph
111112
node_ind = 0
112113
graph_modified = False
113-
for n in graph.node:
114+
for node in graph.node:
114115
node_ind += 1
115-
if n.op_type in ["Add", "Sub"] and not model.is_fork_node(n) and not model.is_join_node(n):
116-
A = model.get_initializer(n.input[1])
116+
if node.op_type in ["Add", "Sub"] and not model.is_fork_node(node) and not model.is_join_node(node):
117+
A = model.get_initializer(node.input[1])
117118
if A is not None and np.isclose(A, np.zeros_like(A), atol=self.atol).all():
118-
remove_node_and_rewire(model, n)
119+
remove_node_and_rewire(model, node)
119120
graph_modified = True
120121
break
121122

122-
elif n.op_type in ["Mul", "Div"] and not model.is_fork_node(n) and not model.is_join_node(n):
123-
A = model.get_initializer(n.input[1])
123+
elif node.op_type in ["Mul", "Div"] and not model.is_fork_node(node) and not model.is_join_node(node):
124+
A = model.get_initializer(node.input[1])
124125
if A is not None and np.isclose(A, np.ones_like(A), atol=self.atol).all():
125-
remove_node_and_rewire(model, n)
126+
remove_node_and_rewire(model, node)
126127
graph_modified = True
127128
break
128-
elif n.op_type == "Pad" and not model.is_fork_node(n) and not model.is_join_node(n):
129-
pads = get_by_name(n.attribute, "pads")
129+
elif node.op_type == "Pad" and not model.is_fork_node(node) and not model.is_join_node(node):
130+
pads = get_by_name(node.attribute, "pads")
130131
if pads is not None:
131132
# older versions of Pad op specify pads as attribute
132133
pads = np.asarray(pads.ints, dtype=np.int64)
133134
else:
134135
# newer versions of Pad op specify pads as input
135-
pads = model.get_initializer(n.input[1])
136+
pads = model.get_initializer(node.input[1])
136137

137138
if (pads is not None) and (pads == 0).all():
138-
remove_node_and_rewire(model, n)
139+
remove_node_and_rewire(model, node)
139140
graph_modified = True
140141
break
141-
elif n.op_type == "Identity":
142-
remove_node_and_rewire(model, n)
142+
elif node.op_type == "Identity":
143+
remove_node_and_rewire(model, node)
143144
graph_modified = True
144145
break
146+
elif node.op_type == "Dropout":
147+
if opset_version < 12:
148+
dropout_ratio = get_by_name(node.attribute, "ratio")
149+
dropout_id_cond = not (dropout_ratio is None) and dropout_ratio.f == 0
150+
else:
151+
based_on_inplen = len(node.input) == 1
152+
based_on_ratio_inp = (not based_on_inplen) and model.get_initializer(node.input[1]) == 0
153+
dropout_id_cond = based_on_inplen or based_on_ratio_inp
154+
if dropout_id_cond:
155+
remove_node_and_rewire(model, node)
156+
graph_modified = True
157+
break
158+
145159
model = model.transform(InferShapes())
146160
return (model, graph_modified)

tests/transformation/test_remove_identity_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242

4343
def insert_identity_op(model, op, as_first_node, approx):
44+
kwargs = {}
45+
inp_ndims = 4 if as_first_node else 2
4446
if approx:
4547
zero_val = 0.000001
4648
one_val = 0.999999
@@ -53,6 +55,12 @@ def insert_identity_op(model, op, as_first_node, approx):
5355
val = np.asarray([one_val], dtype=np.float32)
5456
elif op in ["Identity"]:
5557
val = None
58+
elif op == "Pad":
59+
# opset 11 and above: padding specified as input and not attribute
60+
val = np.asarray([0] * 2 * inp_ndims, dtype=np.int64)
61+
elif op == "Dropout":
62+
val = None
63+
kwargs = {"ratio": 0.0}
5664
else:
5765
return
5866

@@ -62,7 +70,7 @@ def insert_identity_op(model, op, as_first_node, approx):
6270
else:
6371
model.set_initializer("value", val)
6472
inplist = ["inp" if as_first_node else "div_out", "value"]
65-
identity_node = helper.make_node(op, inplist, ["ident_out"])
73+
identity_node = helper.make_node(op, inplist, ["ident_out"], **kwargs)
6674
if as_first_node:
6775
graph.node.insert(0, identity_node)
6876
graph.node[1].input[0] = "ident_out"
@@ -74,7 +82,7 @@ def insert_identity_op(model, op, as_first_node, approx):
7482

7583

7684
# identity operations to be inserted
77-
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
85+
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity", "Pad", "Dropout"])
7886
@pytest.mark.parametrize("approx", [False, True])
7987
@pytest.mark.parametrize("as_first_node", [False, True])
8088
@pytest.mark.parametrize("fork_before_id", [False, True])

0 commit comments

Comments
 (0)