@@ -10385,9 +10385,9 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10385
10385
Value reductionValue = op.getReduction ();
10386
10386
Value logTargetValue = op.getLogTarget ();
10387
10387
10388
- auto selfTy = dyn_cast <ValueTensorType>(self.getType ());
10389
- auto targetTy = dyn_cast <ValueTensorType>(target.getType ());
10390
- auto outTy = dyn_cast <ValueTensorType>(op.getType ());
10388
+ auto selfTy = cast <ValueTensorType>(self.getType ());
10389
+ auto targetTy = cast <ValueTensorType>(target.getType ());
10390
+ auto outTy = cast <ValueTensorType>(op.getType ());
10391
10391
10392
10392
if (!selfTy.hasSizes () || !targetTy.hasSizes () || !outTy.hasSizes ()) {
10393
10393
return rewriter.notifyMatchFailure (
@@ -10405,8 +10405,8 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10405
10405
return rewriter.notifyMatchFailure (
10406
10406
op, " Expected a constant boolean value for logTargetBool" );
10407
10407
10408
- Value logOfTarget;
10409
10408
// Default: target tensor is not in log space
10409
+ Value logOfTarget;
10410
10410
if (!logTargetBool) {
10411
10411
logOfTarget = rewriter.create <AtenLogOp>(loc, targetTy, target);
10412
10412
} else {
@@ -10418,7 +10418,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10418
10418
Value subValue = rewriter.create <AtenSubTensorOp>(loc, selfTy, logOfTarget,
10419
10419
self, constOne);
10420
10420
10421
- // target tensor is already in log space
10421
+ // if target tensor is already in log space
10422
10422
if (logTargetBool) {
10423
10423
target = rewriter.create <AtenExpOp>(loc, targetTy, target);
10424
10424
}
@@ -10431,6 +10431,7 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10431
10431
return rewriter.notifyMatchFailure (op,
10432
10432
" reduction should be a constant int!" );
10433
10433
}
10434
+
10434
10435
Value loss;
10435
10436
Value none = rewriter.create <ConstantNoneOp>(loc);
10436
10437
// reduction: mean
0 commit comments