Skip to content

Commit af4b04c

Browse files
committed
Translate all of the convolution unit tests of the original ones
1 parent b8043a0 commit af4b04c

File tree

1 file changed

+168
-4
lines changed

1 file changed

+168
-4
lines changed

src/convolution.rs

Lines changed: 168 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,28 @@ fn prepare<M: Modulus>() -> ButterflyCache<M> {
231231

232232
#[cfg(test)]
233233
mod tests {
234-
use crate::modint::{Mod998244353, Modulus, StaticModInt};
234+
use crate::{
235+
modint::{Mod998244353, Modulus, StaticModInt},
236+
RemEuclidU32,
237+
};
235238
use rand::{rngs::ThreadRng, Rng as _};
239+
use std::{
240+
convert::{TryFrom, TryInto as _},
241+
fmt,
242+
};
243+
244+
//https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L51-L71
245+
#[test]
246+
fn empty() {
247+
assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[]).is_empty());
248+
assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[1, 2]).is_empty());
249+
assert!(super::convolution_raw::<i32, Mod998244353>(&[1, 2], &[]).is_empty());
250+
assert!(super::convolution_raw::<i32, Mod998244353>(&[1], &[]).is_empty());
251+
assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[]).is_empty());
252+
assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[1, 2]).is_empty());
253+
assert!(super::convolution::<Mod998244353>(&[], &[]).is_empty());
254+
assert!(super::convolution::<Mod998244353>(&[], &[1.into(), 2.into()]).is_empty());
255+
}
236256

237257
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85
238258
#[test]
@@ -267,9 +287,119 @@ mod tests {
267287
test::<M2>(&mut rng);
268288
}
269289

290+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L120-L150
291+
#[test]
292+
fn simple_int() {
293+
simple_raw::<i32>();
294+
}
295+
296+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L152-L182
297+
#[test]
298+
fn simple_uint() {
299+
simple_raw::<u32>();
300+
}
301+
302+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L184-L214
303+
#[test]
304+
fn simple_ll() {
305+
simple_raw::<i64>();
306+
}
307+
308+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L216-L246
309+
#[test]
310+
fn simple_ull() {
311+
simple_raw::<u64>();
312+
}
313+
314+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L249-L279
315+
#[test]
316+
fn simple_int128() {
317+
simple_raw::<i128>();
318+
}
319+
320+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L281-L311
321+
#[test]
322+
fn simple_uint128() {
323+
simple_raw::<u128>();
324+
}
325+
326+
fn simple_raw<T>()
327+
where
328+
T: TryFrom<u32> + Copy + RemEuclidU32,
329+
T::Error: fmt::Debug,
330+
{
331+
const M1: u32 = 998_244_353;
332+
const M2: u32 = 924_844_033;
333+
334+
modulus!(M1, M2);
335+
336+
fn test<T, M>(rng: &mut ThreadRng)
337+
where
338+
T: TryFrom<u32> + Copy + RemEuclidU32,
339+
T::Error: fmt::Debug,
340+
M: Modulus,
341+
{
342+
let mut gen_raw_values = |n| gen_raw_values::<u32, Mod998244353>(rng, n);
343+
for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
344+
let (a, b) = (gen_raw_values(n), gen_raw_values(m));
345+
assert_eq!(
346+
conv_raw_naive::<_, M>(&a, &b),
347+
super::convolution_raw::<_, M>(&a, &b),
348+
);
349+
}
350+
}
351+
352+
let mut rng = rand::thread_rng();
353+
test::<T, M1>(&mut rng);
354+
test::<T, M2>(&mut rng);
355+
}
356+
357+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L315-L329
358+
#[test]
359+
fn conv_ll() {
360+
let mut rng = rand::thread_rng();
361+
for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
362+
let mut gen =
363+
|n: usize| -> Vec<_> { (0..n).map(|_| rng.gen_range(-500_000, 500_000)).collect() };
364+
let (a, b) = (gen(n), gen(m));
365+
assert_eq!(conv_i64_naive(&a, &b), super::convolution_i64(&a, &b));
366+
}
367+
}
368+
369+
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L331-L356
370+
#[test]
371+
fn conv_ll_bound() {
372+
const M1: u64 = 754_974_721; // 2^24
373+
const M2: u64 = 167_772_161; // 2^25
374+
const M3: u64 = 469_762_049; // 2^26
375+
const M2M3: u64 = M2 * M3;
376+
const M1M3: u64 = M1 * M3;
377+
const M1M2: u64 = M1 * M2;
378+
379+
modulus!(M1, M2, M3);
380+
381+
for i in -1000..=1000 {
382+
let a = vec![0u64.wrapping_sub(M1M2 + M1M3 + M2M3) as i64 + i];
383+
let b = vec![1];
384+
assert_eq!(a, super::convolution_i64(&a, &b));
385+
}
386+
387+
for i in 0..1000 {
388+
let a = vec![i64::min_value() + i];
389+
let b = vec![1];
390+
assert_eq!(a, super::convolution_i64(&a, &b));
391+
}
392+
393+
for i in 0..1000 {
394+
let a = vec![i64::max_value() - i];
395+
let b = vec![1];
396+
assert_eq!(a, super::convolution_i64(&a, &b));
397+
}
398+
}
399+
270400
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371
271401
#[test]
272-
fn conv641() {
402+
fn conv_641() {
273403
const M: u32 = 641;
274404
modulus!(M);
275405

@@ -281,7 +411,7 @@ mod tests {
281411

282412
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386
283413
#[test]
284-
fn conv18433() {
414+
fn conv_18433() {
285415
const M: u32 = 18433;
286416
modulus!(M);
287417

@@ -304,9 +434,43 @@ mod tests {
304434
c
305435
}
306436

437+
fn conv_raw_naive<T, M>(a: &[T], b: &[T]) -> Vec<T>
438+
where
439+
T: TryFrom<u32> + Copy + RemEuclidU32,
440+
T::Error: fmt::Debug,
441+
M: Modulus,
442+
{
443+
conv_naive::<M>(
444+
&a.iter().copied().map(Into::into).collect::<Vec<_>>(),
445+
&b.iter().copied().map(Into::into).collect::<Vec<_>>(),
446+
)
447+
.into_iter()
448+
.map(|x| x.val().try_into().unwrap())
449+
.collect()
450+
}
451+
452+
#[allow(clippy::many_single_char_names)]
453+
fn conv_i64_naive(a: &[i64], b: &[i64]) -> Vec<i64> {
454+
let (n, m) = (a.len(), b.len());
455+
let mut c = vec![0; n + m - 1];
456+
for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
457+
c[i + j] += a[i] * b[j];
458+
}
459+
c
460+
}
461+
307462
fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> {
463+
(0..n).map(|_| rng.gen_range(0, M::VALUE).into()).collect()
464+
}
465+
466+
fn gen_raw_values<T, M>(rng: &mut ThreadRng, n: usize) -> Vec<T>
467+
where
468+
T: TryFrom<u32>,
469+
T::Error: fmt::Debug,
470+
M: Modulus,
471+
{
308472
(0..n)
309-
.map(|_| StaticModInt::raw(rng.gen_range(0, M::VALUE)))
473+
.map(|_| rng.gen_range(0, M::VALUE).try_into().unwrap())
310474
.collect()
311475
}
312476
}

0 commit comments

Comments
 (0)