@@ -1615,18 +1615,31 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1615
1615
binder.tensorResultType (resultType)) {
1616
1616
return failure ();
1617
1617
}
1618
+ if (!resultType.hasSizes () || !resultType.hasDtype ()) {
1619
+ return failure ();
1620
+ }
1621
+ auto inputTy = cast<Torch::ValueTensorType>(input.getType ());
1622
+ if (!inputTy || !inputTy.hasSizes ()) {
1623
+ return failure ();
1624
+ }
1625
+ int64_t inputRank = inputTy.getSizes ().size ();
1626
+
1618
1627
Location loc = binder.getLoc ();
1619
1628
Value keepDim = rewriter.create <Torch::ConstantBoolOp>(loc, true );
1620
1629
Value unBiased = rewriter.create <Torch::ConstantBoolOp>(loc, false );
1621
1630
Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1622
1631
1623
- ArrayRef<int64_t > input_shape = resultType.getSizes ();
1624
- SmallVector<int64_t > reduced_shape (input_shape);
1632
+ ArrayRef<int64_t > output_shape = resultType.getSizes ();
1633
+ SmallVector<int64_t > reduced_shape (output_shape);
1634
+
1625
1635
for (int64_t i : axes) {
1636
+ int64_t dim = Torch::toPositiveDim (i, inputRank);
1637
+ if (!Torch::isValidDim (dim, inputRank)) {
1638
+ return failure ();
1639
+ }
1626
1640
reduced_shape[i] = 1 ;
1627
1641
}
1628
-
1629
- Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get (
1642
+ Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get (
1630
1643
resultType.getContext (), reduced_shape, resultType.getDtype ());
1631
1644
SmallVector<Value> cstAxes;
1632
1645
for (int64_t i : axes) {
@@ -1638,29 +1651,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1638
1651
Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
1639
1652
cstAxes);
1640
1653
Value mean = rewriter.create <Torch::AtenMeanDimOp>(
1641
- loc, meanOutTy, input, axes_list, keepDim, none);
1642
-
1654
+ loc, reducedOutTy, input, axes_list, keepDim, none);
1643
1655
Value variance = rewriter.create <Torch::AtenVarDimOp>(
1644
- loc, meanOutTy, input, axes_list, unBiased, keepDim);
1645
-
1656
+ loc, reducedOutTy, input, axes_list, unBiased, keepDim);
1646
1657
Value cstOne = rewriter.create <Torch::ConstantIntOp>(
1647
1658
loc, rewriter.getI64IntegerAttr (1 ));
1648
1659
Value cstEps = rewriter.create <Torch::ConstantFloatOp>(
1649
1660
loc, rewriter.getF64FloatAttr (1e-9 ));
1650
1661
variance = rewriter.create <Torch::AtenAddScalarOp>(
1651
- loc, meanOutTy, variance, cstEps, cstOne);
1652
-
1653
- Value sqrt =
1654
- rewriter.create <Torch::AtenSqrtOp>(loc, meanOutTy, variance);
1655
-
1656
- Value subValue = rewriter.create <Torch::AtenSubTensorOp>(
1662
+ loc, reducedOutTy, variance, cstEps, cstOne);
1663
+ Value sqrtVar =
1664
+ rewriter.create <Torch::AtenSqrtOp>(loc, reducedOutTy, variance);
1665
+ Value inputMinusMean = rewriter.create <Torch::AtenSubTensorOp>(
1657
1666
loc, resultType, input, mean, cstOne);
1658
-
1659
1667
Value meanVarNorm = rewriter.create <Torch::AtenDivTensorOp>(
1660
- loc, resultType, subValue, sqrt );
1668
+ loc, resultType, inputMinusMean, sqrtVar );
1661
1669
1662
1670
rewriter.replaceOp (binder.op , meanVarNorm);
1663
-
1664
1671
return success ();
1665
1672
});
1666
1673
patterns.onOp (
0 commit comments