diff --git a/spec/std/int_spec.cr b/spec/std/int_spec.cr index e027996a3d59..26d76012052d 100644 --- a/spec/std/int_spec.cr +++ b/spec/std/int_spec.cr @@ -12,6 +12,19 @@ private macro it_converts_to_s(num, str, **opts) end end +private class IntEnumerable + include Enumerable(Int32) + + def initialize(@elements : Array(Int32)) + end + + def each(&) + @elements.each do |e| + yield e + end + end +end + describe "Int" do describe "#integer?" do {% for int in BUILTIN_INTEGER_TYPES %} @@ -1139,4 +1152,48 @@ describe "Int" do end end end + + describe "from_digits" do + it "returns Int composed from given digits" do + Int32.from_digits([9, 8, 7, 6, 5, 4, 3, 2, 1]).should eq(123456789) + end + + it "works with a base" do + Int32.from_digits([11, 7], 16).should eq(123) + Int32.from_digits([11, 7], base: 16).should eq(123) + end + + it "accepts digits as Enumerable" do + enumerable = IntEnumerable.new([11, 7]) + Int32.from_digits(enumerable, 16).should eq(123) + end + + it "raises for base less than 2" do + [-1, 0, 1].each do |base| + expect_raises(ArgumentError, "Invalid base #{base}") do + Int32.from_digits([1, 2, 3], base) + end + end + end + + it "raises for digits greater than base" do + expect_raises(ArgumentError, "Invalid digit 2 for base 2") do + Int32.from_digits([1, 0, 2], 2) + end + + expect_raises(ArgumentError, "Invalid digit 10 for base 2") do + Int32.from_digits([1, 0, 10], 2) + end + end + + it "raises for negative digits" do + expect_raises(ArgumentError, "Invalid digit -1") do + Int32.from_digits([1, 2, -1]) + end + end + + it "works properly for values close to the upper limit" do + UInt8.from_digits([5, 5, 2]).should eq(255) + end + end end diff --git a/src/int.cr b/src/int.cr index 8060dff4fac1..3e9bc921d6cd 100644 --- a/src/int.cr +++ b/src/int.cr @@ -2776,3 +2776,53 @@ struct UInt128 self end end + +# Returns a number for given digits and base. +# The digits are expected as an Enumerable with the least significant digit as the first element. +# +# Base must not be less than 2. +# +# All digits must be within 0...base. +# +# ``` +# Int32.from_digits([5, 4, 3, 2, 1]) # => 12345 +# Int32.from_digits([4, 6, 6, 0, 5], base: 7) # => 12345 +# Int32.from_digits([45, 23, 1], base: 100) # => 12345 +# +# Int32.from_digits([1], base: -2) # raises ArgumentError +# Int32.from_digits([-1]) # raises ArgumentError +# Int32.from_digits([3], base: 2) # raises ArgumentError +# ``` +{% for type in %w(Int8 Int16 Int32 Int64 Int128 UInt8 UInt16 UInt32 UInt64 UInt128) %} + def {{type.id}}.from_digits(digits : Enumerable(Int), base : Int = 10) : self + if base < 2 + raise ArgumentError.new("Invalid base #{base}") + end + + num : {{type.id}} = 0 + multiplier : {{type.id}} = 1 + first_element = true + + digits.each do |digit| + if digit < 0 + raise ArgumentError.new("Invalid digit #{digit}") + end + + if digit >= base + raise ArgumentError.new("Invalid digit #{digit} for base #{base}") + end + + # don't calculate multiplier upfront for the next digit + # to avoid overflow at the last iteration + if first_element + first_element = false + else + multiplier *= base + end + + num += digit * multiplier + end + + num + end +{% end %}