From 624f56de14889235aaa871e59d0abb41683d8088 Mon Sep 17 00:00:00 2001 From: Rod Chapman Date: Tue, 8 Jul 2025 13:08:19 +0100 Subject: [PATCH] Simplify gen_matrix() and poly_rej_uniform_x4() 1. Simplify poly_rej_uniform_x4() API to explicitly take 4 objects to write, rather than a slice of a larger array. 2. Modify gen_matrix() in light of that. 3. Update test/bench_components_mlkem.c Signed-off-by: Rod Chapman --- mlkem/src/indcpa.c | 9 +---- mlkem/src/sampling.c | 38 +++++++++++-------- mlkem/src/sampling.h | 25 +++++++----- .../poly_rej_uniform_x4_harness.c | 7 +++- test/bench_components_mlkem.c | 4 +- 5 files changed, 49 insertions(+), 34 deletions(-) diff --git a/mlkem/src/indcpa.c b/mlkem/src/indcpa.c index 672c7e892..556f7692e 100644 --- a/mlkem/src/indcpa.c +++ b/mlkem/src/indcpa.c @@ -221,10 +221,9 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES], /* Sample 4 matrix entries a time. */ for (i = 0; i < (MLKEM_K * MLKEM_K / 4) * 4; i += 4) { - uint8_t x, y; - for (j = 0; j < 4; j++) { + uint8_t x, y; x = (i + j) / MLKEM_K; y = (i + j) % MLKEM_K; if (transposed) @@ -239,11 +238,7 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES], } } - /* - * This call writes across mlk_polyvec boundaries for K=2 and K=3. - * This is intentional and safe. - */ - mlk_poly_rej_uniform_x4(&a[i], seed_ext); + mlk_poly_rej_uniform_x4(&a[i], &a[i + 1], &a[i + 2], &a[i + 3], seed_ext); } /* For MLKEM_K == 3, sample the last entry individually. */ diff --git a/mlkem/src/sampling.c b/mlkem/src/sampling.c index be5d931a7..01eecfe47 100644 --- a/mlkem/src/sampling.c +++ b/mlkem/src/sampling.c @@ -146,7 +146,8 @@ __contract__( * - x4-batched version of `rej_uniform()` from the * reference implementation, leveraging x4-batched Keccak-f1600. */ MLK_INTERNAL_API -void mlk_poly_rej_uniform_x4(mlk_poly *vec, +void mlk_poly_rej_uniform_x4(mlk_poly *vec0, mlk_poly *vec1, mlk_poly *vec2, + mlk_poly *vec3, uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)]) { /* Temporary buffers for XOF output before rejection sampling */ @@ -167,10 +168,10 @@ void mlk_poly_rej_uniform_x4(mlk_poly *vec, */ mlk_xof_x4_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &statex); buflen = MLKEM_GEN_MATRIX_NBLOCKS * MLK_XOF_RATE; - ctr[0] = mlk_rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf[0], buflen); - ctr[1] = mlk_rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf[1], buflen); - ctr[2] = mlk_rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf[2], buflen); - ctr[3] = mlk_rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf[3], buflen); + ctr[0] = mlk_rej_uniform(vec0->coeffs, MLKEM_N, 0, buf[0], buflen); + ctr[1] = mlk_rej_uniform(vec1->coeffs, MLKEM_N, 0, buf[1], buflen); + ctr[2] = mlk_rej_uniform(vec2->coeffs, MLKEM_N, 0, buf[2], buflen); + ctr[3] = mlk_rej_uniform(vec3->coeffs, MLKEM_N, 0, buf[3], buflen); /* * So long as not all matrix entries have been generated, squeeze @@ -180,20 +181,27 @@ void mlk_poly_rej_uniform_x4(mlk_poly *vec, while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || ctr[3] < MLKEM_N) __loop__( - assigns(ctr, statex, memory_slice(vec, sizeof(mlk_poly) * 4), object_whole(buf[0]), - object_whole(buf[1]), object_whole(buf[2]), object_whole(buf[3])) + assigns(ctr, statex, + memory_slice(vec0, sizeof(mlk_poly)), + memory_slice(vec1, sizeof(mlk_poly)), + memory_slice(vec2, sizeof(mlk_poly)), + memory_slice(vec3, sizeof(mlk_poly)), + object_whole(buf[0]), + object_whole(buf[1]), + object_whole(buf[2]), + object_whole(buf[3])) invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) - invariant(array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) - invariant(array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) - invariant(array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) - invariant(array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + invariant(array_bound(vec0->coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(array_bound(vec1->coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(array_bound(vec2->coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(array_bound(vec3->coeffs, 0, ctr[3], 0, MLKEM_Q))) { mlk_xof_x4_squeezeblocks(buf, 1, &statex); - ctr[0] = mlk_rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf[0], buflen); - ctr[1] = mlk_rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf[1], buflen); - ctr[2] = mlk_rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf[2], buflen); - ctr[3] = mlk_rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf[3], buflen); + ctr[0] = mlk_rej_uniform(vec0->coeffs, MLKEM_N, ctr[0], buf[0], buflen); + ctr[1] = mlk_rej_uniform(vec1->coeffs, MLKEM_N, ctr[1], buf[1], buflen); + ctr[2] = mlk_rej_uniform(vec2->coeffs, MLKEM_N, ctr[2], buf[2], buflen); + ctr[3] = mlk_rej_uniform(vec3->coeffs, MLKEM_N, ctr[3], buf[3], buflen); } mlk_xof_x4_release(&statex); diff --git a/mlkem/src/sampling.h b/mlkem/src/sampling.h index 2cf43c889..99aa70deb 100644 --- a/mlkem/src/sampling.h +++ b/mlkem/src/sampling.h @@ -65,8 +65,8 @@ void mlk_poly_cbd3(mlk_poly *r, const uint8_t buf[3 * MLKEM_N / 4]); * Description: Generate four polynomials using rejection sampling * on (pseudo-)uniformly random bytes sampled from a seed. * - * Arguments: - mlk_poly *vec: - * Pointer to an array of 4 polynomials to be sampled. + * Arguments: - mlk_poly *vec0, *vec1, *vec2, *vec3: + * Pointers to 4 polynomials to be sampled. * - uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)]: * Pointer consecutive array of seed buffers of size * MLKEM_SYMBYTES + 2 each, plus padding for alignment. @@ -75,16 +75,23 @@ void mlk_poly_cbd3(mlk_poly *r, const uint8_t buf[3 * MLKEM_N / 4]); * **************************************************/ MLK_INTERNAL_API -void mlk_poly_rej_uniform_x4(mlk_poly *vec, +void mlk_poly_rej_uniform_x4(mlk_poly *vec0, mlk_poly *vec1, mlk_poly *vec2, + mlk_poly *vec3, uint8_t seed[4][MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)]) __contract__( - requires(memory_no_alias(vec, sizeof(mlk_poly) * 4)) + requires(memory_no_alias(vec0, sizeof(mlk_poly))) + requires(memory_no_alias(vec1, sizeof(mlk_poly))) + requires(memory_no_alias(vec2, sizeof(mlk_poly))) + requires(memory_no_alias(vec3, sizeof(mlk_poly))) requires(memory_no_alias(seed, 4 * MLK_ALIGN_UP(MLKEM_SYMBYTES + 2))) - assigns(memory_slice(vec, sizeof(mlk_poly) * 4)) - ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) - ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) - ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) - ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))); + assigns(memory_slice(vec0, sizeof(mlk_poly))) + assigns(memory_slice(vec1, sizeof(mlk_poly))) + assigns(memory_slice(vec2, sizeof(mlk_poly))) + assigns(memory_slice(vec3, sizeof(mlk_poly))) + ensures(array_bound(vec0->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec1->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec2->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec3->coeffs, 0, MLKEM_N, 0, MLKEM_Q))); #define mlk_poly_rej_uniform MLK_NAMESPACE(poly_rej_uniform) /************************************************* diff --git a/proofs/cbmc/poly_rej_uniform_x4/poly_rej_uniform_x4_harness.c b/proofs/cbmc/poly_rej_uniform_x4/poly_rej_uniform_x4_harness.c index 19a45b175..024ff1da4 100644 --- a/proofs/cbmc/poly_rej_uniform_x4/poly_rej_uniform_x4_harness.c +++ b/proofs/cbmc/poly_rej_uniform_x4/poly_rej_uniform_x4_harness.c @@ -7,7 +7,10 @@ void harness(void) { - mlk_poly *out; + mlk_poly *out0; + mlk_poly *out1; + mlk_poly *out2; + mlk_poly *out3; uint8_t(*seed)[MLK_ALIGN_UP(MLKEM_SYMBYTES + 2)]; - mlk_poly_rej_uniform_x4(out, seed); + mlk_poly_rej_uniform_x4(out0, out1, out2, out3, seed); } diff --git a/test/bench_components_mlkem.c b/test/bench_components_mlkem.c index 63418bb9e..26858d627 100644 --- a/test/bench_components_mlkem.c +++ b/test/bench_components_mlkem.c @@ -69,7 +69,9 @@ static int bench(void) BENCH("mlk_poly_rej_uniform", mlk_poly_rej_uniform((mlk_poly *)data0, (uint8_t *)data1)) BENCH("mlk_poly_rej_uniform_x4", - mlk_poly_rej_uniform_x4((mlk_poly *)data0, (uint8_t(*)[64])data1)) + mlk_poly_rej_uniform_x4((mlk_poly *)data0, (mlk_poly *)data1, + (mlk_poly *)data2, (mlk_poly *)data3, + (uint8_t(*)[64])data4)) /* mlk_poly */ /* mlk_poly_compress_du */