diff --git a/util/src/main/scala/Money.scala b/util/src/main/scala/Money.scala index 68274b18..4dfa897f 100644 --- a/util/src/main/scala/Money.scala +++ b/util/src/main/scala/Money.scala @@ -4,7 +4,6 @@ import language.implicitConversions import java.math.MathContext import java.text.NumberFormat import java.util.{Currency, Locale} - import cats.Monoid import cats.data.ValidatedNel import cats.syntax.validated._ @@ -351,8 +350,7 @@ case class HighPrecisionMoney private ( val `type`: String = TypeName - lazy val amount: BigDecimal = - (BigDecimal(preciseAmount) * factor(fractionDigits)).setScale(fractionDigits) + lazy val amount: BigDecimal = preciseAmountToAmount(preciseAmount, fractionDigits) def withFractionDigits(fd: Int)(implicit mode: RoundingMode): HighPrecisionMoney = { val scaledAmount = amount.setScale(fd, mode) @@ -474,6 +472,8 @@ case class HighPrecisionMoney private ( } object HighPrecisionMoney { + import MoneyRounding._ + object ImplicitsDecimal { final implicit class HighPrecisionMoneyNotation(val amount: BigDecimal) extends AnyVal { def EUR: HighPrecisionMoney = HighPrecisionMoney.EUR(amount) @@ -566,6 +566,9 @@ object HighPrecisionMoney { private def amountToPreciseAmount(amount: BigDecimal, fractionDigits: Int): Long = (amount * Money.cachedCentPower(fractionDigits)).toLong + def preciseAmountToAmount(preciseAmount: Long, fractionDigits: Int): BigDecimal = + (BigDecimal(preciseAmount) * factor(fractionDigits)).setScale(fractionDigits) + def fromDecimalAmount(amount: BigDecimal, fractionDigits: Int, currency: Currency)(implicit mode: RoundingMode): HighPrecisionMoney = { val scaledAmount = amount.setScale(fractionDigits, mode) @@ -606,12 +609,9 @@ object HighPrecisionMoney { centAmount: Option[Long]): ValidatedNel[String, HighPrecisionMoney] = for { fd <- validateFractionDigits(fractionDigits, currency) - amount = BigDecimal(preciseAmount) * factor(fd) - scaledAmount = amount.setScale(fd, BigDecimal.RoundingMode.UNNECESSARY) - ca <- validateCentAmount(scaledAmount, centAmount, currency) + ca <- validateCentAmount(preciseAmount, fractionDigits, centAmount, currency) // TODO: revisit this part! the rounding mode might be dynamic and configured elsewhere - actualCentAmount = ca.getOrElse( - roundToCents(scaledAmount, currency)(BigDecimal.RoundingMode.HALF_EVEN)) + actualCentAmount = ca.getOrElse(roundHalfEven(preciseAmount, fractionDigits, currency)) } yield HighPrecisionMoney(preciseAmount, fd, actualCentAmount, currency) private def validateFractionDigits( @@ -625,13 +625,14 @@ object HighPrecisionMoney { fractionDigits.validNel private def validateCentAmount( - amount: BigDecimal, + preciseAmount: Long, + fractionDigits: Int, centAmount: Option[Long], currency: Currency): ValidatedNel[String, Option[Long]] = centAmount match { case Some(actual) => - val min = roundToCents(amount, currency)(RoundingMode.FLOOR) - val max = roundToCents(amount, currency)(RoundingMode.CEILING) + val min = roundFloor(preciseAmount, fractionDigits, currency) + val max = roundCeiling(preciseAmount, fractionDigits, currency) if (actual < min || actual > max) s"centAmount must be correctly rounded preciseAmount (a number between $min and $max).".invalidNel diff --git a/util/src/main/scala/MoneyRounding.scala b/util/src/main/scala/MoneyRounding.scala new file mode 100644 index 00000000..8408646c --- /dev/null +++ b/util/src/main/scala/MoneyRounding.scala @@ -0,0 +1,84 @@ +package io.sphere.util + +import java.util.Currency +import scala.annotation.tailrec + +/** This object contains rounding algorithms for the Money classes. So far we used BigDecimal for + * this purpose, but BigDecimal is slower and consumes more memory than this approach. + */ +object MoneyRounding { + + private def pow10(n: Int): Long = Math.pow(10, n).toLong + + /** @return + * Floor rounded (preciseAmount, fractionDigits) to the cent value of the given currency + */ + def roundFloor(preciseAmount: Long, fractionDigits: Int, currency: Currency): Long = + if (preciseAmount < 0L) { + val power = pow10(fractionDigits - currency.getDefaultFractionDigits) + val floor = preciseAmount / power + val remainder = preciseAmount % power + if (remainder == 0L) floor else floor - 1L + } else + preciseAmount / pow10(fractionDigits - currency.getDefaultFractionDigits) + + /** @return + * Ceiling rounded (preciseAmount, fractionDigits) to the cent value of the given currency + */ + def roundCeiling(preciseAmount: Long, fractionDigits: Int, currency: Currency): Long = + if (preciseAmount < 0L) + preciseAmount / pow10(fractionDigits - currency.getDefaultFractionDigits) + else { + val power = pow10(fractionDigits - currency.getDefaultFractionDigits) + val floor = preciseAmount / power + val remainder = preciseAmount % power + if (remainder == 0L) floor else floor + 1L + } + + private def getFractionDigits( + fractionWithoutLeadingZeros: Long, + fractionDigits: Int): List[Int] = { + @tailrec + def loop(remainder: Long, acc: List[Int]): List[Int] = { + val lastDigit = (remainder % 10L).toInt + val newRemainder = remainder / 10L + val newAcc = lastDigit :: acc + if (newRemainder == 0L) newAcc + else loop(newRemainder, newAcc) + } + val digits = loop(fractionWithoutLeadingZeros, List.empty) + + if (digits.length < fractionDigits) List.fill(fractionDigits - digits.length)(0) ::: digits + else digits + } + + /** @return + * Half even rounded (preciseAmount, fractionDigits) to the cent value of the given currency + */ + def roundHalfEven(preciseAmount: Long, fractionDigits: Int, currency: Currency): Long = { + val centFractionDigits = fractionDigits - currency.getDefaultFractionDigits + val power = pow10(centFractionDigits) + val integer = preciseAmount / power + val fraction = preciseAmount % power + + // Eg: 3 for 123.456 + val leastSignificantDigitOfInt = integer % 10L + + val fractionDigitsList = getFractionDigits(fraction, centFractionDigits) + + // Eg: 4 for 123.456 + val mostSignificantDigitOfFraction :: rest = fractionDigitsList + + if (mostSignificantDigitOfFraction == 5 && rest.forall(_ == 0)) + if (leastSignificantDigitOfInt % 2L == 0L) integer else integer + 1L + else if (mostSignificantDigitOfFraction >= 5) + integer + 1L + else if (mostSignificantDigitOfFraction == -5 && rest.forall(_ == 0)) + if (leastSignificantDigitOfInt % 2L == 0L) integer else integer - 1L + else if (mostSignificantDigitOfFraction <= -5) + integer - 1L + else + integer + } + +} diff --git a/util/src/test/scala/DomainObjectsGen.scala b/util/src/test/scala/DomainObjectsGen.scala index 97536bfd..9cef5c33 100644 --- a/util/src/test/scala/DomainObjectsGen.scala +++ b/util/src/test/scala/DomainObjectsGen.scala @@ -18,7 +18,8 @@ object DomainObjectsGen { val highPrecisionMoney: Gen[HighPrecisionMoney] = for { money <- money - } yield HighPrecisionMoney.fromMoney(money, money.currency.getDefaultFractionDigits) + fractionDigits <- Gen.oneOf(money.currency.getDefaultFractionDigits to 10) + } yield HighPrecisionMoney.fromMoney(money, fractionDigits) val baseMoney: Gen[BaseMoney] = Gen.oneOf(money, highPrecisionMoney) diff --git a/util/src/test/scala/MoneyRoundingSpec.scala b/util/src/test/scala/MoneyRoundingSpec.scala new file mode 100644 index 00000000..c1228f13 --- /dev/null +++ b/util/src/test/scala/MoneyRoundingSpec.scala @@ -0,0 +1,81 @@ +package io.sphere.util + +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.must.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.prop.TableDrivenPropertyChecks.Table +import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks + +import java.util.Currency +import scala.math.BigDecimal.RoundingMode + +class MoneyRoundingSpec extends AnyFunSpec with Matchers with ScalaCheckDrivenPropertyChecks { + val Euro: Currency = Currency.getInstance("EUR") + val ZWL: Currency = Currency.getInstance("ZWL") + val JPY: Currency = Currency.getInstance("JPY") + + private implicit val genConfig: PropertyCheckConfiguration = + PropertyCheckConfiguration(minSuccessful = 50) + + describe("Money Rounding") { + it("roundFloor should behave similarly to BigDecimal rounding") { + forAll(DomainObjectsGen.highPrecisionMoney) { h => + val bdRes = HighPrecisionMoney.roundToCents(h.amount, h.currency)(RoundingMode.FLOOR) + val longRes = + MoneyRounding.roundFloor(h.preciseAmount, h.fractionDigits, h.currency) + bdRes must be(longRes) + } + } + + it("roundCeiling should behave similarly to BigDecimal rounding") { + forAll(DomainObjectsGen.highPrecisionMoney) { h => + val bdRes = HighPrecisionMoney.roundToCents(h.amount, h.currency)(RoundingMode.CEILING) + val longRes = + MoneyRounding.roundCeiling(h.preciseAmount, h.fractionDigits, h.currency) + bdRes must be(longRes) + } + } + + it("roundHalfEven should behave similarly to BigDecimal rounding") { + // I used random generated values later, but I needed these very specific values too to check the + // edge cases of the half even rounding + val data = Table( + ("preciseAmount", "fraction", "currency"), + (1119L, 3, Euro), + (1111L, 3, Euro), + (1115L, 3, Euro), + (1125L, 3, Euro), + (112500L, 5, Euro), + (11250001L, 7, Euro), + (11000004L, 7, Euro), + (11249999L, 7, Euro), + (-1119L, 3, Euro), + (-1111L, 3, Euro), + (-1115L, 3, Euro), + (-1125L, 3, Euro), + (-112500L, 5, Euro), + (5721482481806080960L, 6, ZWL), + (123L, 0, JPY) + ) + + TableDrivenPropertyChecks.forAll(data) { (preciseAmount, fd, cur) => + val amount = HighPrecisionMoney.preciseAmountToAmount(preciseAmount, fd) + val bdRes = HighPrecisionMoney.roundToCents(amount, cur)(RoundingMode.HALF_EVEN) + + val longRes = + MoneyRounding.roundHalfEven(preciseAmount, fd, cur) + + bdRes must be(longRes) + } + + forAll(DomainObjectsGen.highPrecisionMoney) { h => + val bdRes = HighPrecisionMoney.roundToCents(h.amount, h.currency)(RoundingMode.HALF_EVEN) + + val longRes = + MoneyRounding.roundHalfEven(h.preciseAmount, h.fractionDigits, h.currency) + + bdRes must be(longRes) + } + } + } +}