Skip to content

Commit 10bb481

Browse files
authored
Merge pull request #897 from pq-code-package/basemul_john
Switch mlkem_poly_basemul_acc_montgomery_cached_* proofs to integer specs
2 parents 081e65d + 44b0cab commit 10bb481

File tree

7 files changed

+553
-247
lines changed

7 files changed

+553
-247
lines changed

.github/workflows/hol_light.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,4 @@ jobs:
9898
gh_token: ${{ secrets.GITHUB_TOKEN }}
9999
nix-shell: 'hol_light'
100100
script: |
101-
tests hol_light -p ${{ matrix.proof.name }}
101+
tests hol_light -p ${{ matrix.proof.name }} --verbose

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ The functional correctness of various AArch64 assembly routines is established u
6767
- ML-KEM Arithmetic:
6868
* AArch64 forward NTT: [mlkem_ntt.S](proofs/hol_light/arm/mlkem/mlkem_ntt.S)
6969
* AArch64 inverse NTT: [mlkem_intt.S](proofs/hol_light/arm/mlkem/mlkem_intt.S)
70+
* AArch64 base multiplications: [mlkem_poly_basemul_acc_montgomery_cached_k2.S](proofs/hol_light/arm/mlkem/mlkem_poly_basemul_acc_montgomery_cached_k2.S) [mlkem_poly_basemul_acc_montgomery_cached_k3.S](proofs/hol_light/arm/mlkem/mlkem_poly_basemul_acc_montgomery_cached_k3.S) [mlkem_poly_basemul_acc_montgomery_cached_k4.S](proofs/hol_light/arm/mlkem/mlkem_poly_basemul_acc_montgomery_cached_k4.S)
7071
* AArch64 modular reduction: [mlkem_poly_reduce.S](proofs/hol_light/arm/mlkem/mlkem_poly_reduce.S)
7172
* AArch64 conversion to Montgomery form: [mlkem_poly_tomont.S](proofs/hol_light/arm/mlkem/mlkem_poly_tomont.S)
7273
* AArch64 'multiplication cache' computation: [mlkem_poly_mulcache_compute.S](proofs/hol_light/arm/mlkem/mlkem_poly_mulcache_compute.S)

proofs/hol_light/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ At present, this directory contains functional correctness proofs for the follow
2020
- ML-KEM Arithmetic:
2121
* AArch64 forward NTT: [mlkem_ntt.S](mlkem/mlkem_ntt.S)
2222
* AArch64 inverse NTT: [mlkem_intt.S](mlkem/mlkem_intt.S)
23+
* AArch64 base multiplications: [mlkem_poly_basemul_acc_montgomery_cached_k2.S](mlkem/mlkem_poly_basemul_acc_montgomery_cached_k2.S) [mlkem_poly_basemul_acc_montgomery_cached_k3.S](mlkem/mlkem_poly_basemul_acc_montgomery_cached_k3.S) [mlkem_poly_basemul_acc_montgomery_cached_k4.S](mlkem/mlkem_poly_basemul_acc_montgomery_cached_k4.S)
2324
* AArch64 conversion to Montgomery form: [mlkem_poly_tomont.S](mlkem/mlkem_poly_tomont.S)
2425
* AArch64 modular reduction: [mlkem_poly_reduce.S](mlkem/mlkem_poly_reduce.S)
2526
* AArch64 'multiplication cache' computation: [mlkem_poly_mulcache_compute.S](mlkem/mlkem_poly_mulcache_compute.S)
@@ -49,4 +50,3 @@ make -C proofs/hol_light/arm
4950
will build and run the proofs. Note that this make take hours even on powerful machines.
5051

5152
For convenience, you can also use `tests hol_light` which wraps the `make` invocation above; see `tests hol_light --help`.
52-

proofs/hol_light/arm/proofs/mlkem_poly_basemul_acc_montgomery_cached_k2.ml

Lines changed: 106 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -184,45 +184,29 @@ let poly_basemul_acc_montgomery_cached_k2_mc = define_assert_from_elf
184184
];;
185185

186186
let pmull = define
187-
`pmull (x0: 16 word) (x1 : 16 word) (y0 : 16 word) (y1 : 16 word) =
188-
word_add (word_mul ((word_sx x1) : 32 word) (word_sx y1))
189-
(word_mul ((word_sx x0) : 32 word) (word_sx y0))`;;
187+
`pmull (x0: int) (x1 : int) (y0 : int) (y1 : int) = x1 * y1 + x0 * y0`;;
190188

191189
let pmull_acc2 = define
192-
`pmull_acc2 (x00: 16 word) (x01 : 16 word) (y00 : 16 word) (y01 : 16 word)
193-
(x10: 16 word) (x11 : 16 word) (y10 : 16 word) (y11 : 16 word) =
194-
word_add (pmull x10 x11 y10 y11) (pmull x00 x01 y00 y01)`;;
195-
196-
let montred = define
197-
`montred (x : 32 word) =
198-
word_subword (
199-
word_add (
200-
word_mul (
201-
(word_sx : 16 word -> 32 word) (
202-
word_mul (
203-
word_subword x (0,16)
204-
) (word 3327)
205-
)
206-
)
207-
(word 3329)
208-
) x
209-
) (16, 16)`;;
190+
`pmull_acc2 (x00: int) (x01 : int) (y00 : int) (y01 : int)
191+
(x10: int) (x11 : int) (y10 : int) (y11 : int) =
192+
pmull x10 x11 y10 y11 + pmull x00 x01 y00 y01`;;
210193

211194
let pmul_acc2 = define
212-
`pmul_acc2 (x00: 16 word) (x01 : 16 word) (y00 : 16 word) (y01 : 16 word)
213-
(x10: 16 word) (x11 : 16 word) (y10 : 16 word) (y11 : 16 word) =
214-
montred (pmull_acc2 x00 x01 y00 y01 x10 x11 y10 y11)`;;
195+
`pmul_acc2 (x00: int) (x01 : int) (y00 : int) (y01 : int)
196+
(x10: int) (x11 : int) (y10 : int) (y11 : int) =
197+
(&(inverse_mod 3329 65536) *
198+
pmull_acc2 x00 x01 y00 y01 x10 x11 y10 y11) rem &3329`;;
215199

216-
let basemul2_even_raw = define
217-
`basemul2_even_raw x0 y0 y0t x1 y1 y1t = \i.
200+
let basemul2_even = define
201+
`basemul2_even x0 y0 y0t x1 y1 y1t = \i.
218202
pmul_acc2 (x0 (2 * i)) (x0 (2 * i + 1))
219203
(y0 (2 * i)) (y0t i)
220204
(x1 (2 * i)) (x1 (2 * i + 1))
221205
(y1 (2 * i)) (y1t i)
222206
`;;
223207

224-
let basemul2_odd_raw = define
225-
`basemul2_odd_raw x0 y0 x1 y1 = \i.
208+
let basemul2_odd = define
209+
`basemul2_odd x0 y0 x1 y1 = \i.
226210
pmul_acc2 (x0 (2 * i)) (x0 (2 * i + 1))
227211
(y0 (2 * i + 1)) (y0 (2 * i))
228212
(x1 (2 * i)) (x1 (2 * i + 1))
@@ -231,6 +215,70 @@ let basemul2_odd_raw = define
231215

232216
let poly_basemul_acc_montgomery_cached_k2_EXEC = ARM_MK_EXEC_RULE poly_basemul_acc_montgomery_cached_k2_mc;;
233217

218+
(* ------------------------------------------------------------------------- *)
219+
(* Hacky tweaking conversion to write away non-free state component reads. *)
220+
(* ------------------------------------------------------------------------- *)
221+
222+
let lemma = prove
223+
(`!base size s n.
224+
n + 2 <= size
225+
==> read(memory :> bytes16(word_add base (word n))) s =
226+
word((read (memory :> bytes(base,size)) s DIV 2 EXP (8 * n)))`,
227+
REPEAT STRIP_TAC THEN REWRITE_TAC[READ_COMPONENT_COMPOSE] THEN
228+
SPEC_TAC(`read memory s`,`m:int64->byte`) THEN GEN_TAC THEN
229+
REWRITE_TAC[READ_BYTES_DIV] THEN
230+
REWRITE_TAC[bytes16; READ_COMPONENT_COMPOSE; asword; through; read] THEN
231+
ONCE_REWRITE_TAC[GSYM WORD_MOD_SIZE] THEN REWRITE_TAC[DIMINDEX_16] THEN
232+
REWRITE_TAC[ARITH_RULE `16 = 8 * 2`; READ_BYTES_MOD] THEN
233+
ASM_SIMP_TAC[ARITH_RULE `n + 2 <= size ==> MIN (size - n) 2 = MIN 2 2`]);;
234+
235+
let BOUNDED_QUANT_READ_MEM = prove
236+
(`(!x base s.
237+
(!i. i < n
238+
==> read(memory :> bytes16(word_add base (word(2 * i)))) s =
239+
x i) <=>
240+
(!i. i < n
241+
==> word((read(memory :> bytes(base,2 * n)) s DIV 2 EXP (16 * i))) =
242+
x i)) /\
243+
(!x p base s.
244+
(!i. i < n
245+
==> (ival(read(memory :> bytes16(word_add base (word(2 * i)))) s) ==
246+
x i) (mod p)) <=>
247+
(!i. i < n
248+
==> (ival(word((read(memory :> bytes(base,2 * n)) s DIV 2 EXP (16 * i))):int16) ==
249+
x i) (mod p))) /\
250+
(!x p c base s.
251+
(!i. i < n /\ c i
252+
==> (ival(read(memory :> bytes16(word_add base (word(2 * i)))) s) ==
253+
x i) (mod p)) <=>
254+
(!i. i < n /\ c i
255+
==> (ival(word((read(memory :> bytes(base,2 * n)) s DIV 2 EXP (16 * i))):int16) ==
256+
x i) (mod p)))`,
257+
REPEAT STRIP_TAC THEN
258+
MP_TAC(ISPECL [`base:int64`; `2 * n`] lemma) THEN
259+
SIMP_TAC[ARITH_RULE `2 * i + 2 <= 2 * n <=> i < n`] THEN
260+
REWRITE_TAC[ARITH_RULE `8 * 2 * i = 16 * i`]);;
261+
262+
let even_odd_split_lemma = prove
263+
(`(!i. i < 128 ==> P (4 * i) i /\ Q(4 * i + 2) i) <=>
264+
(!i. i < 256 /\ EVEN i ==> P(2 * i) (i DIV 2)) /\
265+
(!i. i < 256 /\ ODD i ==> Q(2 * i) (i DIV 2))`,
266+
REWRITE_TAC[IMP_CONJ] THEN
267+
CONV_TAC(ONCE_DEPTH_CONV EXPAND_CASES_CONV) THEN
268+
CONV_TAC NUM_REDUCE_CONV THEN
269+
CONV_TAC CONJ_ACI_RULE);;
270+
271+
let TWEAK_CONV =
272+
REWRITE_CONV[even_odd_split_lemma] THENC
273+
GEN_REWRITE_CONV TOP_DEPTH_CONV [WORD_RULE
274+
`word_add x (word(a + b)) = word_add (word_add x (word a)) (word b)`] THENC
275+
REWRITE_CONV[BOUNDED_QUANT_READ_MEM] THENC
276+
NUM_REDUCE_CONV;;
277+
278+
(* ------------------------------------------------------------------------- *)
279+
(* Main proof. *)
280+
(* ------------------------------------------------------------------------- *)
281+
234282
let poly_basemul_acc_montgomery_cached_k2_GOAL = `forall srcA srcB srcBt dst x0 y0 y0t x1 y1 y1t pc.
235283
ALL (nonoverlapping (dst, 512))
236284
[(word pc, LENGTH poly_basemul_acc_montgomery_cached_k2_mc); (srcA, 1024); (srcB, 1024); (srcBt, 512)]
@@ -246,10 +294,13 @@ let poly_basemul_acc_montgomery_cached_k2_GOAL = `forall srcA srcB srcBt dst x0
246294
(!i. i < 256 ==> read(memory :> bytes16(word_add srcB (word (512 + 2 * i)))) s = y1 i) /\
247295
(!i. i < 128 ==> read(memory :> bytes16(word_add srcBt (word (256 + 2 * i)))) s = y1t i))
248296
(\s. read PC s = word (pc + 640) /\
249-
(!i. i < 128 ==> read(memory :> bytes16(word_add dst (word (4 * i)))) s =
250-
basemul2_even_raw x0 y0 y0t x1 y1 y1t i /\
251-
read(memory :> bytes16(word_add dst (word (4 * i + 2)))) s =
252-
basemul2_odd_raw x0 y0 x1 y1 i))
297+
((!i. i < 256 ==> abs(ival(x0 i)) <= &2 pow 12 /\ abs(ival(x1 i)) <= &2 pow 12)
298+
==> (!i. i < 128
299+
==> (ival(read(memory :> bytes16(word_add dst (word (4 * i)))) s) ==
300+
basemul2_even (ival o x0) (ival o y0) (ival o y0t)
301+
(ival o x1) (ival o y1) (ival o y1t) i) (mod &3329) /\
302+
(ival(read(memory :> bytes16(word_add dst (word (4 * i + 2)))) s) ==
303+
basemul2_odd (ival o x0) (ival o y0) (ival o x1) (ival o y1) i) (mod &3329))))
253304
// Register and memory footprint
254305
(MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,,
255306
MAYCHANGE [Q8; Q9; Q10; Q11; Q12; Q13; Q14; Q15] ,,
@@ -289,9 +340,10 @@ let poly_basemul_acc_montgomery_cached_k2_SPEC = prove(poly_basemul_acc_montgome
289340
This reduces the proof time *)
290341
REPEAT STRIP_TAC THEN
291342
MAP_EVERY (fun n -> ARM_STEPS_TAC poly_basemul_acc_montgomery_cached_k2_EXEC [n] THEN
292-
(SIMD_SIMPLIFY_TAC [pmull; GSYM WORD_ADD_ASSOC; pmull_acc2; montred; pmul_acc2])) (1--805) THEN
343+
(SIMD_SIMPLIFY_TAC [montred])) (1--805) THEN
293344

294-
ENSURES_FINAL_STATE_TAC THEN
345+
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
346+
CONV_TAC(LAND_CONV(ONCE_DEPTH_CONV EXPAND_CASES_CONV)) THEN STRIP_TAC THEN
295347
REPEAT CONJ_TAC THEN
296348
ASM_REWRITE_TAC [] THEN
297349

@@ -307,24 +359,24 @@ let poly_basemul_acc_montgomery_cached_k2_SPEC = prove(poly_basemul_acc_montgome
307359
CONV_TAC(ONCE_DEPTH_CONV let_CONV) THEN
308360
ASM_REWRITE_TAC [WORD_ADD_0] THEN
309361

310-
(* Forget all assumptions *)
311-
POP_ASSUM_LIST (K ALL_TAC) THEN
362+
(* Forget all state-related assumptions, but keep bounds at least *)
363+
DISCARD_STATE_TAC "s805" THEN
312364

313365
(* Split into one congruence goals per index. *)
314366
REPEAT CONJ_TAC THEN
315-
316-
REWRITE_TAC[basemul2_even_raw; basemul2_odd_raw] THEN
367+
REWRITE_TAC[basemul2_even; basemul2_odd;
368+
pmul_acc2; pmull_acc2; pmull; o_THM] THEN
317369
CONV_TAC(ONCE_DEPTH_CONV EL_CONV) THEN
318-
CONV_TAC(REPEATC (CHANGED_CONV (ONCE_DEPTH_CONV (NUM_MULT_CONV ORELSEC NUM_ADD_CONV)))) THEN
319-
REFL_TAC
320-
);;
370+
CONV_TAC NUM_REDUCE_CONV THEN
371+
372+
(* Solve the congruence goals *)
321373

322-
let TWEAK_CONV =
323-
ONCE_DEPTH_CONV let_CONV THENC
324-
ONCE_DEPTH_CONV EXPAND_CASES_CONV THENC
325-
ONCE_DEPTH_CONV NUM_MULT_CONV THENC
326-
ONCE_DEPTH_CONV NUM_ADD_CONV THENC
327-
PURE_REWRITE_CONV [WORD_ADD_0];;
374+
ASSUM_LIST((fun ths -> W(MP_TAC o CONJUNCT1 o GEN_CONGBOUND_RULE ths o
375+
rand o lhand o rator o snd))) THEN
376+
REWRITE_TAC[GSYM INT_REM_EQ] THEN CONV_TAC INT_REM_DOWN_CONV THEN
377+
MATCH_MP_TAC EQ_IMP THEN AP_TERM_TAC THEN AP_THM_TAC THEN AP_TERM_TAC THEN
378+
CONV_TAC INT_RING
379+
);;
328380

329381
let poly_basemul_acc_montgomery_cached_k2_SPEC' = prove(
330382
`forall srcA srcB srcBt dst x0 y0 y0t x1 y1 y1t pc returnaddress stackpointer.
@@ -351,10 +403,13 @@ let poly_basemul_acc_montgomery_cached_k2_SPEC' = prove(
351403
(!i. i < 128 ==> read(memory :> bytes16(word_add srcBt (word (256 + 2 * i)))) s = y1t i)
352404
)
353405
(\s. read PC s = returnaddress /\
354-
(!i. i < 128 ==> read(memory :> bytes16(word_add dst (word (4 * i)))) s =
355-
basemul2_even_raw x0 y0 y0t x1 y1 y1t i /\
356-
read(memory :> bytes16(word_add dst (word (4 * i + 2)))) s =
357-
basemul2_odd_raw x0 y0 x1 y1 i)
406+
((!i. i < 256 ==> abs(ival(x0 i)) <= &2 pow 12 /\ abs(ival(x1 i)) <= &2 pow 12)
407+
==> (!i. i < 128
408+
==> (ival(read(memory :> bytes16(word_add dst (word (4 * i)))) s) ==
409+
basemul2_even (ival o x0) (ival o y0) (ival o y0t)
410+
(ival o x1) (ival o y1) (ival o y1t) i) (mod &3329) /\
411+
(ival(read(memory :> bytes16(word_add dst (word (4 * i + 2)))) s) ==
412+
basemul2_odd (ival o x0) (ival o y0) (ival o x1) (ival o y1) i) (mod &3329)))
358413
)
359414
// Register and memory footprint
360415
(MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,,

0 commit comments

Comments
 (0)