Skip to content

Simplify gen_matrix() and poly_rej_uniform_x4() #1112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions mlkem/src/indcpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,9 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
/* Sample 4 matrix entries a time. */
for (i = 0; i < (MLKEM_K * MLKEM_K / 4) * 4; i += 4)
{
uint8_t x, y;

for (j = 0; j < 4; j++)
{
uint8_t x, y;
x = (i + j) / MLKEM_K;
y = (i + j) % MLKEM_K;
if (transposed)
Expand All @@ -239,11 +238,7 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
}
}

/*
* This call writes across mlk_polyvec boundaries for K=2 and K=3.
* This is intentional and safe.
*/
mlk_poly_rej_uniform_x4(&a[i], seed_ext);
mlk_poly_rej_uniform_x4(&a[i], &a[i + 1], &a[i + 2], &a[i + 3], seed_ext);
}

/* For MLKEM_K == 3, sample the last entry individually. */
Expand Down
38 changes: 23 additions & 15 deletions mlkem/src/sampling.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ __contract__(
* - x4-batched version of `rej_uniform()` from the
* reference implementation, leveraging x4-batched Keccak-f1600. */
MLK_INTERNAL_API
void mlk_poly_rej_uniform_x4(mlk_poly *vec,
void mlk_poly_rej_uniform_x4(mlk_poly *vec0, mlk_poly *vec1, mlk_poly *vec2,
mlk_poly *vec3,
uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)])
{
/* Temporary buffers for XOF output before rejection sampling */
Expand All @@ -167,10 +168,10 @@ void mlk_poly_rej_uniform_x4(mlk_poly *vec,
*/
mlk_xof_x4_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &statex);
buflen = MLKEM_GEN_MATRIX_NBLOCKS * MLK_XOF_RATE;
ctr[0] = mlk_rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf[0], buflen);
ctr[1] = mlk_rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf[1], buflen);
ctr[2] = mlk_rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf[2], buflen);
ctr[3] = mlk_rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf[3], buflen);
ctr[0] = mlk_rej_uniform(vec0->coeffs, MLKEM_N, 0, buf[0], buflen);
ctr[1] = mlk_rej_uniform(vec1->coeffs, MLKEM_N, 0, buf[1], buflen);
ctr[2] = mlk_rej_uniform(vec2->coeffs, MLKEM_N, 0, buf[2], buflen);
ctr[3] = mlk_rej_uniform(vec3->coeffs, MLKEM_N, 0, buf[3], buflen);

/*
* So long as not all matrix entries have been generated, squeeze
Expand All @@ -180,20 +181,27 @@ void mlk_poly_rej_uniform_x4(mlk_poly *vec,
while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N ||
ctr[3] < MLKEM_N)
__loop__(
assigns(ctr, statex, memory_slice(vec, sizeof(mlk_poly) * 4), object_whole(buf[0]),
object_whole(buf[1]), object_whole(buf[2]), object_whole(buf[3]))
assigns(ctr, statex,
memory_slice(vec0, sizeof(mlk_poly)),
memory_slice(vec1, sizeof(mlk_poly)),
memory_slice(vec2, sizeof(mlk_poly)),
memory_slice(vec3, sizeof(mlk_poly)),
object_whole(buf[0]),
object_whole(buf[1]),
object_whole(buf[2]),
object_whole(buf[3]))
invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N)
invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N)
invariant(array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q))
invariant(array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q))
invariant(array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q))
invariant(array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q)))
invariant(array_bound(vec0->coeffs, 0, ctr[0], 0, MLKEM_Q))
invariant(array_bound(vec1->coeffs, 0, ctr[1], 0, MLKEM_Q))
invariant(array_bound(vec2->coeffs, 0, ctr[2], 0, MLKEM_Q))
invariant(array_bound(vec3->coeffs, 0, ctr[3], 0, MLKEM_Q)))
{
mlk_xof_x4_squeezeblocks(buf, 1, &statex);
ctr[0] = mlk_rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf[0], buflen);
ctr[1] = mlk_rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf[1], buflen);
ctr[2] = mlk_rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf[2], buflen);
ctr[3] = mlk_rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf[3], buflen);
ctr[0] = mlk_rej_uniform(vec0->coeffs, MLKEM_N, ctr[0], buf[0], buflen);
ctr[1] = mlk_rej_uniform(vec1->coeffs, MLKEM_N, ctr[1], buf[1], buflen);
ctr[2] = mlk_rej_uniform(vec2->coeffs, MLKEM_N, ctr[2], buf[2], buflen);
ctr[3] = mlk_rej_uniform(vec3->coeffs, MLKEM_N, ctr[3], buf[3], buflen);
}

mlk_xof_x4_release(&statex);
Expand Down
25 changes: 16 additions & 9 deletions mlkem/src/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ void mlk_poly_cbd3(mlk_poly *r, const uint8_t buf[3 * MLKEM_N / 4]);
* Description: Generate four polynomials using rejection sampling
* on (pseudo-)uniformly random bytes sampled from a seed.
*
* Arguments: - mlk_poly *vec:
* Pointer to an array of 4 polynomials to be sampled.
* Arguments: - mlk_poly *vec0, *vec1, *vec2, *vec3:
* Pointers to 4 polynomials to be sampled.
* - uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)]:
* Pointer consecutive array of seed buffers of size
* MLKEM_SYMBYTES + 2 each, plus padding for alignment.
Expand All @@ -75,16 +75,23 @@ void mlk_poly_cbd3(mlk_poly *r, const uint8_t buf[3 * MLKEM_N / 4]);
*
**************************************************/
MLK_INTERNAL_API
void mlk_poly_rej_uniform_x4(mlk_poly *vec,
void mlk_poly_rej_uniform_x4(mlk_poly *vec0, mlk_poly *vec1, mlk_poly *vec2,
mlk_poly *vec3,
uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)])
__contract__(
requires(memory_no_alias(vec, sizeof(mlk_poly) * 4))
requires(memory_no_alias(vec0, sizeof(mlk_poly)))
requires(memory_no_alias(vec1, sizeof(mlk_poly)))
requires(memory_no_alias(vec2, sizeof(mlk_poly)))
requires(memory_no_alias(vec3, sizeof(mlk_poly)))
requires(memory_no_alias(seed, 4 * MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)))
assigns(memory_slice(vec, sizeof(mlk_poly) * 4))
ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))
ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q))
ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q))
ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q)));
assigns(memory_slice(vec0, sizeof(mlk_poly)))
assigns(memory_slice(vec1, sizeof(mlk_poly)))
assigns(memory_slice(vec2, sizeof(mlk_poly)))
assigns(memory_slice(vec3, sizeof(mlk_poly)))
ensures(array_bound(vec0->coeffs, 0, MLKEM_N, 0, MLKEM_Q))
ensures(array_bound(vec1->coeffs, 0, MLKEM_N, 0, MLKEM_Q))
ensures(array_bound(vec2->coeffs, 0, MLKEM_N, 0, MLKEM_Q))
ensures(array_bound(vec3->coeffs, 0, MLKEM_N, 0, MLKEM_Q)));

#define mlk_poly_rej_uniform MLK_NAMESPACE(poly_rej_uniform)
/*************************************************
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

void harness(void)
{
mlk_poly *out;
mlk_poly *out0;
mlk_poly *out1;
mlk_poly *out2;
mlk_poly *out3;
uint8_t(*seed)[MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)];
mlk_poly_rej_uniform_x4(out, seed);
mlk_poly_rej_uniform_x4(out0, out1, out2, out3, seed);
}
4 changes: 3 additions & 1 deletion test/bench_components_mlkem.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ static int bench(void)
BENCH("mlk_poly_rej_uniform",
mlk_poly_rej_uniform((mlk_poly *)data0, (uint8_t *)data1))
BENCH("mlk_poly_rej_uniform_x4",
mlk_poly_rej_uniform_x4((mlk_poly *)data0, (uint8_t(*)[64])data1))
mlk_poly_rej_uniform_x4((mlk_poly *)data0, (mlk_poly *)data1,
(mlk_poly *)data2, (mlk_poly *)data3,
(uint8_t(*)[64])data4))

/* mlk_poly */
/* mlk_poly_compress_du */
Expand Down