Skip to content

Transpose error when decomposing the convolution backward #1772

Open
@ZihengJiang

Description

@ZihengJiang

FX Graph:

def forward(self, primals_1, primals_8, tangents_1):                                                                                                                                                                                                                              convolution_backward = torch.ops.aten.convolution_backward.default(tangents_1, primals_8, primals_1, [16], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]);  tangents_1 = primals_8 = primals_1 = None
    getitem = convolution_backward[0]
    getitem_1 = convolution_backward[1]
    getitem_2 = convolution_backward[2];  convolution_backward = None
    return [getitem_1, getitem_2, None, None, None, None, None, getitem]

Converted torch dialect

module attributes {torch.debug_module_name = "GraphModule"} {
  func.func @forward(%arg0: !torch.vtensor<[16,3,5,5],f32>, %arg1: !torch.vtensor<[2,3,200,200],f32>, %arg2: !torch.vtensor<[2,16,198,198],f32>) -> (!torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32>) {
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %none = torch.constant.none
    %false = torch.constant.bool false
    %true = torch.constant.bool true
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    torch.runtime.assert %true, "unimplemented: only strides of 1 supported."
    torch.runtime.assert %true, "unimplemented: only strides of 1 supported."
    torch.runtime.assert %true, "unimplemented: only dilations of 1 supported."
    torch.runtime.assert %true, "unimplemented: only dilations of 1 supported."
    %2 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.flip %arg0, %2 : !torch.vtensor<[16,3,5,5],f32>, !torch.list<int> -> !torch.vtensor<[16,3,5,5],f32>
    %4 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
    %6 = torch.aten.convolution %arg2, %5, %none, %0, %4, %0, %false, %1, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.vtensor<[16,3,5,5],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[2,3,200,200],f32>
    %7 = torch.aten.transpose.int %arg2, %int0, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,16,198,198],f32>
    %8 = torch.aten.transpose.int %arg1, %int0, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,200,200],f32>
    %9 = torch.aten.convolution %8, %7, %none, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.vtensor<[2,16,198,198],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
    %10 = torch.aten.transpose.int %9, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
    %11 = torch.prim.ListConstruct %int0, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %12 = torch.aten.sum.dim_IntList %arg2, %11, %false, %none : !torch.vtensor<[2,16,198,198],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[16],f32>
    return %10, %12, %6 : !torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32>
  }
}

It seems that the TransposeInt's type has not been set correctly? https://github.com/llvm/torch-mlir/blob/main/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L1444-L1445

    %5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions