Skip to content

Commit ece6abe

Browse files
In floor_sum, use Wrapping<u64> to handle overflows
1 parent 7703e17 commit ece6abe

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

src/internal_math.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// remove this after dependencies has been added
22
#![allow(dead_code)]
3-
use std::mem::swap;
3+
use std::{mem::swap, num::Wrapping as W};
44

55
/// # Arguments
66
/// * `m` `1 <= m`
@@ -243,12 +243,17 @@ pub(crate) fn primitive_root(m: i32) -> i32 {
243243
/// `sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)`
244244
/* const */
245245
#[allow(clippy::many_single_char_names)]
246-
pub(crate) fn floor_sum_unsigned(mut n: u64, mut m: u64, mut a: u64, mut b: u64) -> u64 {
247-
let mut ans = 0;
246+
pub(crate) fn floor_sum_unsigned(
247+
mut n: W<u64>,
248+
mut m: W<u64>,
249+
mut a: W<u64>,
250+
mut b: W<u64>,
251+
) -> W<u64> {
252+
let mut ans = W(0);
248253
loop {
249254
if a >= m {
250-
if n > 0 {
251-
ans += n * (n - 1) / 2 * (a / m);
255+
if n > W(0) {
256+
ans += n * (n - W(1)) / W(2) * (a / m);
252257
}
253258
a %= m;
254259
}

src/math.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,24 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
189189
/// assert_eq!(math::floor_sum(6, 5, 4, 3), 13);
190190
/// ```
191191
#[allow(clippy::many_single_char_names)]
192-
pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 {
192+
pub fn floor_sum(n: i64, m: i64, a: i64, b: i64) -> i64 {
193+
use std::num::Wrapping as W;
193194
assert!((0..1i64 << 32).contains(&n));
194195
assert!((1..1i64 << 32).contains(&m));
195-
let mut ans = 0;
196+
let mut ans = W(0 as u64);
197+
let (wn, wm, mut wa, mut wb) = (W(n as u64), W(m as u64), W(a as u64), W(b as u64));
196198
if a < 0 {
197-
let a2 = internal_math::safe_mod(a, m);
198-
ans -= n * (n - 1) / 2 * ((a2 - a) / m);
199-
a = a2;
199+
let a2 = W(internal_math::safe_mod(a, m) as u64);
200+
ans -= wn * (wn - W(1)) / W(2) * ((a2 - wa) / wm);
201+
wa = a2;
200202
}
201203
if b < 0 {
202-
let b2 = internal_math::safe_mod(b, m);
203-
ans -= n * ((b2 - b) / m);
204-
b = b2;
204+
let b2 = W(internal_math::safe_mod(b, m) as u64);
205+
ans -= wn * ((b2 - wb) / wm);
206+
wb = b2;
205207
}
206-
ans + internal_math::floor_sum_unsigned(n as u64, m as u64, a as u64, b as u64) as i64
208+
let ret = ans + internal_math::floor_sum_unsigned(wn, wm, wa, wb);
209+
ret.0 as i64
207210
}
208211

209212
#[cfg(test)]

0 commit comments

Comments
 (0)