@@ -243,16 +243,29 @@ def execute_node(self, context, graph):
243
243
inp_values = context [node .input [0 ]]
244
244
th_val = context [node .input [1 ]]
245
245
out_bias = self .get_nodeattr ("ActVal" )
246
- # MT expects inputs to be in the shape (N,C,H,W) or (N, C)
247
- # if 4D then input values in context are (N,H,W,C) and need to
248
- # be transposed.
249
- # if 2D then inputs can be passed directly to MT function
250
- is_4d = len (inp_values .shape ) == 4
251
- if is_4d :
252
- inp_values = np .transpose (inp_values , (0 , 3 , 1 , 2 ))
246
+
247
+ # Consider the data layout for transposing the input into the format
248
+ # accepted by the multithreshold function above, i.e, the channel
249
+ # dimension is along the axis with index 1.
250
+ data_layout = None
251
+ # If there is no layout annotation, guess based on rank of the tensor
252
+ # TODO: Currently there is no mechanism here to get the layout
253
+ # annotation, we allways guess, but this matches the previous behavior.
254
+ if len (inp_values .shape ) < 5 :
255
+ # Maps tensor rank to layout annotation
256
+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NHWC" }
257
+ # Lookup the layout required by this input shape
258
+ data_layout = rank_to_layout [len (inp_values .shape )]
259
+ # Lookup the index of the channel dimension in the data layout
260
+ # Note: Assumes there is at most one "C" which denotes the channel
261
+ # dimension
262
+ cdim = data_layout .index ("C" ) if "C" in data_layout else 1
263
+ # Rearrange the input to the expected (N, C, ...) layout
264
+ inp_values = inp_values .swapaxes (cdim , 1 )
253
265
y = multithreshold (inp_values , th_val , out_bias = out_bias )
254
- if is_4d :
255
- y = y .transpose (0 , 2 , 3 , 1 )
266
+ # Rearrange the output back to the original layout
267
+ y = y .swapaxes (cdim , 1 )
268
+
256
269
act = DataType [self .get_nodeattr ("outputDataType" )]
257
270
if act == DataType ["BIPOLAR" ]:
258
271
# binary to bipolar
0 commit comments