Skip to content

Commit 393d4e0

Browse files
authored
ml-dsa: Add support for external mu (#1023)
1 parent 41edbaf commit 393d4e0

File tree

1 file changed

+114
-4
lines changed

1 file changed

+114
-4
lines changed

ml-dsa/src/lib.rs

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,18 @@ impl<P: MlDsaParams> SigningKey<P> {
373373
// the concatenated M'.
374374
// XXX(RLB) Should the API represent this as an input?
375375
let mu = message_representative(&self.tr, Mp);
376+
self.raw_sign_mu(&mu, rnd)
377+
}
376378

379+
fn raw_sign_mu(&self, mu: &B64, rnd: &B32) -> Signature<P>
380+
where
381+
P: MlDsaParams,
382+
{
377383
// Compute the private random seed
378384
let rhopp: B64 = H::default()
379385
.absorb(&self.K)
380386
.absorb(rnd)
381-
.absorb(&mu)
387+
.absorb(mu)
382388
.squeeze_new();
383389

384390
// Rejection sampling loop
@@ -389,7 +395,7 @@ impl<P: MlDsaParams> SigningKey<P> {
389395

390396
let w1_tilde = P::encode_w1(&w1);
391397
let c_tilde = H::default()
392-
.absorb(&mu)
398+
.absorb(mu)
393399
.absorb(&w1_tilde)
394400
.squeeze_new::<P::Lambda>();
395401
let c = sample_in_ball(&c_tilde, P::TAU);
@@ -448,6 +454,24 @@ impl<P: MlDsaParams> SigningKey<P> {
448454
Ok(self.sign_internal(Mp, &rnd))
449455
}
450456

457+
/// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
458+
///
459+
/// # Errors
460+
///
461+
/// This method can return an opaque error if it fails to get enough randomness.
462+
// Algorithm 2 ML-DSA.Sign (optional pre-computed μ variant)
463+
#[cfg(feature = "rand_core")]
464+
pub fn sign_mu_randomized<R: TryCryptoRng + ?Sized>(
465+
&self,
466+
mu: &B64,
467+
rng: &mut R,
468+
) -> Result<Signature<P>, Error> {
469+
let mut rnd = B32::default();
470+
rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
471+
472+
Ok(self.raw_sign_mu(mu, &rnd))
473+
}
474+
451475
/// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm.
452476
///
453477
/// # Errors
@@ -458,6 +482,14 @@ impl<P: MlDsaParams> SigningKey<P> {
458482
self.raw_sign_deterministic(&[M], ctx)
459483
}
460484

485+
/// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm with a
486+
/// pre-computed μ.
487+
// Algorithm 2 ML-DSA.Sign (optional deterministic and pre-computed μ variant)
488+
pub fn sign_mu_deterministic(&self, mu: &B64) -> Signature<P> {
489+
let rnd = B32::default();
490+
self.raw_sign_mu(mu, &rnd)
491+
}
492+
461493
fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
462494
if ctx.len() > 255 {
463495
return Err(Error::new());
@@ -653,7 +685,13 @@ impl<P: MlDsaParams> VerifyingKey<P> {
653685
{
654686
// Compute the message representative
655687
let mu = message_representative(&self.tr, Mp);
688+
self.raw_verify_mu(&mu, sigma)
689+
}
656690

691+
fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
692+
where
693+
P: MlDsaParams,
694+
{
657695
// Reconstruct w
658696
let c = sample_in_ball(&sigma.c_tilde, P::TAU);
659697

@@ -667,19 +705,25 @@ impl<P: MlDsaParams> VerifyingKey<P> {
667705

668706
let w1p_tilde = P::encode_w1(&w1p);
669707
let cp_tilde = H::default()
670-
.absorb(&mu)
708+
.absorb(mu)
671709
.absorb(&w1p_tilde)
672710
.squeeze_new::<P::Lambda>();
673711

674712
sigma.c_tilde == cp_tilde
675713
}
676714

677-
/// This algorithm reflect the ML-DSA.Verify algorithm from FIPS 204.
715+
/// This algorithm reflects the ML-DSA.Verify algorithm from FIPS 204.
678716
// Algorithm 3 ML-DSA.Verify
679717
pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
680718
self.raw_verify_with_context(&[M], ctx, sigma)
681719
}
682720

721+
/// This algorithm reflects the ML-DSA.Verify algorithm with a pre-computed μ from FIPS 204.
722+
// Algorithm 3 ML-DSA.Verify (optional pre-computed μ variant)
723+
pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
724+
self.raw_verify_mu(mu, sigma)
725+
}
726+
683727
fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
684728
if ctx.len() > 255 {
685729
return false;
@@ -1062,4 +1106,70 @@ mod test {
10621106
many_round_trip_test::<MlDsa65>();
10631107
many_round_trip_test::<MlDsa87>();
10641108
}
1109+
1110+
#[test]
1111+
fn sign_mu_verify_mu_round_trip() {
1112+
fn sign_mu_verify_mu<P>()
1113+
where
1114+
P: MlDsaParams,
1115+
{
1116+
let kp = P::key_gen_internal(&Array::default());
1117+
let sk = kp.signing_key;
1118+
let vk = kp.verifying_key;
1119+
1120+
let M = b"Hello world";
1121+
let rnd = Array([0u8; 32]);
1122+
let mu = message_representative(&sk.tr, &[&[M]]);
1123+
let sig = sk.raw_sign_mu(&mu, &rnd);
1124+
1125+
assert!(vk.raw_verify_mu(&mu, &sig));
1126+
}
1127+
sign_mu_verify_mu::<MlDsa44>();
1128+
sign_mu_verify_mu::<MlDsa65>();
1129+
sign_mu_verify_mu::<MlDsa87>();
1130+
}
1131+
1132+
#[test]
1133+
fn sign_mu_verify_internal_round_trip() {
1134+
fn sign_mu_verify_internal<P>()
1135+
where
1136+
P: MlDsaParams,
1137+
{
1138+
let kp = P::key_gen_internal(&Array::default());
1139+
let sk = kp.signing_key;
1140+
let vk = kp.verifying_key;
1141+
1142+
let M = b"Hello world";
1143+
let rnd = Array([0u8; 32]);
1144+
let mu = message_representative(&sk.tr, &[&[M]]);
1145+
let sig = sk.raw_sign_mu(&mu, &rnd);
1146+
1147+
assert!(vk.verify_internal(&[M], &sig));
1148+
}
1149+
sign_mu_verify_internal::<MlDsa44>();
1150+
sign_mu_verify_internal::<MlDsa65>();
1151+
sign_mu_verify_internal::<MlDsa87>();
1152+
}
1153+
1154+
#[test]
1155+
fn sign_internal_verify_mu_round_trip() {
1156+
fn sign_internal_verify_mu<P>()
1157+
where
1158+
P: MlDsaParams,
1159+
{
1160+
let kp = P::key_gen_internal(&Array::default());
1161+
let sk = kp.signing_key;
1162+
let vk = kp.verifying_key;
1163+
1164+
let M = b"Hello world";
1165+
let rnd = Array([0u8; 32]);
1166+
let mu = message_representative(&sk.tr, &[&[M]]);
1167+
let sig = sk.sign_internal(&[M], &rnd);
1168+
1169+
assert!(vk.raw_verify_mu(&mu, &sig));
1170+
}
1171+
sign_internal_verify_mu::<MlDsa44>();
1172+
sign_internal_verify_mu::<MlDsa65>();
1173+
sign_internal_verify_mu::<MlDsa87>();
1174+
}
10651175
}

0 commit comments

Comments
 (0)