Skip to content

Commit e5d5903

Browse files
authored
Merge pull request #149 from fastmachinelearning/feature/improved_chanlast_eltwiseops
Improved channels-last via elementwise op generalization
2 parents 9b22db4 + a18a186 commit e5d5903

File tree

6 files changed

+328
-49
lines changed

6 files changed

+328
-49
lines changed

src/qonnx/custom_op/channels_last/batch_normalization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def verify_node(self):
8787

8888
# verify number of attributes
8989
num_of_attr = 2
90-
if len(node.attribute) == num_of_attr:
90+
if len(node.attribute) >= num_of_attr:
9191
info_messages.append("The number of attributes is correct")
9292
else:
9393
info_messages.append(
Binary file not shown.

src/qonnx/transformation/channels_last.py

+195-46
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,110 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
import numpy as np
2930
import warnings
31+
from copy import deepcopy
3032
from onnx import TensorProto, helper
3133

32-
from qonnx.analysis.topology import is_linear
34+
from qonnx.core.modelwrapper import ModelWrapper
3335
from qonnx.custom_op import channels_last
3436
from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args
3537
from qonnx.transformation.base import Transformation
3638
from qonnx.transformation.fold_constants import FoldConstants
39+
from qonnx.transformation.general import SortGraph
3740
from qonnx.transformation.infer_shapes import InferShapes
3841
from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
3942
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
4043
from qonnx.util.basic import get_by_name
44+
from qonnx.util.onnx import is_eltwise_optype
4145

4246
# Standard ONNX nodes which require a ChannelsLast data format to function properly
4347
_channelsLast_node_types = list(channels_last.custom_op.keys())
4448

4549
# Nodes, which do not modify the shape of the tensor
4650
# And modify all values in the same way.
47-
_move_through_nodes = ["Quant", "Relu"]
51+
_move_through_nodes = ["Quant", "Relu", "Selu", "LeakyRelu", "Sigmoid", "Tanh"]
4852

4953
# Nodes, which do not modify the shape of the tensor,
5054
# And modify all values in the same way, if the second tensor is a scalar.
5155
_move_through_nodes_if_scalar = ["Mul", "Div", "Sub", "Add"]
5256

5357

58+
def get_transpose_perms(transpose_node, model):
59+
perm = get_by_name(transpose_node.attribute, "perm")
60+
ndim = len(model.get_tensor_shape(transpose_node.input[0]))
61+
if perm is None:
62+
# default perm is to reverse the dim order
63+
return list(range(ndim - 1, -1, -1))
64+
else:
65+
return list(perm.ints)
66+
67+
68+
def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrapper):
69+
t0 = transpose_node.input[0]
70+
t1 = transpose_node.output[0]
71+
t2 = eltwise_node.output[0]
72+
subgraph_inp_shape = model.get_tensor_shape(t0)
73+
ndim_inp = len(subgraph_inp_shape)
74+
perm = get_transpose_perms(transpose_node, model)
75+
76+
# before: t0 -> transpose -> t1 -> eltwise -> t2
77+
# after: t0 -> eltwise -> t1 -> transpose -> t2
78+
# find the eltwise inp index fed by transpose
79+
transpose_in_ind = list(eltwise_node.input).index(t1)
80+
# check all inputs for the eltwise op:
81+
# we need to ensure those inputs get inverse-transposed
82+
# to keep the graph semantics intact
83+
for ind, eltwise_inp in enumerate(eltwise_node.input):
84+
if ind == transpose_in_ind:
85+
# the input that feeds from the original transpose
86+
# node will be implicitly inverse-transposed, since we'll be
87+
# moving that transpose node past the eltwise op
88+
continue
89+
inp_shape = model.get_tensor_shape(eltwise_inp)
90+
ndim = len(inp_shape)
91+
if ndim == 0:
92+
# scalar input, always broadcastable, no action needed
93+
continue
94+
elif ndim == ndim_inp:
95+
# input with matching dimensions, add inverse transpose
96+
new_t_inp = model.make_new_valueinfo_name()
97+
inv_perm = np.argsort(perm)
98+
new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm)
99+
t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape
100+
model.set_tensor_shape(new_t_inp, t_shape)
101+
eltwise_node.input[ind] = new_t_inp
102+
model.graph.node.append(new_transpose_node)
103+
else:
104+
# input with non-matching dimensions, assume broadcastable
105+
# first add Unsqueeze node to match number of dimensions
106+
unsqueeze_param_name = model.make_new_valueinfo_name()
107+
model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64))
108+
unsqueeze_out_name = model.make_new_valueinfo_name()
109+
new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name])
110+
unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape
111+
model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape)
112+
model.graph.node.append(new_unsqueeze_node)
113+
# now add inverse transpose
114+
new_t_inp = model.make_new_valueinfo_name()
115+
inv_perm = np.argsort(perm)
116+
new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm)
117+
t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape
118+
model.set_tensor_shape(new_t_inp, t_shape)
119+
eltwise_node.input[ind] = new_t_inp
120+
model.graph.node.append(new_transpose_node)
121+
# rewire to swap transpose and eltwise node order
122+
eltwise_node.input[transpose_in_ind] = t0
123+
eltwise_node.output[0] = t1
124+
transpose_node.input[0] = t1
125+
transpose_node.output[0] = t2
126+
# t1 tensor shape changes to inp_shape
127+
model.set_tensor_shape(t1, subgraph_inp_shape)
128+
model = model.transform(SortGraph())
129+
model = model.transform(FoldConstants())
130+
return model
131+
132+
54133
class ConvertToChannelsLastAndClean(Transformation):
55134
"""
56135
Converts data layout dependent nodes to ChannelsLast nodes and inserts transformations.
@@ -67,8 +146,7 @@ def __init__(self, make_input_channels_last=False):
67146
super().__init__()
68147
self._make_input_channels_last = make_input_channels_last
69148

70-
def apply(self, model):
71-
assert model.analysis(is_linear)["is_linear"], "Only linear and non-branching models are supported at this moment."
149+
def apply(self, model: ModelWrapper):
72150
assert model.check_all_tensor_shapes_specified(), (
73151
"All tensor shapes must be specified. " "Consider running InferShapes."
74152
)
@@ -85,8 +163,9 @@ def apply(self, model):
85163
# Technically only required if something changed in the previous trafo
86164
model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos())
87165

88-
# Apply MoveChanLastDownStream
166+
# Apply MoveChanLastDownStream and MoveTransposePastFork
89167
model = model.transform(MoveChanFirstDownstream())
168+
model = model.transform(MoveTransposePastFork())
90169

91170
# Run RemoveConsecutiveChanFirstAndChanLastTrafos again,
92171
# Technically only required if something changed in the previous trafo
@@ -218,9 +297,9 @@ def apply(self, model):
218297
# Check the input shape and make sure we support it
219298
input_shape = model.get_tensor_shape(n.input[0])
220299
# Check that this is a "to chan first" trafo
221-
perm_1 = get_by_name(n.attribute, "perm")
300+
perm_1 = get_transpose_perms(n, model)
222301
ndim = len(input_shape)
223-
if list(to_channels_first_args(ndim)) == perm_1.ints:
302+
if list(to_channels_first_args(ndim)) == perm_1:
224303
successor_nodes = model.find_direct_successors(n)
225304
if successor_nodes is None:
226305
continue
@@ -229,8 +308,8 @@ def apply(self, model):
229308
if successor_node.op_type == "Transpose":
230309
# Check that this is a "to chan last" trafo,
231310
# if so both can get removed.
232-
perm_2 = get_by_name(successor_node.attribute, "perm")
233-
if list(to_channels_last_args(ndim)) == perm_2.ints:
311+
perm_2 = get_transpose_perms(successor_node, model)
312+
if list(to_channels_last_args(ndim)) == perm_2:
234313
# Connect original input to new output
235314
input_tensor = n.input[0]
236315
output_tensor_name = successor_node.output[0]
@@ -257,7 +336,7 @@ class MoveChanLastUpstream(Transformation):
257336
Moves channel last transformations further upstream.
258337
"""
259338

260-
def apply(self, model):
339+
def apply(self, model: ModelWrapper):
261340
graph = model.graph
262341
node_ind = 0
263342
graph_modified = False
@@ -268,8 +347,8 @@ def apply(self, model):
268347
# Check the input shape and make sure we support it
269348
input_shape = model.get_tensor_shape(n.input[0])
270349
ndim = len(input_shape)
271-
perm = get_by_name(n.attribute, "perm")
272-
if list(to_channels_last_args(ndim)) == perm.ints:
350+
perm = get_transpose_perms(n, model)
351+
if list(to_channels_last_args(ndim)) == perm:
273352
predecessors = model.find_direct_predecessors(n)
274353
# Check if we reached the top of the graph
275354
if predecessors is None:
@@ -285,6 +364,10 @@ def apply(self, model):
285364
if second_inp_shape == [1] or second_inp_shape == []:
286365
move_through_valid |= True
287366

367+
# don't move through if the predecessor output is a fork
368+
if model.is_fork_node(predecessor):
369+
move_through_valid = False
370+
288371
# Apply move through trafo if possible
289372
if move_through_valid:
290373
# Input tensors are always input 0
@@ -334,52 +417,29 @@ def apply(self, model):
334417
node_ind = 0
335418
graph_modified = False
336419
# Find transpose nodes, which are "to chan first" trafos
337-
for n in graph.node:
420+
for node in graph.node:
338421
node_ind += 1
339-
if n.op_type == "Transpose":
422+
if node.op_type == "Transpose":
340423
# Check the input shape and make sure we support it
341-
input_shape = model.get_tensor_shape(n.input[0])
424+
input_shape = model.get_tensor_shape(node.input[0])
342425
ndim = len(input_shape)
343-
perm = get_by_name(n.attribute, "perm")
344-
if list(to_channels_first_args(ndim)) == perm.ints:
426+
perm = get_transpose_perms(node, model)
427+
if list(to_channels_first_args(ndim)) == perm:
345428
# Do not move the node, if it is at the top of the graph,
346429
# this is a strange edge case, for 1D networks, where channels last and channels first trafos
347430
# are identical.
348-
predecessors = model.find_direct_predecessors(n)
431+
predecessors = model.find_direct_predecessors(node)
349432
if predecessors is None:
350433
continue
351434

352-
successors = model.find_direct_successors(n)
435+
successors = model.find_direct_successors(node)
353436
if successors is None:
354437
continue
355438
successor = successors[0]
439+
transpose_node = node
356440

357-
# Check if we can simply move through the next node
358-
move_through_valid = successor.op_type in _move_through_nodes
359-
# Check if we have a node, which applies a scalar change,
360-
# then we can also move through.
361-
if successor.op_type in _move_through_nodes_if_scalar:
362-
second_inp_shape = model.get_tensor_shape(successor.input[1])
363-
if second_inp_shape == [1] or second_inp_shape == []:
364-
move_through_valid |= True
365-
# Apply move through trafo if possible
366-
if move_through_valid:
367-
# Collect all tensors connecting n and successor
368-
# and surrounding nodes
369-
tensor_1 = n.input[0]
370-
tensor_2 = n.output[0]
371-
tensor_3 = successor.output[0]
372-
# Now connect the tensors to the nodes again,
373-
# but in different order
374-
successor.input[0] = tensor_1
375-
successor.output[0] = tensor_2
376-
n.input[0] = tensor_2
377-
n.output[0] = tensor_3
378-
379-
# Change the shape of the middle tensor
380-
target_shape = model.get_tensor_shape(tensor_1)
381-
model.set_tensor_shape(tensor_2, target_shape)
382-
441+
if is_eltwise_optype(successor.op_type):
442+
model = move_transpose_past_eltwise(transpose_node, successor, model)
383443
graph_modified = True
384444
return model, graph_modified
385445

@@ -422,7 +482,7 @@ def apply(self, model):
422482
input_shape = model.get_tensor_shape(transp_node.input[0])
423483
# check if transpose converts ChannelsLast to ChannelsFirst
424484
ndim = len(input_shape)
425-
perms = get_by_name(transp_node.attribute, "perm").ints
485+
perms = get_transpose_perms(transp_node, model)
426486
if list(to_channels_first_args(ndim)) == perms:
427487
producer = model.find_producer(transp_node.input[0])
428488
consumer = model.find_consumer(n.output[0])
@@ -505,3 +565,92 @@ def apply(self, model):
505565
into subsequent node"
506566
)
507567
return model, graph_modified
568+
569+
570+
class MoveOpPastFork(Transformation):
571+
"""Move node operations past graph forks. Used when a node before a fork
572+
can be merged with nodes in the branches
573+
"""
574+
575+
def __init__(self, op_name_list):
576+
super().__init__()
577+
self.ops_to_move = op_name_list
578+
579+
def apply(self, model):
580+
graph = model.graph
581+
graph_modified = False
582+
nodes = [n for n in graph.node]
583+
node_ind = 0
584+
for node in nodes:
585+
node_ind += 1
586+
if node.op_type in self.ops_to_move and model.is_fork_node(node) and not model.is_join_node(node):
587+
# Restrict this transform to operations with constant parameters
588+
# Assuming parameters is in input 1
589+
if len(node.input) > 1:
590+
op_init_param = model.get_initializer(node.input[1])
591+
else:
592+
op_init_param = None
593+
594+
# Check case when branches are empty and go
595+
# to the same node
596+
consumers = model.find_consumers(node.output[0])
597+
assert len(consumers) > 1, "Must have >1 consumer"
598+
unique_consumer = True
599+
for consum_node in consumers[1:]:
600+
if consumers[0] != consum_node:
601+
unique_consumer = False
602+
break
603+
604+
if unique_consumer:
605+
continue
606+
607+
for consumer_node in consumers[1:]:
608+
# create new node
609+
new_output_tensor_name = model.make_new_valueinfo_name()
610+
if op_init_param is None:
611+
new_inp_list = [node.input[0]]
612+
else:
613+
new_param_name = model.make_new_valueinfo_name()
614+
new_inp_list = [node.input[0], new_param_name]
615+
model.set_initializer(new_param_name, op_init_param)
616+
new_node = deepcopy(node)
617+
new_node.input[:] = new_inp_list
618+
new_node.output[:] = [new_output_tensor_name]
619+
graph.node.insert(node_ind, new_node)
620+
node_ind += 1
621+
622+
# change consumer input tensor
623+
graph.node.remove(consumer_node)
624+
for idx, consumer_input in enumerate(consumer_node.input):
625+
if consumer_input == node.output[0]:
626+
consumer_node.input[idx] = new_output_tensor_name
627+
break
628+
else:
629+
raise Exception("Consumer should have the current node output as input")
630+
631+
graph.node.insert(node_ind, consumer_node)
632+
633+
graph_modified = True
634+
635+
model = model.transform(InferShapes())
636+
return (model, graph_modified)
637+
638+
639+
class MoveAddPastFork(MoveOpPastFork):
640+
def __init__(self):
641+
super().__init__(["Add"])
642+
643+
644+
class MoveMulPastFork(MoveOpPastFork):
645+
def __init__(self):
646+
super().__init__(["Mul"])
647+
648+
649+
class MoveLinearPastFork(MoveOpPastFork):
650+
def __init__(self):
651+
super().__init__(["Add", "Mul"])
652+
653+
654+
class MoveTransposePastFork(MoveOpPastFork):
655+
def __init__(self):
656+
super().__init__(["Transpose"])

0 commit comments

Comments
 (0)