Skip to content

Commit 2bd9076

Browse files
committed
Pass signature verification tests
1 parent 8267eef commit 2bd9076

File tree

6 files changed

+91
-47
lines changed

6 files changed

+91
-47
lines changed

ml-dsa/src/algebra.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,12 @@ impl From<NttPolynomial> for Array<FieldElement, U256> {
686686
// Algorithm 41 NTT
687687
impl Polynomial {
688688
pub fn ntt(&self) -> NttPolynomial {
689-
let mut w = self.0.clone();
689+
// XXX let mut w = self.0.clone();
690+
let mut w: Array<FieldElement, U256> = self
691+
.0
692+
.iter()
693+
.map(|x| FieldElement(x.0 % FieldElement::Q))
694+
.collect();
690695

691696
let mut m = 0;
692697
for len in [128, 64, 32, 16, 8, 4, 2, 1] {

ml-dsa/src/hint.rs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ where
9898
let mut y: EncodedHint<P> = Default::default();
9999
let mut index = 0;
100100
let omega = P::Omega::USIZE;
101-
102101
for i in 0..P::K::U8 {
103102
let i_usize: usize = i.into();
104103
for j in 0..256 {
@@ -114,32 +113,37 @@ where
114113
y
115114
}
116115

116+
fn monotonic(a: &[usize]) -> bool {
117+
a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] < *x)
118+
}
119+
117120
pub fn bit_unpack(y: &EncodedHint<P>) -> Option<Self> {
121+
let (indices, cuts) = P::split_hint(y);
122+
let cuts: Array<usize, P::K> = cuts.iter().map(|x| usize::from(*x)).collect();
123+
124+
let indices: Array<usize, P::Omega> = indices.iter().map(|x| usize::from(*x)).collect();
125+
let max_cut: usize = cuts.iter().cloned().max().unwrap().into();
126+
if !Self::monotonic(&cuts)
127+
|| max_cut > indices.len()
128+
|| indices[max_cut..].iter().cloned().max().unwrap_or(0) > 0
129+
{
130+
return None;
131+
}
132+
118133
let mut h = Self::default();
119-
let mut index = 0;
120-
let omega = P::Omega::USIZE;
134+
let mut start = 0;
135+
for (i, &end) in cuts.iter().enumerate() {
136+
let indices = &indices[start..end];
121137

122-
for i in 0..P::K::U8 {
123-
let i_usize: usize = i.into();
124-
let end: usize = y[omega + i_usize].into();
125-
if end < index || end > omega {
138+
if !Self::monotonic(indices) {
126139
return None;
127140
}
128141

129-
let start = index;
130-
while index < end {
131-
if index > start && y[index - 1] >= y[index] {
132-
return None;
133-
}
134-
135-
let j: usize = y[index].into();
136-
h.0[i_usize][j] = true;
137-
index += 1;
142+
for &j in indices {
143+
h.0[i][j] = true;
138144
}
139-
}
140145

141-
if y[index..omega].iter().any(|x| *x != 0) {
142-
return None;
146+
start = end;
143147
}
144148

145149
Some(h)

ml-dsa/src/lib.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,16 @@ impl<P: SignatureParams> Signature<P> {
5656
// Algorithm 27 sigDecode
5757
pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
5858
let (c_tilde, z, h) = P::split_sig(&enc);
59-
Some(Self {
60-
c_tilde: c_tilde.clone(),
61-
z: P::decode_z(z),
62-
h: Hint::bit_unpack(h)?,
63-
})
59+
60+
let c_tilde = c_tilde.clone();
61+
let z = P::decode_z(z);
62+
let h = Hint::bit_unpack(h)?;
63+
64+
if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
65+
return None;
66+
}
67+
68+
Some(Self { c_tilde, z, h })
6469
}
6570
}
6671

@@ -164,24 +169,27 @@ impl<P: ParameterSet> SigningKey<P> {
164169
let z = &y + &cs1;
165170
let r0 = (&w - &cs2).low_bits::<P::Gamma2>();
166171

167-
let gamma1_threshold = P::Gamma1::U32 - P::BETA;
168-
let gamma2_threshold = P::Gamma2::U32 - P::BETA;
169-
if r0.infinity_norm() > gamma2_threshold || z.infinity_norm() > gamma1_threshold {
172+
if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
173+
|| r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
174+
{
170175
continue;
171176
}
172177

173178
let ct0 = (&c_hat * &t0_hat).ntt_inverse();
174179
let h = Hint::<P>::new(-&ct0, &(&w - &cs2) + &ct0);
175180

176-
if ct0.infinity_norm() > P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
181+
if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
177182
continue;
178183
}
179184

180185
let z = z.mod_plus_minus(FieldElement(FieldElement::Q));
181186
return Signature { c_tilde, z, h };
182187
}
183188

184-
// TODO(RLB) Make this method fallible
189+
// XXX(RLB) We could be more parsimonious about the number of iterations here, and still
190+
// have an overwhelming probability of success.
191+
// XXX(RLB) I still don't love panicking. Maybe we should expose the fact that this method
192+
// can fail?
185193
panic!("Rejection sampling failed to find a valid signature");
186194
}
187195

@@ -228,17 +236,17 @@ pub struct VerificationKey<P: ParameterSet> {
228236
}
229237

230238
impl<P: ParameterSet> VerificationKey<P> {
231-
pub fn verify(&self, Mp: &[u8], sigma: &Signature<P>) -> bool
239+
pub fn verify_internal(&self, Mp: &[u8], sigma: &Signature<P>) -> bool
232240
where
233241
P: VerificationKeyParams + SignatureParams,
234242
{
235243
// TODO(RLB) pre-compute these and store them on the signing key struct
236244
let A_hat = NttMatrix::<P::K, P::L>::expand_a(&self.rho);
237-
let t1_hat = (FieldElement(1 << 13) * &self.t1).ntt();
245+
let t1 = FieldElement(1 << 13) * &self.t1;
246+
let t1_hat = t1.ntt();
238247
let tr: B64 = H::default().absorb(&self.encode()).squeeze_new();
239248

240249
// Compute the message representative
241-
// XXX(RLB) might need to run bytes_to_bits()?
242250
let mu: B64 = H::default().absorb(&tr).absorb(&Mp).squeeze_new();
243251

244252
// Reconstruct w
@@ -259,7 +267,7 @@ impl<P: ParameterSet> VerificationKey<P> {
259267
.squeeze_new::<P::Lambda>();
260268

261269
let gamma1_threshold = P::Gamma1::U32 - P::BETA;
262-
return sigma.z.infinity_norm() < gamma1_threshold && sigma.c_tilde == cp_tilde;
270+
sigma.c_tilde == cp_tilde
263271
}
264272

265273
// Algorithm 22 pkEncode
@@ -354,6 +362,12 @@ mod test {
354362
let sk_bytes = sk.encode();
355363
let sk2 = SigningKey::<P>::decode(&sk_bytes);
356364
assert!(sk == sk2);
365+
366+
let sig = sk.sign_internal(&[0, 1, 2, 3], (&[0u8; 32]).into());
367+
let sig_bytes = sig.encode();
368+
println!("sig_bytes: {:?}", hex::encode(&sig_bytes));
369+
let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
370+
assert!(sig == sig2);
357371
}
358372

359373
#[test]
@@ -370,14 +384,13 @@ mod test {
370384
let mut rng = rand::thread_rng();
371385

372386
let seed: [u8; 32] = rng.gen();
373-
let (_pk, sk) = SigningKey::<P>::key_gen_internal(&seed.into());
387+
let (pk, sk) = SigningKey::<P>::key_gen_internal(&seed.into());
374388

375389
let rnd: [u8; 32] = rng.gen();
376390
let Mp = b"Hello world";
377-
let _sig = sk.sign_internal(Mp, &rnd.into());
391+
let sig = sk.sign_internal(Mp, &rnd.into());
378392

379-
// TODO(RLB) Re-enable and debug
380-
// assert!(pk.verify(Mp, &sig));
393+
assert!(pk.verify_internal(Mp, &sig));
381394
}
382395

383396
#[test]

ml-dsa/src/param.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ pub trait SignatureParams: ParameterSet {
396396
type HintSize: ArraySize;
397397
type SignatureSize: ArraySize;
398398

399+
const GAMMA1_MINUS_BETA: u32;
400+
const GAMMA2_MINUS_BETA: u32;
401+
402+
fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>);
403+
399404
fn encode_w1(t1: &PolynomialVector<Self::K>) -> EncodedW1<Self>;
400405
fn decode_w1(enc: &EncodedW1<Self>) -> PolynomialVector<Self::K>;
401406

@@ -415,6 +420,8 @@ pub trait SignatureParams: ParameterSet {
415420
pub type EncodedCTilde<P> = Array<u8, <P as ParameterSet>::Lambda>;
416421
pub type EncodedW1<P> = Array<u8, <P as SignatureParams>::W1Size>;
417422
pub type EncodedZ<P> = Array<u8, <P as SignatureParams>::ZSize>;
423+
pub type EncodedHintIndices<P> = Array<u8, <P as ParameterSet>::Omega>;
424+
pub type EncodedHintCuts<P> = Array<u8, <P as ParameterSet>::K>;
418425
pub type EncodedHint<P> = Array<u8, <P as SignatureParams>::HintSize>;
419426
pub type EncodedSignature<P> = Array<u8, <P as SignatureParams>::SignatureSize>;
420427

@@ -435,7 +442,7 @@ where
435442
+ Rem<P::L, Output = U0>,
436443
// Hint
437444
P::Omega: Add<P::K>,
438-
Sum<P::Omega, P::K>: ArraySize,
445+
Sum<P::Omega, P::K>: ArraySize + Sub<P::Omega, Output = P::K>,
439446
// Signature
440447
P::Lambda: Add<Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
441448
Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>:
@@ -459,6 +466,13 @@ where
459466
type HintSize = Sum<P::Omega, P::K>;
460467
type SignatureSize = Sum<Sum<P::Lambda, Self::ZSize>, Self::HintSize>;
461468

469+
const GAMMA1_MINUS_BETA: u32 = P::Gamma1::U32 - P::BETA;
470+
const GAMMA2_MINUS_BETA: u32 = P::Gamma2::U32 - P::BETA;
471+
472+
fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>) {
473+
y.split_ref()
474+
}
475+
462476
fn encode_w1(w1: &PolynomialVector<Self::K>) -> EncodedW1<Self> {
463477
SimpleBitPack::<Self::W1Bits>::pack(w1)
464478
}

ml-dsa/tests/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ The current copies of these files were taken from commit [65370b8] of that repo.
1414
The actual tests to be performed are described in the [ACVP documentation].
1515

1616
[NIST ACVP repository]: https://github.com/usnistgov/ACVP-Server/
17-
[keyGen]: https://github.com/usnistgov/ACVP-Server/blob/master/gen-val/json-files/ML-DSA-keyGen-FIPS204
18-
[sigGen]: https://github.com/usnistgov/ACVP-Server/blob/master/gen-val/json-files/ML-DSA-sigGen-FIPS204
19-
[sigVer]: https://github.com/usnistgov/ACVP-Server/blob/master/gen-val/json-files/ML-DSA-sigVer-FIPS204
20-
[65370b8]: https://github.com/usnistgov/ACVP-Server/commit/65370b861b96efd30dfe0daae607bde26a78a5c8
21-
[ACVP documentation]: https://github.com/usnistgov/ACVP/tree/65370b861b96efd30dfe0daae607bde26a78a5c8/src/ml-dsa/sections
17+
[keyGen]: https://github.com/usnistgov/ACVP-Server/blob/65370b8/gen-val/json-files/ML-DSA-keyGen-FIPS204
18+
[sigGen]: https://github.com/usnistgov/ACVP-Server/blob/65370b8/gen-val/json-files/ML-DSA-sigGen-FIPS204
19+
[sigVer]: https://github.com/usnistgov/ACVP-Server/blob/65370b8/gen-val/json-files/ML-DSA-sigVer-FIPS204
20+
[65370b8]: https://github.com/usnistgov/ACVP-Server/commit/65370b8
21+
[ACVP documentation]: https://github.com/usnistgov/ACVP/tree/master/src/ml-dsa/sections

ml-dsa/tests/sig-ver.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,13 @@ fn verify<P: VerificationKeyParams + SignatureParams>(tg: &acvp::TestGroup, tc:
3131

3232
// Import the signature
3333
let sig_bytes = EncodedSignature::<P>::try_from(tc.signature.as_slice()).unwrap();
34-
let sig = Signature::<P>::decode(&sig_bytes).unwrap();
34+
let sig = Signature::<P>::decode(&sig_bytes);
3535

36-
// Verify the signature
37-
assert!(pk.verify(tc.message.as_slice(), &sig));
36+
// Verify the signature if it successfully decoded
37+
let test_passed = sig
38+
.map(|sig| pk.verify_internal(tc.message.as_slice(), &sig))
39+
.unwrap_or_default();
40+
assert_eq!(test_passed, tc.test_passed);
3841
}
3942

4043
mod acvp {
@@ -77,6 +80,11 @@ mod acvp {
7780
#[serde(rename = "tcId")]
7881
pub id: usize,
7982

83+
#[serde(rename = "testPassed")]
84+
pub test_passed: bool,
85+
86+
pub reason: String,
87+
8088
#[serde(with = "hex::serde")]
8189
pub message: Vec<u8>,
8290

0 commit comments

Comments
 (0)