diff --git a/ml-dsa/src/algebra.rs b/ml-dsa/src/algebra.rs index b89bd283..85de202e 100644 --- a/ml-dsa/src/algebra.rs +++ b/ml-dsa/src/algebra.rs @@ -210,3 +210,53 @@ impl AlgebraExt for Vector { ) } } + +#[cfg(test)] +mod test { + use super::*; + + use crate::{MlDsa65, ParameterSet}; + + type Mod = ::TwoGamma2; + const MOD: u32 = Mod::U32; + const MOD_ELEM: Elem = Elem::new(MOD); + + #[test] + fn mod_plus_minus() { + for x in 0..MOD { + // BaseField::Q { + let x = Elem::new(x); + let x0 = x.mod_plus_minus::(); + + // Outputs from mod+- should be in the half-open interval (-gamma2, gamma2] + let positive_bound = x0.0 <= MOD / 2; + let negative_bound = x0.0 > BaseField::Q - MOD / 2; + assert!(positive_bound || negative_bound); + + // The output should be equivalent to the input, mod 2 * gamma2. We add 2 * gamma2 + // before comparing so that both values are "positive", avoiding interactions between + // the mod-Q and mod-M operations. + let xn = x + MOD_ELEM; + let x0n = x0 + MOD_ELEM; + assert_eq!(xn.0 % MOD, x0n.0 % MOD); + } + } + + #[test] + fn decompose() { + for x in 0..MOD { + let x = Elem::new(x); + let (x1, x0) = x.decompose::(); + + // The low-order output from decompose() is a mod+- output, optionally minus one. So + // they should be in the closed interval [-gamma2, gamma2]. + let positive_bound = x0.0 <= MOD / 2; + let negative_bound = x0.0 >= BaseField::Q - MOD / 2; + assert!(positive_bound || negative_bound); + + // The low-order and high-order outputs should combine to form the input. + let xx = (MOD * x1.0 + x0.0) % BaseField::Q; + assert_eq!(xx, x.0); + } + } +} diff --git a/ml-dsa/src/hint.rs b/ml-dsa/src/hint.rs index df9f321e..ced3cf34 100644 --- a/ml-dsa/src/hint.rs +++ b/ml-dsa/src/hint.rs @@ -22,18 +22,18 @@ fn use_hint(h: bool, r: Elem) -> Elem { let gamma2 = TwoGamma2::U32 / 2; if h && r0.0 <= gamma2 { Elem::new((r1.0 + 1) % m) - } else if h && r0.0 > BaseField::Q - gamma2 { + } else if h && r0.0 >= BaseField::Q - gamma2 { Elem::new((r1.0 + m - 1) % m) } else if h { // We use the Elem encoding even for signed integers. Since r0 is computed - // mod+- 2*gamma2, it is guaranteed to be in (gamma2, gamma2]. + // mod+- 2*gamma2 (possibly minus 1), it is guaranteed to be in [-gamma2, gamma2]. unreachable!(); } else { r1 } } -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Hint

(pub Array, P::K>) where P: SignatureParams; @@ -116,7 +116,7 @@ where } fn monotonic(a: &[usize]) -> bool { - a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] < *x) + a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] <= *x) } pub fn bit_unpack(y: &EncodedHint

) -> Option { diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index 3efe91cf..96d856bf 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -89,7 +89,7 @@ pub use crate::util::B32; pub use signature::Error; /// An ML-DSA signature -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Signature { c_tilde: Array, z: Vector, @@ -899,4 +899,42 @@ mod test { sign_verify_round_trip_test::(); sign_verify_round_trip_test::(); } + + fn many_round_trip_test

() + where + P: MlDsaParams, + { + use rand::Rng; + + const ITERATIONS: usize = 1000; + + let mut rng = rand::thread_rng(); + let mut seed = B32::default(); + + for _i in 0..ITERATIONS { + let seed_data: &mut [u8] = seed.as_mut(); + rng.fill(seed_data); + + let kp = P::key_gen_internal(&seed); + let sk = kp.signing_key; + let vk = kp.verifying_key; + + let M = b"Hello world"; + let rnd = Array([0u8; 32]); + let sig = sk.sign_internal(&[M], &rnd); + + let sig_enc = sig.encode(); + let sig_dec = Signature::

::decode(&sig_enc).unwrap(); + + assert_eq!(sig_dec, sig); + assert!(vk.verify_internal(&[M], &sig_dec)); + } + } + + #[test] + fn many_round_trip() { + many_round_trip_test::(); + many_round_trip_test::(); + many_round_trip_test::(); + } }