Skip to content

Commit 7edde2e

Browse files
committed
Use cast instead of dyn_cast
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 7fb518d commit 7edde2e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10385,9 +10385,9 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1038510385
Value reductionValue = op.getReduction();
1038610386
Value logTargetValue = op.getLogTarget();
1038710387

10388-
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
10389-
auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10390-
auto outTy = dyn_cast<ValueTensorType>(op.getType());
10388+
auto selfTy = cast<ValueTensorType>(self.getType());
10389+
auto targetTy = cast<ValueTensorType>(target.getType());
10390+
auto outTy = cast<ValueTensorType>(op.getType());
1039110391

1039210392
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
1039310393
return rewriter.notifyMatchFailure(
@@ -10405,8 +10405,8 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1040510405
return rewriter.notifyMatchFailure(
1040610406
op, "Expected a constant boolean value for logTargetBool");
1040710407

10408-
Value logOfTarget;
1040910408
// Default: target tensor is not in log space
10409+
Value logOfTarget;
1041010410
if (!logTargetBool) {
1041110411
logOfTarget = rewriter.create<AtenLogOp>(loc, targetTy, target);
1041210412
} else {
@@ -10418,7 +10418,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1041810418
Value subValue = rewriter.create<AtenSubTensorOp>(loc, selfTy, logOfTarget,
1041910419
self, constOne);
1042010420

10421-
// target tensor is already in log space
10421+
// if target tensor is already in log space
1042210422
if (logTargetBool) {
1042310423
target = rewriter.create<AtenExpOp>(loc, targetTy, target);
1042410424
}
@@ -10431,6 +10431,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1043110431
return rewriter.notifyMatchFailure(op,
1043210432
"reduction should be a constant int!");
1043310433
}
10434+
1043410435
Value loss;
1043510436
Value none = rewriter.create<ConstantNoneOp>(loc);
1043610437
// reduction: mean

0 commit comments

Comments
 (0)