diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 3bfc35c09d1b..24ba5cb6ecfe 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9322,9 +9322,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, op, "support floating-point type input only"); } - // Upcasting the input tensor to `F64` dtype for higher precision during the - // computation of the result. - if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { + // Upcasting the input tensor to a double-bitwidth dtype for higher precision + // during the computation of the result. + unsigned bitwidth = inputTensorTy.getDtype().getIntOrFloatBitWidth(); + if (bitwidth != 64) { + Type targetTy = rewriter.getF64Type(); + if (bitwidth == 8) + targetTy = rewriter.getBF16Type(); + else if (bitwidth == 16) + targetTy = rewriter.getF32Type(); self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); inputTensorTy = cast(self.getType()); }