Skip to content

Commit 5f789c0

Browse files
authored
Fix channels_last_tagged_reshape_pass to handle mixed memory format tuple outputs (#11647)
### Summary This PR fixes the `channels_last_tagged_reshape_pass.py` to properly handle tuple outputs with mixed memory formats. Previously, the pass only checked and converted the first element of tuple outputs, which could lead to incorrect memory formats for other elements in the tuple. This fix is important for models that return multiple outputs with different memory format requirements, such as a mix of convolution outputs (which should be in NHWC format) and linear outputs (which should be in standard format). ### Test plan I added a new test class `ThreeOutputsModel` that has three outputs with different memory format requirements. I ensured that this test output given NCHW and NHWC inputs would evaluate properly. I also created a simpler 2 input class `ConvAddConvOutput` which operated on different inputs and returned two different dim order outputs.
1 parent 057558f commit 5f789c0

File tree

2 files changed

+62
-20
lines changed

2 files changed

+62
-20
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,10 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191
return not self.is_nhwc_node(node)
9292

9393
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return (
95-
node.target in self.memory_sensitive_ops_nhwc
96-
or node.name == "output"
97-
and not node.args[0][0].meta["val"].is_contiguous()
98-
)
94+
return node.target in self.memory_sensitive_ops_nhwc
9995

10096
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
101-
return (
102-
node.target in self.memory_sensitive_ops_nchw
103-
or node.name == "output"
104-
and node.args[0][0].meta["val"].is_contiguous()
105-
)
97+
return node.target in self.memory_sensitive_ops_nchw
10698

10799
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
108100
# There are two conditions that must be met for a node to be able to
@@ -380,18 +372,21 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
380372
# This node has no inputs so we don't need to change anything
381373
continue
382374

383-
if self.requires_nhwc_input(node):
375+
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
376+
if node.op == "output":
377+
out_tuple = node.args[0]
378+
for out_node in out_tuple:
379+
if out_node.meta["val"].is_contiguous():
380+
self.input_to_nchw(graph_module, out_node, node)
381+
else:
382+
self.input_to_nhwc(graph_module, out_node, node)
383+
elif self.requires_nhwc_input(node):
384384
# Nodes which enter this branch are ones that require their
385385
# first input to be nhwc. This makes this node's output nhwc too
386-
# Currently, all nodes like this should have all of their other
387-
# inputs as nchw, so fail if this is not true
388-
if node.name == "output":
389-
self.input_to_nhwc(graph_module, node.args[0][0], node)
390-
else:
391-
self.input_to_nhwc(graph_module, node.args[0], node)
392-
393-
for input_node in node.all_input_nodes[1:]:
394-
if self.is_nhwc_node(input_node):
386+
387+
self.input_to_nhwc(graph_module, node.args[0], node)
388+
for input_node in node.all_input_nodes:
389+
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
395390
raise AssertionError(
396391
f"Expected {input_node} to be NCHW in channels last reshape pass"
397392
)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,50 @@ def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
335335
)
336336
.run_method_and_compare_outputs()
337337
)
338+
339+
class ConvAddConvOutput(torch.nn.Module):
340+
def __init__(self):
341+
super().__init__()
342+
self.conv1 = torch.nn.Conv2d(3, 16, 3)
343+
self.conv2 = torch.nn.Conv2d(16, 16, 3)
344+
345+
def forward(self, x):
346+
y = self.conv1(x)
347+
z = torch.add(y, 1.0)
348+
out1 = self.conv2(z)
349+
out2 = z
350+
return out1, out2
351+
352+
ConvAddConvOutputModule = ConvAddConvOutput()
353+
354+
def test_conv_add_conv_output(self):
355+
x = torch.randn(1, 3, 8, 8)
356+
357+
self.run_tester(self.ConvAddConvOutput().eval(), (x,))
358+
359+
x_cl = x.to(memory_format=torch.channels_last)
360+
self.run_tester(self.ConvAddConvOutput().eval(), (x_cl,))
361+
362+
class ThreeOutputsModel(torch.nn.Module):
363+
def __init__(self):
364+
super().__init__()
365+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
366+
self.conv2 = torch.nn.Conv2d(3, 3, 3)
367+
self.linear = torch.nn.Linear(6, 6)
368+
369+
def forward(self, x):
370+
conv1_out = self.conv1(x)
371+
conv2_out = self.conv2(x)
372+
linear_out = self.linear(x)
373+
374+
return linear_out, conv1_out, conv2_out
375+
376+
ThreeOutputsModelModule = ThreeOutputsModel()
377+
378+
def test_three_outputs_model(self):
379+
x = torch.randn(1, 3, 6, 6)
380+
381+
self.run_tester(self.ThreeOutputsModelModule.eval(), (x,))
382+
383+
x_cl = x.to(memory_format=torch.channels_last)
384+
self.run_tester(self.ThreeOutputsModelModule.eval(), (x_cl,))

0 commit comments

Comments
 (0)