Skip to content

[TorchToLinalg] Add lowering of torch.aten.pixel_unshuffle op #4278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

alaa-ali
Copy link
Contributor

@alaa-ali alaa-ali commented Jul 18, 2025

This PR will fix the following issue:
Add lowering of torch.aten.pixel_unshuffle op to linalg dialect

This code snippet can reproduce the issue:

func.func @pixel_unshuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
  %int2 = torch.constant.int 2
  %0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
  return %0 : !torch.vtensor<[1,32,2,2],f32>
}

The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
and PyTorch implementation could be found in main/aten/src/ATen/native/PixelShuffle.cpp:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp

With code changes, torch.aten.pixel_unshuffle will be lowered to the following:

module {
  func.func @main(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %int2 = torch.constant.int 2
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int3 = torch.constant.int 3
    %int4 = torch.constant.int 4
    %int5 = torch.constant.int 5
    %0 = torch.prim.ListConstruct %int0, %int1, %int3, %int5, %int2, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prims.split_dim %arg0, %int2, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32>
    %2 = torch.prims.split_dim %1, %int4, %int2 : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32>
    %3 = torch.aten.permute %2, %0 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list<int> -> !torch.vtensor<[1,8,2,2,2,2],f32>
    %4 = torch.prims.collapse %3, %int2, %int3 : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32>
    %5 = torch.prims.collapse %4, %int1, %int2 : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
    return %5 : !torch.vtensor<[1,32,2,2],f32>
  }
}

@alaa-ali
Copy link
Contributor Author

Copy link
Contributor

@ivangarcia44 ivangarcia44 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general it looks good to me. My main feedback is:

  1. Can this be generalized for ND if the pixel shuffle operator allows ND?
  2. What is the expected batch dimension behavior? Can this be tested?

Thanks!

@ivangarcia44
Copy link
Contributor

Are there tests for batch batch dimension omission and more than one batch dimensions? Thanks

@alaa-ali
Copy link
Contributor Author

alaa-ali commented Aug 1, 2025

Are there tests for batch batch dimension omission and more than one batch dimensions? Thanks

This has been captured in e2e tests. Thanks for your feedback.

@ivangarcia44
Copy link
Contributor

LGTM

This PR has tests for rank 3, 4, and 5 which cover the case of no batch dimension, 1 batch dimension and 2 batch dimensions. All my concerns are addressed in the PR. Looks good to me.

@alaa-ali
Copy link
Contributor Author

alaa-ali commented Aug 4, 2025

Hi everyone, a kind reminder to provide feedback, please. This PR adds support of torch.aten.pixel_unshuffle op.
Thank you

@rsuderman @zjgarvey @penguin-wwy @newling @sahas3 @ramiro050 @qedawkins @vivekkhandelwal1

// (*leading_dims, C, H*r, W*r),
//
// where leading_dims is of size N, then
// X = pixel_unshuffle(input, downscale_factor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using r here instead of downscale_factor will be better for consistency. You can also move line 3731 mentioning r is the downscale_factor above to help readability.

Comment on lines +3728 to +3729
// X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W)
// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need two collapses -- isn't collapsing directly to C*r*r sufficient?

Comment on lines +3762 to +3781
auto getTypeFromShape = [inOptionalDType](auto &&vals) {
// Get a vector of integers from a vector of Values.
auto getIntShape = [](auto &&vals) {
SmallVector<int64_t> shape;
shape.reserve(vals.size());
for (auto v : vals) {
int64_t cst_val;
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
shape.push_back(cst_val);
} else {
shape.push_back(kUnknownSize);
}
}
return shape;
};

const auto intShape = getIntShape(vals);
return ValueTensorType::get(vals[0].getContext(),
llvm::ArrayRef(intShape), inOptionalDType);
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods were added as utilities in #4259. Once that is merged, can you update your code to reuse the utilities?

Comment on lines +3785 to +3792
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
// folded to a ConstantOp.
auto getDimSize = [&](uint64_t i) -> Value {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is shared with #4259 as well. It'll be good to move this into an utility method and probably move all these utilities to lib/Conversion/Utils/Utils.cpp as the shared location to be used elsewhere in the code base too.

Comment on lines +3823 to +3824
SmallVector<Value> partiallyExpandedShape = leadingDims;
partiallyExpandedShape.append({inC, outH, factor, inW});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this before it's use also rename to heightSplitShape for readability?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants