@@ -231,8 +231,28 @@ fn prepare<M: Modulus>() -> ButterflyCache<M> {
231
231
232
232
#[ cfg( test) ]
233
233
mod tests {
234
- use crate :: modint:: { Mod998244353 , Modulus , StaticModInt } ;
234
+ use crate :: {
235
+ modint:: { Mod998244353 , Modulus , StaticModInt } ,
236
+ RemEuclidU32 ,
237
+ } ;
235
238
use rand:: { rngs:: ThreadRng , Rng as _} ;
239
+ use std:: {
240
+ convert:: { TryFrom , TryInto as _} ,
241
+ fmt,
242
+ } ;
243
+
244
+ //https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L51-L71
245
+ #[ test]
246
+ fn empty ( ) {
247
+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ ] , & [ ] ) . is_empty( ) ) ;
248
+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ ] , & [ 1 , 2 ] ) . is_empty( ) ) ;
249
+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ 1 , 2 ] , & [ ] ) . is_empty( ) ) ;
250
+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ 1 ] , & [ ] ) . is_empty( ) ) ;
251
+ assert ! ( super :: convolution_raw:: <i64 , Mod998244353 >( & [ ] , & [ ] ) . is_empty( ) ) ;
252
+ assert ! ( super :: convolution_raw:: <i64 , Mod998244353 >( & [ ] , & [ 1 , 2 ] ) . is_empty( ) ) ;
253
+ assert ! ( super :: convolution:: <Mod998244353 >( & [ ] , & [ ] ) . is_empty( ) ) ;
254
+ assert ! ( super :: convolution:: <Mod998244353 >( & [ ] , & [ 1 . into( ) , 2 . into( ) ] ) . is_empty( ) ) ;
255
+ }
236
256
237
257
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85
238
258
#[ test]
@@ -267,9 +287,119 @@ mod tests {
267
287
test :: < M2 > ( & mut rng) ;
268
288
}
269
289
290
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L120-L150
291
+ #[ test]
292
+ fn simple_int ( ) {
293
+ simple_raw :: < i32 > ( ) ;
294
+ }
295
+
296
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L152-L182
297
+ #[ test]
298
+ fn simple_uint ( ) {
299
+ simple_raw :: < u32 > ( ) ;
300
+ }
301
+
302
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L184-L214
303
+ #[ test]
304
+ fn simple_ll ( ) {
305
+ simple_raw :: < i64 > ( ) ;
306
+ }
307
+
308
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L216-L246
309
+ #[ test]
310
+ fn simple_ull ( ) {
311
+ simple_raw :: < u64 > ( ) ;
312
+ }
313
+
314
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L249-L279
315
+ #[ test]
316
+ fn simple_int128 ( ) {
317
+ simple_raw :: < i128 > ( ) ;
318
+ }
319
+
320
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L281-L311
321
+ #[ test]
322
+ fn simple_uint128 ( ) {
323
+ simple_raw :: < u128 > ( ) ;
324
+ }
325
+
326
+ fn simple_raw < T > ( )
327
+ where
328
+ T : TryFrom < u32 > + Copy + RemEuclidU32 ,
329
+ T :: Error : fmt:: Debug ,
330
+ {
331
+ const M1 : u32 = 998_244_353 ;
332
+ const M2 : u32 = 924_844_033 ;
333
+
334
+ modulus ! ( M1 , M2 ) ;
335
+
336
+ fn test < T , M > ( rng : & mut ThreadRng )
337
+ where
338
+ T : TryFrom < u32 > + Copy + RemEuclidU32 ,
339
+ T :: Error : fmt:: Debug ,
340
+ M : Modulus ,
341
+ {
342
+ let mut gen_raw_values = |n| gen_raw_values :: < u32 , Mod998244353 > ( rng, n) ;
343
+ for ( n, m) in ( 1 ..20 ) . flat_map ( |i| ( 1 ..20 ) . map ( move |j| ( i, j) ) ) {
344
+ let ( a, b) = ( gen_raw_values ( n) , gen_raw_values ( m) ) ;
345
+ assert_eq ! (
346
+ conv_raw_naive:: <_, M >( & a, & b) ,
347
+ super :: convolution_raw:: <_, M >( & a, & b) ,
348
+ ) ;
349
+ }
350
+ }
351
+
352
+ let mut rng = rand:: thread_rng ( ) ;
353
+ test :: < T , M1 > ( & mut rng) ;
354
+ test :: < T , M2 > ( & mut rng) ;
355
+ }
356
+
357
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L315-L329
358
+ #[ test]
359
+ fn conv_ll ( ) {
360
+ let mut rng = rand:: thread_rng ( ) ;
361
+ for ( n, m) in ( 1 ..20 ) . flat_map ( |i| ( 1 ..20 ) . map ( move |j| ( i, j) ) ) {
362
+ let mut gen =
363
+ |n : usize | -> Vec < _ > { ( 0 ..n) . map ( |_| rng. gen_range ( -500_000 , 500_000 ) ) . collect ( ) } ;
364
+ let ( a, b) = ( gen ( n) , gen ( m) ) ;
365
+ assert_eq ! ( conv_i64_naive( & a, & b) , super :: convolution_i64( & a, & b) ) ;
366
+ }
367
+ }
368
+
369
+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L331-L356
370
+ #[ test]
371
+ fn conv_ll_bound ( ) {
372
+ const M1 : u64 = 754_974_721 ; // 2^24
373
+ const M2 : u64 = 167_772_161 ; // 2^25
374
+ const M3 : u64 = 469_762_049 ; // 2^26
375
+ const M2M3 : u64 = M2 * M3 ;
376
+ const M1M3 : u64 = M1 * M3 ;
377
+ const M1M2 : u64 = M1 * M2 ;
378
+
379
+ modulus ! ( M1 , M2 , M3 ) ;
380
+
381
+ for i in -1000 ..=1000 {
382
+ let a = vec ! [ 0u64 . wrapping_sub( M1M2 + M1M3 + M2M3 ) as i64 + i] ;
383
+ let b = vec ! [ 1 ] ;
384
+ assert_eq ! ( a, super :: convolution_i64( & a, & b) ) ;
385
+ }
386
+
387
+ for i in 0 ..1000 {
388
+ let a = vec ! [ i64 :: min_value( ) + i] ;
389
+ let b = vec ! [ 1 ] ;
390
+ assert_eq ! ( a, super :: convolution_i64( & a, & b) ) ;
391
+ }
392
+
393
+ for i in 0 ..1000 {
394
+ let a = vec ! [ i64 :: max_value( ) - i] ;
395
+ let b = vec ! [ 1 ] ;
396
+ assert_eq ! ( a, super :: convolution_i64( & a, & b) ) ;
397
+ }
398
+ }
399
+
270
400
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371
271
401
#[ test]
272
- fn conv641 ( ) {
402
+ fn conv_641 ( ) {
273
403
const M : u32 = 641 ;
274
404
modulus ! ( M ) ;
275
405
@@ -281,7 +411,7 @@ mod tests {
281
411
282
412
// https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386
283
413
#[ test]
284
- fn conv18433 ( ) {
414
+ fn conv_18433 ( ) {
285
415
const M : u32 = 18433 ;
286
416
modulus ! ( M ) ;
287
417
@@ -304,9 +434,43 @@ mod tests {
304
434
c
305
435
}
306
436
437
+ fn conv_raw_naive < T , M > ( a : & [ T ] , b : & [ T ] ) -> Vec < T >
438
+ where
439
+ T : TryFrom < u32 > + Copy + RemEuclidU32 ,
440
+ T :: Error : fmt:: Debug ,
441
+ M : Modulus ,
442
+ {
443
+ conv_naive :: < M > (
444
+ & a. iter ( ) . copied ( ) . map ( Into :: into) . collect :: < Vec < _ > > ( ) ,
445
+ & b. iter ( ) . copied ( ) . map ( Into :: into) . collect :: < Vec < _ > > ( ) ,
446
+ )
447
+ . into_iter ( )
448
+ . map ( |x| x. val ( ) . try_into ( ) . unwrap ( ) )
449
+ . collect ( )
450
+ }
451
+
452
+ #[ allow( clippy:: many_single_char_names) ]
453
+ fn conv_i64_naive ( a : & [ i64 ] , b : & [ i64 ] ) -> Vec < i64 > {
454
+ let ( n, m) = ( a. len ( ) , b. len ( ) ) ;
455
+ let mut c = vec ! [ 0 ; n + m - 1 ] ;
456
+ for ( i, j) in ( 0 ..n) . flat_map ( |i| ( 0 ..m) . map ( move |j| ( i, j) ) ) {
457
+ c[ i + j] += a[ i] * b[ j] ;
458
+ }
459
+ c
460
+ }
461
+
307
462
fn gen_values < M : Modulus > ( rng : & mut ThreadRng , n : usize ) -> Vec < StaticModInt < M > > {
463
+ ( 0 ..n) . map ( |_| rng. gen_range ( 0 , M :: VALUE ) . into ( ) ) . collect ( )
464
+ }
465
+
466
+ fn gen_raw_values < T , M > ( rng : & mut ThreadRng , n : usize ) -> Vec < T >
467
+ where
468
+ T : TryFrom < u32 > ,
469
+ T :: Error : fmt:: Debug ,
470
+ M : Modulus ,
471
+ {
308
472
( 0 ..n)
309
- . map ( |_| StaticModInt :: raw ( rng. gen_range ( 0 , M :: VALUE ) ) )
473
+ . map ( |_| rng. gen_range ( 0 , M :: VALUE ) . try_into ( ) . unwrap ( ) )
310
474
. collect ( )
311
475
}
312
476
}
0 commit comments