26
26
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
27
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
28
29
+ import numpy as np
29
30
import warnings
31
+ from copy import deepcopy
30
32
from onnx import TensorProto , helper
31
33
32
- from qonnx .analysis . topology import is_linear
34
+ from qonnx .core . modelwrapper import ModelWrapper
33
35
from qonnx .custom_op import channels_last
34
36
from qonnx .custom_op .channels_last .base_wrapped_op import to_channels_first_args , to_channels_last_args
35
37
from qonnx .transformation .base import Transformation
36
38
from qonnx .transformation .fold_constants import FoldConstants
39
+ from qonnx .transformation .general import SortGraph
37
40
from qonnx .transformation .infer_shapes import InferShapes
38
41
from qonnx .transformation .make_input_chanlast import MakeInputChannelsLast
39
42
from qonnx .transformation .quant_constant_folding import FoldTransposeIntoQuantInit
40
43
from qonnx .util .basic import get_by_name
44
+ from qonnx .util .onnx import is_eltwise_optype
41
45
42
46
# Standard ONNX nodes which require a ChannelsLast data format to function properly
43
47
_channelsLast_node_types = list (channels_last .custom_op .keys ())
44
48
45
49
# Nodes, which do not modify the shape of the tensor
46
50
# And modify all values in the same way.
47
- _move_through_nodes = ["Quant" , "Relu" ]
51
+ _move_through_nodes = ["Quant" , "Relu" , "Selu" , "LeakyRelu" , "Sigmoid" , "Tanh" ]
48
52
49
53
# Nodes, which do not modify the shape of the tensor,
50
54
# And modify all values in the same way, if the second tensor is a scalar.
51
55
_move_through_nodes_if_scalar = ["Mul" , "Div" , "Sub" , "Add" ]
52
56
53
57
58
+ def get_transpose_perms (transpose_node , model ):
59
+ perm = get_by_name (transpose_node .attribute , "perm" )
60
+ ndim = len (model .get_tensor_shape (transpose_node .input [0 ]))
61
+ if perm is None :
62
+ # default perm is to reverse the dim order
63
+ return list (range (ndim - 1 , - 1 , - 1 ))
64
+ else :
65
+ return list (perm .ints )
66
+
67
+
68
+ def move_transpose_past_eltwise (transpose_node , eltwise_node , model : ModelWrapper ):
69
+ t0 = transpose_node .input [0 ]
70
+ t1 = transpose_node .output [0 ]
71
+ t2 = eltwise_node .output [0 ]
72
+ subgraph_inp_shape = model .get_tensor_shape (t0 )
73
+ ndim_inp = len (subgraph_inp_shape )
74
+ perm = get_transpose_perms (transpose_node , model )
75
+
76
+ # before: t0 -> transpose -> t1 -> eltwise -> t2
77
+ # after: t0 -> eltwise -> t1 -> transpose -> t2
78
+ # find the eltwise inp index fed by transpose
79
+ transpose_in_ind = list (eltwise_node .input ).index (t1 )
80
+ # check all inputs for the eltwise op:
81
+ # we need to ensure those inputs get inverse-transposed
82
+ # to keep the graph semantics intact
83
+ for ind , eltwise_inp in enumerate (eltwise_node .input ):
84
+ if ind == transpose_in_ind :
85
+ # the input that feeds from the original transpose
86
+ # node will be implicitly inverse-transposed, since we'll be
87
+ # moving that transpose node past the eltwise op
88
+ continue
89
+ inp_shape = model .get_tensor_shape (eltwise_inp )
90
+ ndim = len (inp_shape )
91
+ if ndim == 0 :
92
+ # scalar input, always broadcastable, no action needed
93
+ continue
94
+ elif ndim == ndim_inp :
95
+ # input with matching dimensions, add inverse transpose
96
+ new_t_inp = model .make_new_valueinfo_name ()
97
+ inv_perm = np .argsort (perm )
98
+ new_transpose_node = helper .make_node ("Transpose" , [eltwise_inp ], [new_t_inp ], perm = inv_perm )
99
+ t_shape = np .transpose (np .empty (inp_shape ), axes = inv_perm ).shape
100
+ model .set_tensor_shape (new_t_inp , t_shape )
101
+ eltwise_node .input [ind ] = new_t_inp
102
+ model .graph .node .append (new_transpose_node )
103
+ else :
104
+ # input with non-matching dimensions, assume broadcastable
105
+ # first add Unsqueeze node to match number of dimensions
106
+ unsqueeze_param_name = model .make_new_valueinfo_name ()
107
+ model .set_initializer (unsqueeze_param_name , np .asarray (list (range (ndim_inp - ndim )), dtype = np .int64 ))
108
+ unsqueeze_out_name = model .make_new_valueinfo_name ()
109
+ new_unsqueeze_node = helper .make_node ("Unsqueeze" , [eltwise_inp , unsqueeze_param_name ], [unsqueeze_out_name ])
110
+ unsqueeze_out_shape = np .expand_dims (np .empty (inp_shape ), axis = tuple (range (ndim_inp - ndim ))).shape
111
+ model .set_tensor_shape (unsqueeze_out_name , unsqueeze_out_shape )
112
+ model .graph .node .append (new_unsqueeze_node )
113
+ # now add inverse transpose
114
+ new_t_inp = model .make_new_valueinfo_name ()
115
+ inv_perm = np .argsort (perm )
116
+ new_transpose_node = helper .make_node ("Transpose" , [unsqueeze_out_name ], [new_t_inp ], perm = inv_perm )
117
+ t_shape = np .transpose (np .empty (unsqueeze_out_shape ), axes = inv_perm ).shape
118
+ model .set_tensor_shape (new_t_inp , t_shape )
119
+ eltwise_node .input [ind ] = new_t_inp
120
+ model .graph .node .append (new_transpose_node )
121
+ # rewire to swap transpose and eltwise node order
122
+ eltwise_node .input [transpose_in_ind ] = t0
123
+ eltwise_node .output [0 ] = t1
124
+ transpose_node .input [0 ] = t1
125
+ transpose_node .output [0 ] = t2
126
+ # t1 tensor shape changes to inp_shape
127
+ model .set_tensor_shape (t1 , subgraph_inp_shape )
128
+ model = model .transform (SortGraph ())
129
+ model = model .transform (FoldConstants ())
130
+ return model
131
+
132
+
54
133
class ConvertToChannelsLastAndClean (Transformation ):
55
134
"""
56
135
Converts data layout dependent nodes to ChannelsLast nodes and inserts transformations.
@@ -67,8 +146,7 @@ def __init__(self, make_input_channels_last=False):
67
146
super ().__init__ ()
68
147
self ._make_input_channels_last = make_input_channels_last
69
148
70
- def apply (self , model ):
71
- assert model .analysis (is_linear )["is_linear" ], "Only linear and non-branching models are supported at this moment."
149
+ def apply (self , model : ModelWrapper ):
72
150
assert model .check_all_tensor_shapes_specified (), (
73
151
"All tensor shapes must be specified. " "Consider running InferShapes."
74
152
)
@@ -85,8 +163,9 @@ def apply(self, model):
85
163
# Technically only required if something changed in the previous trafo
86
164
model = model .transform (RemoveConsecutiveChanFirstAndChanLastTrafos ())
87
165
88
- # Apply MoveChanLastDownStream
166
+ # Apply MoveChanLastDownStream and MoveTransposePastFork
89
167
model = model .transform (MoveChanFirstDownstream ())
168
+ model = model .transform (MoveTransposePastFork ())
90
169
91
170
# Run RemoveConsecutiveChanFirstAndChanLastTrafos again,
92
171
# Technically only required if something changed in the previous trafo
@@ -218,9 +297,9 @@ def apply(self, model):
218
297
# Check the input shape and make sure we support it
219
298
input_shape = model .get_tensor_shape (n .input [0 ])
220
299
# Check that this is a "to chan first" trafo
221
- perm_1 = get_by_name ( n . attribute , "perm" )
300
+ perm_1 = get_transpose_perms ( n , model )
222
301
ndim = len (input_shape )
223
- if list (to_channels_first_args (ndim )) == perm_1 . ints :
302
+ if list (to_channels_first_args (ndim )) == perm_1 :
224
303
successor_nodes = model .find_direct_successors (n )
225
304
if successor_nodes is None :
226
305
continue
@@ -229,8 +308,8 @@ def apply(self, model):
229
308
if successor_node .op_type == "Transpose" :
230
309
# Check that this is a "to chan last" trafo,
231
310
# if so both can get removed.
232
- perm_2 = get_by_name (successor_node . attribute , "perm" )
233
- if list (to_channels_last_args (ndim )) == perm_2 . ints :
311
+ perm_2 = get_transpose_perms (successor_node , model )
312
+ if list (to_channels_last_args (ndim )) == perm_2 :
234
313
# Connect original input to new output
235
314
input_tensor = n .input [0 ]
236
315
output_tensor_name = successor_node .output [0 ]
@@ -257,7 +336,7 @@ class MoveChanLastUpstream(Transformation):
257
336
Moves channel last transformations further upstream.
258
337
"""
259
338
260
- def apply (self , model ):
339
+ def apply (self , model : ModelWrapper ):
261
340
graph = model .graph
262
341
node_ind = 0
263
342
graph_modified = False
@@ -268,8 +347,8 @@ def apply(self, model):
268
347
# Check the input shape and make sure we support it
269
348
input_shape = model .get_tensor_shape (n .input [0 ])
270
349
ndim = len (input_shape )
271
- perm = get_by_name ( n . attribute , "perm" )
272
- if list (to_channels_last_args (ndim )) == perm . ints :
350
+ perm = get_transpose_perms ( n , model )
351
+ if list (to_channels_last_args (ndim )) == perm :
273
352
predecessors = model .find_direct_predecessors (n )
274
353
# Check if we reached the top of the graph
275
354
if predecessors is None :
@@ -285,6 +364,10 @@ def apply(self, model):
285
364
if second_inp_shape == [1 ] or second_inp_shape == []:
286
365
move_through_valid |= True
287
366
367
+ # don't move through if the predecessor output is a fork
368
+ if model .is_fork_node (predecessor ):
369
+ move_through_valid = False
370
+
288
371
# Apply move through trafo if possible
289
372
if move_through_valid :
290
373
# Input tensors are always input 0
@@ -334,52 +417,29 @@ def apply(self, model):
334
417
node_ind = 0
335
418
graph_modified = False
336
419
# Find transpose nodes, which are "to chan first" trafos
337
- for n in graph .node :
420
+ for node in graph .node :
338
421
node_ind += 1
339
- if n .op_type == "Transpose" :
422
+ if node .op_type == "Transpose" :
340
423
# Check the input shape and make sure we support it
341
- input_shape = model .get_tensor_shape (n .input [0 ])
424
+ input_shape = model .get_tensor_shape (node .input [0 ])
342
425
ndim = len (input_shape )
343
- perm = get_by_name ( n . attribute , "perm" )
344
- if list (to_channels_first_args (ndim )) == perm . ints :
426
+ perm = get_transpose_perms ( node , model )
427
+ if list (to_channels_first_args (ndim )) == perm :
345
428
# Do not move the node, if it is at the top of the graph,
346
429
# this is a strange edge case, for 1D networks, where channels last and channels first trafos
347
430
# are identical.
348
- predecessors = model .find_direct_predecessors (n )
431
+ predecessors = model .find_direct_predecessors (node )
349
432
if predecessors is None :
350
433
continue
351
434
352
- successors = model .find_direct_successors (n )
435
+ successors = model .find_direct_successors (node )
353
436
if successors is None :
354
437
continue
355
438
successor = successors [0 ]
439
+ transpose_node = node
356
440
357
- # Check if we can simply move through the next node
358
- move_through_valid = successor .op_type in _move_through_nodes
359
- # Check if we have a node, which applies a scalar change,
360
- # then we can also move through.
361
- if successor .op_type in _move_through_nodes_if_scalar :
362
- second_inp_shape = model .get_tensor_shape (successor .input [1 ])
363
- if second_inp_shape == [1 ] or second_inp_shape == []:
364
- move_through_valid |= True
365
- # Apply move through trafo if possible
366
- if move_through_valid :
367
- # Collect all tensors connecting n and successor
368
- # and surrounding nodes
369
- tensor_1 = n .input [0 ]
370
- tensor_2 = n .output [0 ]
371
- tensor_3 = successor .output [0 ]
372
- # Now connect the tensors to the nodes again,
373
- # but in different order
374
- successor .input [0 ] = tensor_1
375
- successor .output [0 ] = tensor_2
376
- n .input [0 ] = tensor_2
377
- n .output [0 ] = tensor_3
378
-
379
- # Change the shape of the middle tensor
380
- target_shape = model .get_tensor_shape (tensor_1 )
381
- model .set_tensor_shape (tensor_2 , target_shape )
382
-
441
+ if is_eltwise_optype (successor .op_type ):
442
+ model = move_transpose_past_eltwise (transpose_node , successor , model )
383
443
graph_modified = True
384
444
return model , graph_modified
385
445
@@ -422,7 +482,7 @@ def apply(self, model):
422
482
input_shape = model .get_tensor_shape (transp_node .input [0 ])
423
483
# check if transpose converts ChannelsLast to ChannelsFirst
424
484
ndim = len (input_shape )
425
- perms = get_by_name (transp_node . attribute , "perm" ). ints
485
+ perms = get_transpose_perms (transp_node , model )
426
486
if list (to_channels_first_args (ndim )) == perms :
427
487
producer = model .find_producer (transp_node .input [0 ])
428
488
consumer = model .find_consumer (n .output [0 ])
@@ -505,3 +565,92 @@ def apply(self, model):
505
565
into subsequent node"
506
566
)
507
567
return model , graph_modified
568
+
569
+
570
+ class MoveOpPastFork (Transformation ):
571
+ """Move node operations past graph forks. Used when a node before a fork
572
+ can be merged with nodes in the branches
573
+ """
574
+
575
+ def __init__ (self , op_name_list ):
576
+ super ().__init__ ()
577
+ self .ops_to_move = op_name_list
578
+
579
+ def apply (self , model ):
580
+ graph = model .graph
581
+ graph_modified = False
582
+ nodes = [n for n in graph .node ]
583
+ node_ind = 0
584
+ for node in nodes :
585
+ node_ind += 1
586
+ if node .op_type in self .ops_to_move and model .is_fork_node (node ) and not model .is_join_node (node ):
587
+ # Restrict this transform to operations with constant parameters
588
+ # Assuming parameters is in input 1
589
+ if len (node .input ) > 1 :
590
+ op_init_param = model .get_initializer (node .input [1 ])
591
+ else :
592
+ op_init_param = None
593
+
594
+ # Check case when branches are empty and go
595
+ # to the same node
596
+ consumers = model .find_consumers (node .output [0 ])
597
+ assert len (consumers ) > 1 , "Must have >1 consumer"
598
+ unique_consumer = True
599
+ for consum_node in consumers [1 :]:
600
+ if consumers [0 ] != consum_node :
601
+ unique_consumer = False
602
+ break
603
+
604
+ if unique_consumer :
605
+ continue
606
+
607
+ for consumer_node in consumers [1 :]:
608
+ # create new node
609
+ new_output_tensor_name = model .make_new_valueinfo_name ()
610
+ if op_init_param is None :
611
+ new_inp_list = [node .input [0 ]]
612
+ else :
613
+ new_param_name = model .make_new_valueinfo_name ()
614
+ new_inp_list = [node .input [0 ], new_param_name ]
615
+ model .set_initializer (new_param_name , op_init_param )
616
+ new_node = deepcopy (node )
617
+ new_node .input [:] = new_inp_list
618
+ new_node .output [:] = [new_output_tensor_name ]
619
+ graph .node .insert (node_ind , new_node )
620
+ node_ind += 1
621
+
622
+ # change consumer input tensor
623
+ graph .node .remove (consumer_node )
624
+ for idx , consumer_input in enumerate (consumer_node .input ):
625
+ if consumer_input == node .output [0 ]:
626
+ consumer_node .input [idx ] = new_output_tensor_name
627
+ break
628
+ else :
629
+ raise Exception ("Consumer should have the current node output as input" )
630
+
631
+ graph .node .insert (node_ind , consumer_node )
632
+
633
+ graph_modified = True
634
+
635
+ model = model .transform (InferShapes ())
636
+ return (model , graph_modified )
637
+
638
+
639
+ class MoveAddPastFork (MoveOpPastFork ):
640
+ def __init__ (self ):
641
+ super ().__init__ (["Add" ])
642
+
643
+
644
+ class MoveMulPastFork (MoveOpPastFork ):
645
+ def __init__ (self ):
646
+ super ().__init__ (["Mul" ])
647
+
648
+
649
+ class MoveLinearPastFork (MoveOpPastFork ):
650
+ def __init__ (self ):
651
+ super ().__init__ (["Add" , "Mul" ])
652
+
653
+
654
+ class MoveTransposePastFork (MoveOpPastFork ):
655
+ def __init__ (self ):
656
+ super ().__init__ (["Transpose" ])
0 commit comments