Skip to content

Commit d86ec13

Browse files
committed
[Thresholding] Generalize data layouts for node execution
See fastmachinelearning/qonnx#143 for the similar generalization applied to QONNX MultiThreshold
1 parent 94f887b commit d86ec13

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

src/finn/custom_op/fpgadataflow/thresholding.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -243,16 +243,29 @@ def execute_node(self, context, graph):
243243
inp_values = context[node.input[0]]
244244
th_val = context[node.input[1]]
245245
out_bias = self.get_nodeattr("ActVal")
246-
# MT expects inputs to be in the shape (N,C,H,W) or (N, C)
247-
# if 4D then input values in context are (N,H,W,C) and need to
248-
# be transposed.
249-
# if 2D then inputs can be passed directly to MT function
250-
is_4d = len(inp_values.shape) == 4
251-
if is_4d:
252-
inp_values = np.transpose(inp_values, (0, 3, 1, 2))
246+
247+
# Consider the data layout for transposing the input into the format
248+
# accepted by the multithreshold function above, i.e, the channel
249+
# dimension is along the axis with index 1.
250+
data_layout = None
251+
# If there is no layout annotation, guess based on rank of the tensor
252+
# TODO: Currently there is no mechanism here to get the layout
253+
# annotation, we allways guess, but this matches the previous behavior.
254+
if len(inp_values.shape) < 5:
255+
# Maps tensor rank to layout annotation
256+
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NHWC"}
257+
# Lookup the layout required by this input shape
258+
data_layout = rank_to_layout[len(inp_values.shape)]
259+
# Lookup the index of the channel dimension in the data layout
260+
# Note: Assumes there is at most one "C" which denotes the channel
261+
# dimension
262+
cdim = data_layout.index("C") if "C" in data_layout else 1
263+
# Rearrange the input to the expected (N, C, ...) layout
264+
inp_values = inp_values.swapaxes(cdim, 1)
253265
y = multithreshold(inp_values, th_val, out_bias=out_bias)
254-
if is_4d:
255-
y = y.transpose(0, 2, 3, 1)
266+
# Rearrange the output back to the original layout
267+
y = y.swapaxes(cdim, 1)
268+
256269
act = DataType[self.get_nodeattr("outputDataType")]
257270
if act == DataType["BIPOLAR"]:
258271
# binary to bipolar

0 commit comments

Comments
 (0)