@@ -10482,9 +10482,9 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10482
10482
Value reductionValue = op.getReduction();
10483
10483
Value logTargetValue = op.getLogTarget();
10484
10484
10485
- auto selfTy = dyn_cast <ValueTensorType>(self.getType());
10486
- auto targetTy = dyn_cast <ValueTensorType>(target.getType());
10487
- auto outTy = dyn_cast <ValueTensorType>(op.getType());
10485
+ auto selfTy = cast <ValueTensorType>(self.getType());
10486
+ auto targetTy = cast <ValueTensorType>(target.getType());
10487
+ auto outTy = cast <ValueTensorType>(op.getType());
10488
10488
10489
10489
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
10490
10490
return rewriter.notifyMatchFailure(
@@ -10502,8 +10502,8 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10502
10502
return rewriter.notifyMatchFailure(
10503
10503
op, "Expected a constant boolean value for logTargetBool");
10504
10504
10505
- Value logOfTarget;
10506
10505
// Default: target tensor is not in log space
10506
+ Value logOfTarget;
10507
10507
if (!logTargetBool) {
10508
10508
logOfTarget = rewriter.create<AtenLogOp>(loc, targetTy, target);
10509
10509
} else {
@@ -10515,7 +10515,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10515
10515
Value subValue = rewriter.create<AtenSubTensorOp>(loc, selfTy, logOfTarget,
10516
10516
self, constOne);
10517
10517
10518
- // target tensor is already in log space
10518
+ // if target tensor is already in log space
10519
10519
if (logTargetBool) {
10520
10520
target = rewriter.create<AtenExpOp>(loc, targetTy, target);
10521
10521
}
@@ -10528,6 +10528,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10528
10528
return rewriter.notifyMatchFailure(op,
10529
10529
"reduction should be a constant int!");
10530
10530
}
10531
+
10531
10532
Value loss;
10532
10533
Value none = rewriter.create<ConstantNoneOp>(loc);
10533
10534
// reduction: mean
0 commit comments