Skip to content

Commit 3d7e820

Browse files
committed
Add shape & dtype inference
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent d6c3a9b commit 3d7e820

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8385,6 +8385,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
83858385
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
83868386
" return %0 : !torch.list<int>\n"
83878387
" }\n"
8388+
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
8389+
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
8390+
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
8391+
" return %1 : !torch.tuple<list<int>, list<int>>\n"
8392+
" }\n"
83888393
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
83898394
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
83908395
" return %0 : !torch.list<int>\n"
@@ -12806,6 +12811,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1280612811
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1280712812
" return %0#1 : !torch.int\n"
1280812813
" }\n"
12814+
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<int, int> {\n"
12815+
" %int4 = torch.constant.int 4\n"
12816+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12817+
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
12818+
" return %1 : !torch.tuple<int, int>\n"
12819+
" }\n"
1280912820
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
1281012821
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1281112822
" return %0#1 : !torch.int\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,10 @@ def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: L
12521252
def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
12531253
return pool1d(self, kernel_size, stride, padding, ceil_mode)
12541254

1255+
def aten〇max_pool1d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
1256+
maxpool1d = indices = pool1d(self, kernel_size, stride, padding, ceil_mode)
1257+
return maxpool1d, indices
1258+
12551259
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
12561260
return adaptive_avg_pool1d(self, output_size)
12571261

@@ -3497,6 +3501,10 @@ def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis
34973501
self_rank, self_dtype = self_rank_dtype
34983502
return self_dtype
34993503

3504+
def aten〇max_pool1d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[int, int]:
3505+
self_rank, self_dtype = self_rank_dtype
3506+
return self_dtype, torch.int64
3507+
35003508
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
35013509
def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
35023510
self_rank, self_dtype = self_rank_dtype

0 commit comments

Comments
 (0)