Skip to content

Commit 84f5325

Browse files
committed
Use cast instead of dyn_cast
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 46cf563 commit 84f5325

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10482,9 +10482,9 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1048210482
Value reductionValue = op.getReduction();
1048310483
Value logTargetValue = op.getLogTarget();
1048410484

10485-
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
10486-
auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10487-
auto outTy = dyn_cast<ValueTensorType>(op.getType());
10485+
auto selfTy = cast<ValueTensorType>(self.getType());
10486+
auto targetTy = cast<ValueTensorType>(target.getType());
10487+
auto outTy = cast<ValueTensorType>(op.getType());
1048810488

1048910489
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
1049010490
return rewriter.notifyMatchFailure(
@@ -10502,8 +10502,8 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1050210502
return rewriter.notifyMatchFailure(
1050310503
op, "Expected a constant boolean value for logTargetBool");
1050410504

10505-
Value logOfTarget;
1050610505
// Default: target tensor is not in log space
10506+
Value logOfTarget;
1050710507
if (!logTargetBool) {
1050810508
logOfTarget = rewriter.create<AtenLogOp>(loc, targetTy, target);
1050910509
} else {
@@ -10515,7 +10515,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1051510515
Value subValue = rewriter.create<AtenSubTensorOp>(loc, selfTy, logOfTarget,
1051610516
self, constOne);
1051710517

10518-
// target tensor is already in log space
10518+
// if target tensor is already in log space
1051910519
if (logTargetBool) {
1052010520
target = rewriter.create<AtenExpOp>(loc, targetTy, target);
1052110521
}
@@ -10528,6 +10528,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1052810528
return rewriter.notifyMatchFailure(op,
1052910529
"reduction should be a constant int!");
1053010530
}
10531+
1053110532
Value loss;
1053210533
Value none = rewriter.create<ConstantNoneOp>(loc);
1053310534
// reduction: mean

0 commit comments

Comments
 (0)