You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix channels_last_tagged_reshape_pass to handle mixed memory format tuple outputs (#11647)
### Summary
This PR fixes the `channels_last_tagged_reshape_pass.py` to properly
handle tuple outputs with mixed memory formats. Previously, the pass
only checked and converted the first element of tuple outputs, which
could lead to incorrect memory formats for other elements in the tuple.
This fix is important for models that return multiple outputs with
different memory format requirements, such as a mix of convolution
outputs (which should be in NHWC format) and linear outputs (which
should be in standard format).
### Test plan
I added a new test class `ThreeOutputsModel` that has three outputs with
different memory format requirements. I ensured that this test output
given NCHW and NHWC inputs would evaluate properly. I also created a
simpler 2 input class `ConvAddConvOutput` which operated on different
inputs and returned two different dim order outputs.
0 commit comments