Skip to content

Commit 338213f

Browse files
committed
Add type checks & allow MVN expansion by default
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent a42b9a3 commit 338213f

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,18 +1615,31 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16151615
binder.tensorResultType(resultType)) {
16161616
return failure();
16171617
}
1618+
if (!resultType.hasSizes() || !resultType.hasDtype()) {
1619+
return failure();
1620+
}
1621+
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
1622+
if (!inputTy || !inputTy.hasSizes()) {
1623+
return failure();
1624+
}
1625+
int64_t inputRank = inputTy.getSizes().size();
1626+
16181627
Location loc = binder.getLoc();
16191628
Value keepDim = rewriter.create<Torch::ConstantBoolOp>(loc, true);
16201629
Value unBiased = rewriter.create<Torch::ConstantBoolOp>(loc, false);
16211630
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
16221631

1623-
ArrayRef<int64_t> input_shape = resultType.getSizes();
1624-
SmallVector<int64_t> reduced_shape(input_shape);
1632+
ArrayRef<int64_t> output_shape = resultType.getSizes();
1633+
SmallVector<int64_t> reduced_shape(output_shape);
1634+
16251635
for (int64_t i : axes) {
1636+
int64_t dim = Torch::toPositiveDim(i, inputRank);
1637+
if (!Torch::isValidDim(dim, inputRank)) {
1638+
return failure();
1639+
}
16261640
reduced_shape[i] = 1;
16271641
}
1628-
1629-
Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get(
1642+
Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get(
16301643
resultType.getContext(), reduced_shape, resultType.getDtype());
16311644
SmallVector<Value> cstAxes;
16321645
for (int64_t i : axes) {
@@ -1638,29 +1651,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16381651
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
16391652
cstAxes);
16401653
Value mean = rewriter.create<Torch::AtenMeanDimOp>(
1641-
loc, meanOutTy, input, axes_list, keepDim, none);
1642-
1654+
loc, reducedOutTy, input, axes_list, keepDim, none);
16431655
Value variance = rewriter.create<Torch::AtenVarDimOp>(
1644-
loc, meanOutTy, input, axes_list, unBiased, keepDim);
1645-
1656+
loc, reducedOutTy, input, axes_list, unBiased, keepDim);
16461657
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
16471658
loc, rewriter.getI64IntegerAttr(1));
16481659
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
16491660
loc, rewriter.getF64FloatAttr(1e-9));
16501661
variance = rewriter.create<Torch::AtenAddScalarOp>(
1651-
loc, meanOutTy, variance, cstEps, cstOne);
1652-
1653-
Value sqrt =
1654-
rewriter.create<Torch::AtenSqrtOp>(loc, meanOutTy, variance);
1655-
1656-
Value subValue = rewriter.create<Torch::AtenSubTensorOp>(
1662+
loc, reducedOutTy, variance, cstEps, cstOne);
1663+
Value sqrtVar =
1664+
rewriter.create<Torch::AtenSqrtOp>(loc, reducedOutTy, variance);
1665+
Value inputMinusMean = rewriter.create<Torch::AtenSubTensorOp>(
16571666
loc, resultType, input, mean, cstOne);
1658-
16591667
Value meanVarNorm = rewriter.create<Torch::AtenDivTensorOp>(
1660-
loc, resultType, subValue, sqrt);
1668+
loc, resultType, inputMinusMean, sqrtVar);
16611669

16621670
rewriter.replaceOp(binder.op, meanVarNorm);
1663-
16641671
return success();
16651672
});
16661673
patterns.onOp(

python/torch_mlir/extras/onnx_importer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ class Config:
103103
function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field(
104104
default_factory=lambda: {
105105
# Default domain (ONNX built-in ops)
106-
"": {}
106+
"": {
107+
"MeanVarianceNormalization",
108+
}
107109
}
108110
)
109111

0 commit comments

Comments
 (0)