Skip to content

Commit 5831ec5

Browse files
committed
Simplify gen_matrix() and poly_rej_uniform_x4()
1. Simplify poly_rej_uniform_x4() API to explicitly take 4 objects to write, rather than a slice of a larger array. 2. Modify gen_matrix() in light of that. 3. Update test/bench_components_mlkem.c Signed-off-by: Rod Chapman <rodchap@amazon.com>
1 parent 65e4512 commit 5831ec5

File tree

5 files changed

+49
-34
lines changed

5 files changed

+49
-34
lines changed

mlkem/src/indcpa.c

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,9 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
221221
/* Sample 4 matrix entries a time. */
222222
for (i = 0; i < (MLKEM_K * MLKEM_K / 4) * 4; i += 4)
223223
{
224-
uint8_t x, y;
225-
226224
for (j = 0; j < 4; j++)
227225
{
226+
uint8_t x, y;
228227
x = (i + j) / MLKEM_K;
229228
y = (i + j) % MLKEM_K;
230229
if (transposed)
@@ -239,11 +238,7 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
239238
}
240239
}
241240

242-
/*
243-
* This call writes across mlk_polyvec boundaries for K=2 and K=3.
244-
* This is intentional and safe.
245-
*/
246-
mlk_poly_rej_uniform_x4(&a[i], seed_ext);
241+
mlk_poly_rej_uniform_x4(&a[i], &a[i + 1], &a[i + 2], &a[i + 3], seed_ext);
247242
}
248243

249244
/* For MLKEM_K == 3, sample the last entry individually. */

mlkem/src/sampling.c

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ __contract__(
146146
* - x4-batched version of `rej_uniform()` from the
147147
* reference implementation, leveraging x4-batched Keccak-f1600. */
148148
MLK_INTERNAL_API
149-
void mlk_poly_rej_uniform_x4(mlk_poly *vec,
149+
void mlk_poly_rej_uniform_x4(mlk_poly *vec0, mlk_poly *vec1, mlk_poly *vec2,
150+
mlk_poly *vec3,
150151
uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)])
151152
{
152153
/* Temporary buffers for XOF output before rejection sampling */
@@ -167,10 +168,10 @@ void mlk_poly_rej_uniform_x4(mlk_poly *vec,
167168
*/
168169
mlk_xof_x4_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &statex);
169170
buflen = MLKEM_GEN_MATRIX_NBLOCKS * MLK_XOF_RATE;
170-
ctr[0] = mlk_rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf[0], buflen);
171-
ctr[1] = mlk_rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf[1], buflen);
172-
ctr[2] = mlk_rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf[2], buflen);
173-
ctr[3] = mlk_rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf[3], buflen);
171+
ctr[0] = mlk_rej_uniform(vec0->coeffs, MLKEM_N, 0, buf[0], buflen);
172+
ctr[1] = mlk_rej_uniform(vec1->coeffs, MLKEM_N, 0, buf[1], buflen);
173+
ctr[2] = mlk_rej_uniform(vec2->coeffs, MLKEM_N, 0, buf[2], buflen);
174+
ctr[3] = mlk_rej_uniform(vec3->coeffs, MLKEM_N, 0, buf[3], buflen);
174175

175176
/*
176177
* So long as not all matrix entries have been generated, squeeze
@@ -180,20 +181,27 @@ void mlk_poly_rej_uniform_x4(mlk_poly *vec,
180181
while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N ||
181182
ctr[3] < MLKEM_N)
182183
__loop__(
183-
assigns(ctr, statex, memory_slice(vec, sizeof(mlk_poly) * 4), object_whole(buf[0]),
184-
object_whole(buf[1]), object_whole(buf[2]), object_whole(buf[3]))
184+
assigns(ctr, statex,
185+
memory_slice(vec0, sizeof(mlk_poly)),
186+
memory_slice(vec1, sizeof(mlk_poly)),
187+
memory_slice(vec2, sizeof(mlk_poly)),
188+
memory_slice(vec3, sizeof(mlk_poly)),
189+
object_whole(buf[0]),
190+
object_whole(buf[1]),
191+
object_whole(buf[2]),
192+
object_whole(buf[3]))
185193
invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N)
186194
invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N)
187-
invariant(array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q))
188-
invariant(array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q))
189-
invariant(array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q))
190-
invariant(array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q)))
195+
invariant(array_bound(vec0->coeffs, 0, ctr[0], 0, MLKEM_Q))
196+
invariant(array_bound(vec1->coeffs, 0, ctr[1], 0, MLKEM_Q))
197+
invariant(array_bound(vec2->coeffs, 0, ctr[2], 0, MLKEM_Q))
198+
invariant(array_bound(vec3->coeffs, 0, ctr[3], 0, MLKEM_Q)))
191199
{
192200
mlk_xof_x4_squeezeblocks(buf, 1, &statex);
193-
ctr[0] = mlk_rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf[0], buflen);
194-
ctr[1] = mlk_rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf[1], buflen);
195-
ctr[2] = mlk_rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf[2], buflen);
196-
ctr[3] = mlk_rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf[3], buflen);
201+
ctr[0] = mlk_rej_uniform(vec0->coeffs, MLKEM_N, ctr[0], buf[0], buflen);
202+
ctr[1] = mlk_rej_uniform(vec1->coeffs, MLKEM_N, ctr[1], buf[1], buflen);
203+
ctr[2] = mlk_rej_uniform(vec2->coeffs, MLKEM_N, ctr[2], buf[2], buflen);
204+
ctr[3] = mlk_rej_uniform(vec3->coeffs, MLKEM_N, ctr[3], buf[3], buflen);
197205
}
198206

199207
mlk_xof_x4_release(&statex);

mlkem/src/sampling.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void mlk_poly_cbd3(mlk_poly *r, const uint8_t buf[3 * MLKEM_N / 4]);
6565
* Description: Generate four polynomials using rejection sampling
6666
* on (pseudo-)uniformly random bytes sampled from a seed.
6767
*
68-
* Arguments: - mlk_poly *vec:
69-
* Pointer to an array of 4 polynomials to be sampled.
68+
* Arguments: - mlk_poly *vec0, *vec1, *vec2, *vec3:
69+
* Pointers to 4 polynomials to be sampled.
7070
* - uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)]:
7171
* Pointer consecutive array of seed buffers of size
7272
* MLKEM_SYMBYTES + 2 each, plus padding for alignment.
@@ -75,16 +75,23 @@ void mlk_poly_cbd3(mlk_poly *r, const uint8_t buf[3 * MLKEM_N / 4]);
7575
*
7676
**************************************************/
7777
MLK_INTERNAL_API
78-
void mlk_poly_rej_uniform_x4(mlk_poly *vec,
78+
void mlk_poly_rej_uniform_x4(mlk_poly *vec0, mlk_poly *vec1, mlk_poly *vec2,
79+
mlk_poly *vec3,
7980
uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)])
8081
__contract__(
81-
requires(memory_no_alias(vec, sizeof(mlk_poly) * 4))
82+
requires(memory_no_alias(vec0, sizeof(mlk_poly)))
83+
requires(memory_no_alias(vec1, sizeof(mlk_poly)))
84+
requires(memory_no_alias(vec2, sizeof(mlk_poly)))
85+
requires(memory_no_alias(vec3, sizeof(mlk_poly)))
8286
requires(memory_no_alias(seed, 4 * MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)))
83-
assigns(memory_slice(vec, sizeof(mlk_poly) * 4))
84-
ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))
85-
ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q))
86-
ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q))
87-
ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q)));
87+
assigns(memory_slice(vec0, sizeof(mlk_poly)))
88+
assigns(memory_slice(vec1, sizeof(mlk_poly)))
89+
assigns(memory_slice(vec2, sizeof(mlk_poly)))
90+
assigns(memory_slice(vec3, sizeof(mlk_poly)))
91+
ensures(array_bound(vec0->coeffs, 0, MLKEM_N, 0, MLKEM_Q))
92+
ensures(array_bound(vec1->coeffs, 0, MLKEM_N, 0, MLKEM_Q))
93+
ensures(array_bound(vec2->coeffs, 0, MLKEM_N, 0, MLKEM_Q))
94+
ensures(array_bound(vec3->coeffs, 0, MLKEM_N, 0, MLKEM_Q)));
8895

8996
#define mlk_poly_rej_uniform MLK_NAMESPACE(poly_rej_uniform)
9097
/*************************************************

proofs/cbmc/poly_rej_uniform_x4/poly_rej_uniform_x4_harness.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
void harness(void)
99
{
10-
mlk_poly *out;
10+
mlk_poly *out0;
11+
mlk_poly *out1;
12+
mlk_poly *out2;
13+
mlk_poly *out3;
1114
uint8_t(*seed)[MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)];
12-
mlk_poly_rej_uniform_x4(out, seed);
15+
mlk_poly_rej_uniform_x4(out0, out1, out2, out3, seed);
1316
}

test/bench_components_mlkem.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ static int bench(void)
6969
BENCH("mlk_poly_rej_uniform",
7070
mlk_poly_rej_uniform((mlk_poly *)data0, (uint8_t *)data1))
7171
BENCH("mlk_poly_rej_uniform_x4",
72-
mlk_poly_rej_uniform_x4((mlk_poly *)data0, (uint8_t(*)[64])data1))
72+
mlk_poly_rej_uniform_x4((mlk_poly *)data0, (mlk_poly *)data1,
73+
(mlk_poly *)data2, (mlk_poly *)data3,
74+
(uint8_t (*)[64])data4))
7375

7476
/* mlk_poly */
7577
/* mlk_poly_compress_du */

0 commit comments

Comments
 (0)