@@ -107,40 +107,54 @@ def __init__(self, atol=1e-05):
107
107
self .atol = atol
108
108
109
109
def apply (self , model ):
110
+ opset_version = model .model .opset_import [0 ].version
110
111
graph = model .graph
111
112
node_ind = 0
112
113
graph_modified = False
113
- for n in graph .node :
114
+ for node in graph .node :
114
115
node_ind += 1
115
- if n .op_type in ["Add" , "Sub" ] and not model .is_fork_node (n ) and not model .is_join_node (n ):
116
- A = model .get_initializer (n .input [1 ])
116
+ if node .op_type in ["Add" , "Sub" ] and not model .is_fork_node (node ) and not model .is_join_node (node ):
117
+ A = model .get_initializer (node .input [1 ])
117
118
if A is not None and np .isclose (A , np .zeros_like (A ), atol = self .atol ).all ():
118
- remove_node_and_rewire (model , n )
119
+ remove_node_and_rewire (model , node )
119
120
graph_modified = True
120
121
break
121
122
122
- elif n .op_type in ["Mul" , "Div" ] and not model .is_fork_node (n ) and not model .is_join_node (n ):
123
- A = model .get_initializer (n .input [1 ])
123
+ elif node .op_type in ["Mul" , "Div" ] and not model .is_fork_node (node ) and not model .is_join_node (node ):
124
+ A = model .get_initializer (node .input [1 ])
124
125
if A is not None and np .isclose (A , np .ones_like (A ), atol = self .atol ).all ():
125
- remove_node_and_rewire (model , n )
126
+ remove_node_and_rewire (model , node )
126
127
graph_modified = True
127
128
break
128
- elif n .op_type == "Pad" and not model .is_fork_node (n ) and not model .is_join_node (n ):
129
- pads = get_by_name (n .attribute , "pads" )
129
+ elif node .op_type == "Pad" and not model .is_fork_node (node ) and not model .is_join_node (node ):
130
+ pads = get_by_name (node .attribute , "pads" )
130
131
if pads is not None :
131
132
# older versions of Pad op specify pads as attribute
132
133
pads = np .asarray (pads .ints , dtype = np .int64 )
133
134
else :
134
135
# newer versions of Pad op specify pads as input
135
- pads = model .get_initializer (n .input [1 ])
136
+ pads = model .get_initializer (node .input [1 ])
136
137
137
138
if (pads is not None ) and (pads == 0 ).all ():
138
- remove_node_and_rewire (model , n )
139
+ remove_node_and_rewire (model , node )
139
140
graph_modified = True
140
141
break
141
- elif n .op_type == "Identity" :
142
- remove_node_and_rewire (model , n )
142
+ elif node .op_type == "Identity" :
143
+ remove_node_and_rewire (model , node )
143
144
graph_modified = True
144
145
break
146
+ elif node .op_type == "Dropout" :
147
+ if opset_version < 12 :
148
+ dropout_ratio = get_by_name (node .attribute , "ratio" )
149
+ dropout_id_cond = not (dropout_ratio is None ) and dropout_ratio .f == 0
150
+ else :
151
+ based_on_inplen = len (node .input ) == 1
152
+ based_on_ratio_inp = (not based_on_inplen ) and model .get_initializer (node .input [1 ]) == 0
153
+ dropout_id_cond = based_on_inplen or based_on_ratio_inp
154
+ if dropout_id_cond :
155
+ remove_node_and_rewire (model , node )
156
+ graph_modified = True
157
+ break
158
+
145
159
model = model .transform (InferShapes ())
146
160
return (model , graph_modified )
0 commit comments