Skip to content

CBMC: Refine bounds for input and output of base multiplication #906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 90 additions & 67 deletions mlkem/src/indcpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
* Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L19]
*
**************************************************/
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec pk,
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES],
const mlk_polyvec *pk,
const uint8_t seed[MLKEM_SYMBYTES])
{
mlk_assert_bound_2d(pk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
Expand All @@ -83,7 +84,7 @@ static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec pk,
* Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L2-3]
*
**************************************************/
static void mlk_unpack_pk(mlk_polyvec pk, uint8_t seed[MLKEM_SYMBYTES],
static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES])
{
mlk_polyvec_frombytes(pk, packedpk);
Expand All @@ -108,7 +109,8 @@ static void mlk_unpack_pk(mlk_polyvec pk, uint8_t seed[MLKEM_SYMBYTES],
* Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L20]
*
**************************************************/
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec sk)
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES],
const mlk_polyvec *sk)
{
mlk_assert_bound_2d(sk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
mlk_polyvec_tobytes(r, sk);
Expand All @@ -128,7 +130,7 @@ static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec sk)
* Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L5]
*
**************************************************/
static void mlk_unpack_sk(mlk_polyvec sk,
static void mlk_unpack_sk(mlk_polyvec *sk,
const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES])
{
mlk_polyvec_frombytes(sk, packedsk);
Expand All @@ -149,8 +151,8 @@ static void mlk_unpack_sk(mlk_polyvec sk,
* Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L22-23]
*
**************************************************/
static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], mlk_polyvec b,
mlk_poly *v)
static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES],
const mlk_polyvec *b, mlk_poly *v)
{
mlk_polyvec_compress_du(r, b);
mlk_poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v);
Expand All @@ -170,7 +172,7 @@ static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], mlk_polyvec b,
* Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L1-4]
*
**************************************************/
static void mlk_unpack_ciphertext(mlk_polyvec b, mlk_poly *v,
static void mlk_unpack_ciphertext(mlk_polyvec *b, mlk_poly *v,
const uint8_t c[MLKEM_INDCPA_BYTES])
{
mlk_polyvec_decompress_du(b, c);
Expand Down Expand Up @@ -201,7 +203,7 @@ __contract__(
*
* Not static for benchmarking */
MLK_INTERNAL_API
void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
int transposed)
{
unsigned i, j;
Expand Down Expand Up @@ -238,7 +240,11 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
}
}

mlk_poly_rej_uniform_x4(&a[i], &a[i + 1], &a[i + 2], &a[i + 3], seed_ext);
mlk_poly_rej_uniform_x4(&a->vec[i / MLKEM_K].vec[i % MLKEM_K],
&a->vec[(i + 1) / MLKEM_K].vec[(i + 1) % MLKEM_K],
&a->vec[(i + 2) / MLKEM_K].vec[(i + 2) % MLKEM_K],
&a->vec[(i + 3) / MLKEM_K].vec[(i + 3) % MLKEM_K],
seed_ext);
}

/* For MLKEM_K == 3, sample the last entry individually. */
Expand All @@ -259,7 +265,7 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
seed_ext[0][MLKEM_SYMBYTES + 1] = x;
}

mlk_poly_rej_uniform(&a[i], seed_ext[0]);
mlk_poly_rej_uniform(&a->vec[i / MLKEM_K].vec[i % MLKEM_K], seed_ext[0]);
i++;
}

Expand All @@ -271,7 +277,8 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
*/
for (i = 0; i < MLKEM_K * MLKEM_K; i++)
{
mlk_poly_permute_bitrev_to_custom(a[i].coeffs);
mlk_poly_permute_bitrev_to_custom(
a->vec[i / MLKEM_K].vec[i % MLKEM_K].coeffs);
}

/* Specification: Partially implements
Expand All @@ -289,31 +296,42 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
* - mlk_polymat a: Input matrix. Must be in NTT domain
* and have coefficients of absolute value < 4096.
* - mlk_polyvec v: Input polynomial vector. Must be in NTT
* domain.
* domain and have coefficients of absolute value
* < MLK_NTT_BOUND.
* - mlk_polyvec vc: Mulcache for v, computed via
* mlk_polyvec_mulcache_compute().
* mlk_polyvec_mulcache_compute(). Must have coefficients
* of absolute value < MLKEM_Q.
*
* Specification: Implements @[FIPS203, Section 2.4.7, Eq (2.12), (2.13)]
*
**************************************************/
static void mlk_matvec_mul(mlk_polyvec out, const mlk_polymat a,
const mlk_polyvec v, const mlk_polyvec_mulcache vc)
static void mlk_matvec_mul(mlk_polyvec *out, const mlk_polymat *a,
const mlk_polyvec *v, const mlk_polyvec_mulcache *vc)
__contract__(
requires(memory_no_alias(out, sizeof(mlk_polyvec)))
requires(memory_no_alias(a, sizeof(mlk_polymat)))
requires(memory_no_alias(v, sizeof(mlk_polyvec)))
requires(memory_no_alias(vc, sizeof(mlk_polyvec_mulcache)))
requires(forall(k0, 0, MLKEM_K * MLKEM_K,
array_bound(a[k0].coeffs, 0, MLKEM_N, 0, MLKEM_UINT12_LIMIT)))
assigns(object_whole(out)))
requires(forall(k0, 0, MLKEM_K,
forall(k1, 0, MLKEM_K,
array_bound(a->vec[k0].vec[k1].coeffs, 0, MLKEM_N, 0, MLKEM_UINT12_LIMIT))))
requires(forall(k1, 0, MLKEM_K,
array_abs_bound(v->vec[k1].coeffs, 0, MLKEM_N, MLK_NTT_BOUND)))
requires(forall(k2, 0, MLKEM_K,
array_abs_bound(vc->vec[k2].coeffs, 0, MLKEM_N/2, MLKEM_Q)))
assigns(memory_slice(out, sizeof(mlk_polyvec)))
ensures(forall(k3, 0, MLKEM_K,
array_abs_bound(out->vec[k3].coeffs, 0, MLKEM_N, INT16_MAX/2))))
{
unsigned i;
for (i = 0; i < MLKEM_K; i++)
__loop__(
assigns(i, object_whole(out))
invariant(i <= MLKEM_K))
assigns(i, memory_slice(out, sizeof(mlk_polyvec)))
invariant(i <= MLKEM_K)
invariant(forall(k, 0, i,
array_abs_bound(out->vec[k].coeffs, 0, MLKEM_N, INT16_MAX/2))))
{
mlk_polyvec_basemul_acc_montgomery_cached(&out[i], &a[MLKEM_K * i], v, vc);
mlk_polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a->vec[i], v, vc);
}
}

Expand Down Expand Up @@ -352,47 +370,49 @@ void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
*/
MLK_CT_TESTING_DECLASSIFY(publicseed, MLKEM_SYMBYTES);

mlk_gen_matrix(a, publicseed, 0 /* no transpose */);
mlk_gen_matrix(&a, publicseed, 0 /* no transpose */);

#if MLKEM_K == 2
mlk_poly_getnoise_eta1_4x(&skpv[0], &skpv[1], &e[0], &e[1], noiseseed, 0, 1,
2, 3);
mlk_poly_getnoise_eta1_4x(&skpv.vec[0], &skpv.vec[1], &e.vec[0], &e.vec[1],
noiseseed, 0, 1, 2, 3);
#elif MLKEM_K == 3
/*
* Only the first three output buffers are needed.
* The laster parameter is a dummy that's overwritten later.
*/
mlk_poly_getnoise_eta1_4x(&skpv[0], &skpv[1], &skpv[2],
&pkpv[0] /* irrelevant */, noiseseed, 0, 1, 2,
mlk_poly_getnoise_eta1_4x(&skpv.vec[0], &skpv.vec[1], &skpv.vec[2],
&pkpv.vec[0] /* irrelevant */, noiseseed, 0, 1, 2,
0xFF /* irrelevant */);
/* Same here */
mlk_poly_getnoise_eta1_4x(&e[0], &e[1], &e[2], &pkpv[0] /* irrelevant */,
noiseseed, 3, 4, 5, 0xFF /* irrelevant */);
mlk_poly_getnoise_eta1_4x(&e.vec[0], &e.vec[1], &e.vec[2],
&pkpv.vec[0] /* irrelevant */, noiseseed, 3, 4, 5,
0xFF /* irrelevant */);
#elif MLKEM_K == 4
mlk_poly_getnoise_eta1_4x(&skpv[0], &skpv[1], &skpv[2], &skpv[3], noiseseed,
0, 1, 2, 3);
mlk_poly_getnoise_eta1_4x(&e[0], &e[1], &e[2], &e[3], noiseseed, 4, 5, 6, 7);
#endif
mlk_poly_getnoise_eta1_4x(&skpv.vec[0], &skpv.vec[1], &skpv.vec[2],
&skpv.vec[3], noiseseed, 0, 1, 2, 3);
mlk_poly_getnoise_eta1_4x(&e.vec[0], &e.vec[1], &e.vec[2], &e.vec[3],
noiseseed, 4, 5, 6, 7);
#endif /* MLKEM_K == 4 */

mlk_polyvec_ntt(skpv);
mlk_polyvec_ntt(e);
mlk_polyvec_ntt(&skpv);
mlk_polyvec_ntt(&e);

mlk_polyvec_mulcache_compute(skpv_cache, skpv);
mlk_matvec_mul(pkpv, a, skpv, skpv_cache);
mlk_polyvec_tomont(pkpv);
mlk_polyvec_mulcache_compute(&skpv_cache, &skpv);
mlk_matvec_mul(&pkpv, &a, &skpv, &skpv_cache);
mlk_polyvec_tomont(&pkpv);

mlk_polyvec_add(pkpv, e);
mlk_polyvec_reduce(pkpv);
mlk_polyvec_reduce(skpv);
mlk_polyvec_add(&pkpv, &e);
mlk_polyvec_reduce(&pkpv);
mlk_polyvec_reduce(&skpv);

mlk_pack_sk(sk, skpv);
mlk_pack_pk(pk, pkpv, publicseed);
mlk_pack_sk(sk, &skpv);
mlk_pack_pk(pk, &pkpv, publicseed);

/* Specification: Partially implements
* @[FIPS203, Section 3.3, Destruction of intermediate values] */
mlk_zeroize(buf, sizeof(buf));
mlk_zeroize(coins_with_domain_separator, sizeof(coins_with_domain_separator));
mlk_zeroize(a, sizeof(a));
mlk_zeroize(&a, sizeof(a));
mlk_zeroize(&e, sizeof(e));
mlk_zeroize(&skpv, sizeof(skpv));
mlk_zeroize(&skpv_cache, sizeof(skpv_cache));
Expand All @@ -418,7 +438,7 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
mlk_poly v, k, epp;
mlk_polyvec_mulcache sp_cache;

mlk_unpack_pk(pkpv, seed, pk);
mlk_unpack_pk(&pkpv, seed, pk);
mlk_poly_frommsg(&k, m);

/*
Expand All @@ -429,44 +449,47 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
*/
MLK_CT_TESTING_DECLASSIFY(seed, MLKEM_SYMBYTES);

mlk_gen_matrix(at, seed, 1 /* transpose */);
mlk_gen_matrix(&at, seed, 1 /* transpose */);

#if MLKEM_K == 2
mlk_poly_getnoise_eta1122_4x(&sp[0], &sp[1], &ep[0], &ep[1], coins, 0, 1, 2,
3);
mlk_poly_getnoise_eta1122_4x(&sp.vec[0], &sp.vec[1], &ep.vec[0], &ep.vec[1],
coins, 0, 1, 2, 3);
mlk_poly_getnoise_eta2(&epp, coins, 4);
#elif MLKEM_K == 3
/*
* In this call, only the first three output buffers are needed.
* The last parameter is a dummy that's overwritten later.
*/
mlk_poly_getnoise_eta1_4x(&sp[0], &sp[1], &sp[2], &b[0], coins, 0, 1, 2,
0xFF);
mlk_poly_getnoise_eta1_4x(&sp.vec[0], &sp.vec[1], &sp.vec[2], &b.vec[0],
coins, 0, 1, 2, 0xFF);
/* The fourth output buffer in this call _is_ used. */
mlk_poly_getnoise_eta2_4x(&ep[0], &ep[1], &ep[2], &epp, coins, 3, 4, 5, 6);
mlk_poly_getnoise_eta2_4x(&ep.vec[0], &ep.vec[1], &ep.vec[2], &epp, coins, 3,
4, 5, 6);
#elif MLKEM_K == 4
mlk_poly_getnoise_eta1_4x(&sp[0], &sp[1], &sp[2], &sp[3], coins, 0, 1, 2, 3);
mlk_poly_getnoise_eta2_4x(&ep[0], &ep[1], &ep[2], &ep[3], coins, 4, 5, 6, 7);
mlk_poly_getnoise_eta1_4x(&sp.vec[0], &sp.vec[1], &sp.vec[2], &sp.vec[3],
coins, 0, 1, 2, 3);
mlk_poly_getnoise_eta2_4x(&ep.vec[0], &ep.vec[1], &ep.vec[2], &ep.vec[3],
coins, 4, 5, 6, 7);
mlk_poly_getnoise_eta2(&epp, coins, 8);
#endif
#endif /* MLKEM_K == 4 */

mlk_polyvec_ntt(sp);
mlk_polyvec_ntt(&sp);

mlk_polyvec_mulcache_compute(sp_cache, sp);
mlk_matvec_mul(b, at, sp, sp_cache);
mlk_polyvec_basemul_acc_montgomery_cached(&v, pkpv, sp, sp_cache);
mlk_polyvec_mulcache_compute(&sp_cache, &sp);
mlk_matvec_mul(&b, &at, &sp, &sp_cache);
mlk_polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache);

mlk_polyvec_invntt_tomont(b);
mlk_polyvec_invntt_tomont(&b);
mlk_poly_invntt_tomont(&v);

mlk_polyvec_add(b, ep);
mlk_polyvec_add(&b, &ep);
mlk_poly_add(&v, &epp);
mlk_poly_add(&v, &k);

mlk_polyvec_reduce(b);
mlk_polyvec_reduce(&b);
mlk_poly_reduce(&v);

mlk_pack_ciphertext(c, b, &v);
mlk_pack_ciphertext(c, &b, &v);

/* Specification: Partially implements
* @[FIPS203, Section 3.3, Destruction of intermediate values] */
Expand All @@ -475,7 +498,7 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
mlk_zeroize(&sp_cache, sizeof(sp_cache));
mlk_zeroize(&b, sizeof(b));
mlk_zeroize(&v, sizeof(v));
mlk_zeroize(at, sizeof(at));
mlk_zeroize(&at, sizeof(at));
mlk_zeroize(&k, sizeof(k));
mlk_zeroize(&ep, sizeof(ep));
mlk_zeroize(&epp, sizeof(epp));
Expand All @@ -493,12 +516,12 @@ void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
mlk_poly v, sb;
mlk_polyvec_mulcache b_cache;

mlk_unpack_ciphertext(b, &v, c);
mlk_unpack_sk(skpv, sk);
mlk_unpack_ciphertext(&b, &v, c);
mlk_unpack_sk(&skpv, sk);

mlk_polyvec_ntt(b);
mlk_polyvec_mulcache_compute(b_cache, b);
mlk_polyvec_basemul_acc_montgomery_cached(&sb, skpv, b, b_cache);
mlk_polyvec_ntt(&b);
mlk_polyvec_mulcache_compute(&b_cache, &b);
mlk_polyvec_basemul_acc_montgomery_cached(&sb, &skpv, &b, &b_cache);
mlk_poly_invntt_tomont(&sb);

mlk_poly_sub(&v, &sb);
Expand Down
6 changes: 3 additions & 3 deletions mlkem/src/indcpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
*
**************************************************/
MLK_INTERNAL_API
void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
int transposed)
__contract__(
requires(memory_no_alias(a, sizeof(mlk_polymat)))
requires(memory_no_alias(seed, MLKEM_SYMBYTES))
requires(transposed == 0 || transposed == 1)
assigns(object_whole(a))
ensures(forall(x, 0, MLKEM_K * MLKEM_K,
array_bound(a[x].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))
ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K,
array_bound(a->vec[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q))))
);

#define mlk_indcpa_keypair_derand MLK_NAMESPACE_K(indcpa_keypair_derand)
Expand Down
6 changes: 3 additions & 3 deletions mlkem/src/kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ int crypto_kem_check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES])
mlk_polyvec p;
uint8_t p_reencoded[MLKEM_POLYVECBYTES];

mlk_polyvec_frombytes(p, pk);
mlk_polyvec_reduce(p);
mlk_polyvec_tobytes(p_reencoded, p);
mlk_polyvec_frombytes(&p, pk);
mlk_polyvec_reduce(&p);
mlk_polyvec_tobytes(p_reencoded, &p);

/* We use a constant-time memcmp here to avoid having to
* declassify the PK before the PCT has succeeded. */
Expand Down
4 changes: 4 additions & 0 deletions mlkem/src/native/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ __contract__(
static MLK_INLINE void mlk_intt_native(int16_t p[MLKEM_N])
__contract__(
requires(memory_no_alias(p, sizeof(int16_t) * MLKEM_N))
requires(array_abs_bound(p, 0, MLKEM_N, INT16_MAX/2))
assigns(memory_slice(p, sizeof(int16_t) * MLKEM_N))
ensures(array_abs_bound(p, 0, MLKEM_N, MLK_INVNTT_BOUND))
);
Expand Down Expand Up @@ -244,6 +245,7 @@ __contract__(
requires(memory_no_alias(b_cache, sizeof(int16_t) * 2 * (MLKEM_N / 2)))
requires(array_bound(a, 0, 2 * MLKEM_N, 0, MLKEM_UINT12_LIMIT))
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
ensures(array_abs_bound(r, 0, MLKEM_N, INT16_MAX/2))
);
#endif /* MLK_CONFIG_MULTILEVEL_WITH_SHARED || MLKEM_K == 2 */

Expand Down Expand Up @@ -277,6 +279,7 @@ __contract__(
requires(memory_no_alias(b_cache, sizeof(int16_t) * 3 * (MLKEM_N / 2)))
requires(array_bound(a, 0, 3 * MLKEM_N, 0, MLKEM_UINT12_LIMIT))
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
ensures(array_abs_bound(r, 0, MLKEM_N, INT16_MAX/2))
);
#endif /* MLK_CONFIG_MULTILEVEL_WITH_SHARED || MLKEM_K == 3 */

Expand Down Expand Up @@ -310,6 +313,7 @@ __contract__(
requires(memory_no_alias(b_cache, sizeof(int16_t) * 4 * (MLKEM_N / 2)))
requires(array_bound(a, 0, 4 * MLKEM_N, 0, MLKEM_UINT12_LIMIT))
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
ensures(array_abs_bound(r, 0, MLKEM_N, INT16_MAX/2))
);
#endif /* MLK_CONFIG_MULTILEVEL_WITH_SHARED || MLKEM_K == 4 */
#endif /* MLK_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */
Expand Down
Loading
Loading