Skip to content

Commit 487a7ce

Browse files
committed
Precompute some values as part of key generation / decoding
1 parent 2bd9076 commit 487a7ce

File tree

1 file changed

+95
-63
lines changed

1 file changed

+95
-63
lines changed

ml-dsa/src/lib.rs

Lines changed: 95 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,44 @@ pub struct SigningKey<P: ParameterSet> {
7878
s1: PolynomialVector<P::L>,
7979
s2: PolynomialVector<P::K>,
8080
t0: PolynomialVector<P::K>,
81+
82+
// Derived values
83+
s1_hat: NttVector<P::L>,
84+
s2_hat: NttVector<P::K>,
85+
t0_hat: NttVector<P::K>,
86+
A_hat: NttMatrix<P::K, P::L>,
8187
}
8288

8389
impl<P: ParameterSet> SigningKey<P> {
90+
fn new(
91+
rho: B32,
92+
K: B32,
93+
tr: B64,
94+
s1: PolynomialVector<P::L>,
95+
s2: PolynomialVector<P::K>,
96+
t0: PolynomialVector<P::K>,
97+
A_hat: Option<NttMatrix<P::K, P::L>>,
98+
) -> Self {
99+
let A_hat = A_hat.unwrap_or_else(|| NttMatrix::expand_a(&rho));
100+
let s1_hat = s1.ntt();
101+
let s2_hat = s2.ntt();
102+
let t0_hat = t0.ntt();
103+
104+
Self {
105+
rho,
106+
K,
107+
tr,
108+
s1,
109+
s2,
110+
t0,
111+
112+
s1_hat,
113+
s2_hat,
114+
t0_hat,
115+
A_hat,
116+
}
117+
}
118+
84119
/// Deterministically generate a signing key pair from the specified seed
85120
pub fn key_gen_internal(xi: &B32) -> (VerificationKey<P>, SigningKey<P>)
86121
where
@@ -97,32 +132,19 @@ impl<P: ParameterSet> SigningKey<P> {
97132
let K: B32 = h.squeeze_new();
98133

99134
// Sample private key components
100-
let A = NttMatrix::<P::K, P::L>::expand_a(&rho);
135+
let A_hat = NttMatrix::<P::K, P::L>::expand_a(&rho);
101136
let s1 = PolynomialVector::<P::L>::expand_s(&rhop, P::Eta::ETA, 0);
102137
let s2 = PolynomialVector::<P::K>::expand_s(&rhop, P::Eta::ETA, P::L::USIZE);
103138

104139
// Compute derived values
105-
let As1 = &A * &s1.ntt();
106-
let t = &As1.ntt_inverse() + &s2;
140+
let As1_hat = &A_hat * &s1.ntt();
141+
let t = &As1_hat.ntt_inverse() + &s2;
107142

108143
// Compress and encode
109144
let (t1, t0) = t.power2round();
110145

111-
let vk = VerificationKey {
112-
rho: rho.clone(),
113-
t1,
114-
};
115-
116-
let tr = H::default().absorb(&vk.encode()).squeeze_new();
117-
118-
let sk = Self {
119-
rho,
120-
K,
121-
tr,
122-
s1,
123-
s2,
124-
t0,
125-
};
146+
let vk = VerificationKey::new(rho, t1, Some(A_hat.clone()), None);
147+
let sk = Self::new(rho, K, vk.tr.clone(), s1, s2, t0, Some(A_hat));
126148

127149
(vk, sk)
128150
}
@@ -132,12 +154,6 @@ impl<P: ParameterSet> SigningKey<P> {
132154
where
133155
P: SignatureParams,
134156
{
135-
// TODO(RLB) pre-compute these and store them on the signing key struct
136-
let s1_hat = self.s1.ntt();
137-
let s2_hat = self.s2.ntt();
138-
let t0_hat = self.t0.ntt();
139-
let A_hat = NttMatrix::<P::K, P::L>::expand_a(&self.rho);
140-
141157
// Compute the message representative
142158
// XXX(RLB) Should the API represent this as an input?
143159
let mu: B64 = H::default().absorb(&self.tr).absorb(&Mp).squeeze_new();
@@ -152,7 +168,7 @@ impl<P: ParameterSet> SigningKey<P> {
152168
// Rejection sampling loop
153169
for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
154170
let y = PolynomialVector::<P::L>::expand_mask::<P::Gamma1>(&rhopp, kappa);
155-
let w = (&A_hat * &y.ntt()).ntt_inverse();
171+
let w = (&self.A_hat * &y.ntt()).ntt_inverse();
156172
let w1 = w.high_bits::<P::Gamma2>();
157173

158174
let w1_tilde = P::encode_w1(&w1);
@@ -163,8 +179,8 @@ impl<P: ParameterSet> SigningKey<P> {
163179
let c = Polynomial::sample_in_ball(&c_tilde, P::TAU);
164180
let c_hat = c.ntt();
165181

166-
let cs1 = (&c_hat * &s1_hat).ntt_inverse();
167-
let cs2 = (&c_hat * &s2_hat).ntt_inverse();
182+
let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
183+
let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
168184

169185
let z = &y + &cs1;
170186
let r0 = (&w - &cs2).low_bits::<P::Gamma2>();
@@ -175,7 +191,7 @@ impl<P: ParameterSet> SigningKey<P> {
175191
continue;
176192
}
177193

178-
let ct0 = (&c_hat * &t0_hat).ntt_inverse();
194+
let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
179195
let h = Hint::<P>::new(-&ct0, &(&w - &cs2) + &ct0);
180196

181197
if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
@@ -217,14 +233,15 @@ impl<P: ParameterSet> SigningKey<P> {
217233
P: SigningKeyParams,
218234
{
219235
let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
220-
Self {
221-
rho: rho.clone(),
222-
K: K.clone(),
223-
tr: tr.clone(),
224-
s1: P::decode_s1(s1_enc),
225-
s2: P::decode_s2(s2_enc),
226-
t0: P::decode_t0(t0_enc),
227-
}
236+
Self::new(
237+
rho.clone(),
238+
K.clone(),
239+
tr.clone(),
240+
P::decode_s1(s1_enc),
241+
P::decode_s2(s2_enc),
242+
P::decode_t0(t0_enc),
243+
None,
244+
)
228245
}
229246
}
230247

@@ -233,31 +250,30 @@ impl<P: ParameterSet> SigningKey<P> {
233250
pub struct VerificationKey<P: ParameterSet> {
234251
rho: B32,
235252
t1: PolynomialVector<P::K>,
253+
254+
// Derived values
255+
A_hat: NttMatrix<P::K, P::L>,
256+
t1_2d_hat: NttVector<P::K>,
257+
tr: B64,
236258
}
237259

238-
impl<P: ParameterSet> VerificationKey<P> {
260+
impl<P: VerificationKeyParams> VerificationKey<P> {
239261
pub fn verify_internal(&self, Mp: &[u8], sigma: &Signature<P>) -> bool
240262
where
241-
P: VerificationKeyParams + SignatureParams,
263+
P: SignatureParams,
242264
{
243-
// TODO(RLB) pre-compute these and store them on the signing key struct
244-
let A_hat = NttMatrix::<P::K, P::L>::expand_a(&self.rho);
245-
let t1 = FieldElement(1 << 13) * &self.t1;
246-
let t1_hat = t1.ntt();
247-
let tr: B64 = H::default().absorb(&self.encode()).squeeze_new();
248-
249265
// Compute the message representative
250-
let mu: B64 = H::default().absorb(&tr).absorb(&Mp).squeeze_new();
266+
let mu: B64 = H::default().absorb(&self.tr).absorb(&Mp).squeeze_new();
251267

252268
// Reconstruct w
253269
let c = Polynomial::sample_in_ball(&sigma.c_tilde, P::TAU);
254270

255271
let z_hat = sigma.z.ntt();
256272
let c_hat = c.ntt();
257-
let Az_hat = &A_hat * &z_hat;
258-
let ct1_hat = &c_hat * &t1_hat;
273+
let Az_hat = &self.A_hat * &z_hat;
274+
let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
259275

260-
let wp_approx = (&Az_hat - &ct1_hat).ntt_inverse();
276+
let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
261277
let w1p = sigma.h.use_hint(&wp_approx);
262278

263279
let w1p_tilde = P::encode_w1(&w1p);
@@ -266,29 +282,45 @@ impl<P: ParameterSet> VerificationKey<P> {
266282
.absorb(&w1p_tilde)
267283
.squeeze_new::<P::Lambda>();
268284

269-
let gamma1_threshold = P::Gamma1::U32 - P::BETA;
270285
sigma.c_tilde == cp_tilde
271286
}
272287

288+
fn encode_internal(rho: &B32, t1: &PolynomialVector<P::K>) -> EncodedVerificationKey<P> {
289+
let t1_enc = P::encode_t1(t1);
290+
P::concat_vk(rho.clone(), t1_enc)
291+
}
292+
293+
fn new(
294+
rho: B32,
295+
t1: PolynomialVector<P::K>,
296+
A_hat: Option<NttMatrix<P::K, P::L>>,
297+
enc: Option<EncodedVerificationKey<P>>,
298+
) -> Self {
299+
let A_hat = A_hat.unwrap_or_else(|| NttMatrix::expand_a(&rho));
300+
let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
301+
302+
let t1_2d_hat = (FieldElement(1 << 13) * &t1).ntt();
303+
let tr: B64 = H::default().absorb(&enc).squeeze_new();
304+
305+
Self {
306+
rho,
307+
t1,
308+
A_hat,
309+
t1_2d_hat,
310+
tr,
311+
}
312+
}
313+
273314
// Algorithm 22 pkEncode
274-
pub fn encode(&self) -> EncodedVerificationKey<P>
275-
where
276-
P: VerificationKeyParams,
277-
{
278-
let t1 = P::encode_t1(&self.t1);
279-
P::concat_vk(self.rho.clone(), t1)
315+
pub fn encode(&self) -> EncodedVerificationKey<P> {
316+
Self::encode_internal(&self.rho, &self.t1)
280317
}
281318

282319
// Algorithm 23 pkDecode
283-
pub fn decode(enc: &EncodedVerificationKey<P>) -> Self
284-
where
285-
P: VerificationKeyParams,
286-
{
320+
pub fn decode(enc: &EncodedVerificationKey<P>) -> Self {
287321
let (rho, t1_enc) = P::split_vk(enc);
288-
Self {
289-
rho: rho.clone(),
290-
t1: P::decode_t1(t1_enc),
291-
}
322+
let t1 = P::decode_t1(t1_enc);
323+
Self::new(rho.clone(), t1, None, Some(enc.clone()))
292324
}
293325
}
294326

0 commit comments

Comments
 (0)