Skip to content

Commit 5b0479c

Browse files
committed
Upcast gradually when computing variance
Going all the way to f64 is undesirable, especially for low-precision tensors in bf16 or f8 variants. Upcast only to the next type, e.g., bf16->f32 or f8->bf16. This is consistent with what Pytorch seems to be doing internally. Signed-off-by: Alex Zinenko <git@ozinenko.com>
1 parent 46c3888 commit 5b0479c

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9322,9 +9322,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
93229322
op, "support floating-point type input only");
93239323
}
93249324

9325-
// Upcasting the input tensor to `F64` dtype for higher precision during the
9326-
// computation of the result.
9327-
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
9325+
// Upcasting the input tensor to a double-bitwidth dtype for higher precision
9326+
// during the computation of the result.
9327+
unsigned bitwidth = inputTensorTy.getDtype().getIntOrFloatBitWidth();
9328+
if (bitwidth != 64) {
9329+
Type targetTy = rewriter.getF64Type();
9330+
if (bitwidth == 8)
9331+
targetTy = rewriter.getBF16Type();
9332+
else if (bitwidth == 16)
9333+
targetTy = rewriter.getF32Type();
93289334
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
93299335
inputTensorTy = cast<BaseTensorType>(self.getType());
93309336
}

0 commit comments

Comments
 (0)