Open
Description
FX Graph:
def forward(self, primals_1, primals_8, tangents_1): convolution_backward = torch.ops.aten.convolution_backward.default(tangents_1, primals_8, primals_1, [16], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]); tangents_1 = primals_8 = primals_1 = None
getitem = convolution_backward[0]
getitem_1 = convolution_backward[1]
getitem_2 = convolution_backward[2]; convolution_backward = None
return [getitem_1, getitem_2, None, None, None, None, None, getitem]
Converted torch dialect
module attributes {torch.debug_module_name = "GraphModule"} {
func.func @forward(%arg0: !torch.vtensor<[16,3,5,5],f32>, %arg1: !torch.vtensor<[2,3,200,200],f32>, %arg2: !torch.vtensor<[2,16,198,198],f32>) -> (!torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32>) {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%none = torch.constant.none
%false = torch.constant.bool false
%true = torch.constant.bool true
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
torch.runtime.assert %true, "unimplemented: only strides of 1 supported."
torch.runtime.assert %true, "unimplemented: only strides of 1 supported."
torch.runtime.assert %true, "unimplemented: only dilations of 1 supported."
torch.runtime.assert %true, "unimplemented: only dilations of 1 supported."
%2 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.flip %arg0, %2 : !torch.vtensor<[16,3,5,5],f32>, !torch.list<int> -> !torch.vtensor<[16,3,5,5],f32>
%4 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
%6 = torch.aten.convolution %arg2, %5, %none, %0, %4, %0, %false, %1, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.vtensor<[16,3,5,5],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[2,3,200,200],f32>
%7 = torch.aten.transpose.int %arg2, %int0, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,16,198,198],f32>
%8 = torch.aten.transpose.int %arg1, %int0, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,200,200],f32>
%9 = torch.aten.convolution %8, %7, %none, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.vtensor<[2,16,198,198],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
%10 = torch.aten.transpose.int %9, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
%11 = torch.prim.ListConstruct %int0, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%12 = torch.aten.sum.dim_IntList %arg2, %11, %false, %none : !torch.vtensor<[2,16,198,198],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[16],f32>
return %10, %12, %6 : !torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32>
}
}
It seems that the TransposeInt's type has not been set correctly? https://github.com/llvm/torch-mlir/blob/main/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L1444-L1445
%5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>