|
| 1 | +use std::{ |
| 2 | + array, |
| 3 | + borrow::{Borrow, BorrowMut}, |
| 4 | +}; |
| 5 | + |
| 6 | +use openvm_circuit::arch::{AdapterAirContext, MinimalInstruction, VmAdapterInterface, VmCoreAir}; |
| 7 | +use openvm_circuit_primitives::utils::not; |
| 8 | +use openvm_circuit_primitives_derive::AlignedBorrow; |
| 9 | +use openvm_instructions::{instruction::Instruction, LocalOpcode}; |
| 10 | +use openvm_stark_backend::{ |
| 11 | + interaction::InteractionBuilder, |
| 12 | + p3_air::{AirBuilder, BaseAir}, |
| 13 | + p3_field::{Field, FieldAlgebra, PrimeField32}, |
| 14 | + rap::{BaseAirWithPublicValues, ColumnsAir}, |
| 15 | +}; |
| 16 | +use openvm_womir_transpiler::EqOpcode; |
| 17 | +use serde::{de::DeserializeOwned, Deserialize, Serialize}; |
| 18 | +use serde_big_array::BigArray; |
| 19 | +use struct_reflection::{StructReflection, StructReflectionHelper}; |
| 20 | +use strum::IntoEnumIterator; |
| 21 | + |
| 22 | +use crate::{AdapterRuntimeContextWom, VmCoreChipWom}; |
| 23 | +use openvm_circuit::arch::Result as ResultVm; |
| 24 | + |
| 25 | +#[repr(C)] |
| 26 | +#[derive(AlignedBorrow, StructReflection)] |
| 27 | +pub struct EqCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> { |
| 28 | + pub a: [T; NUM_LIMBS], |
| 29 | + pub b: [T; NUM_LIMBS], |
| 30 | + pub c: [T; NUM_LIMBS], |
| 31 | + |
| 32 | + pub cmp_result: T, |
| 33 | + |
| 34 | + pub opcode_eq_flag: T, |
| 35 | + pub opcode_ne_flag: T, |
| 36 | + |
| 37 | + pub diff_inv_marker: [T; NUM_LIMBS], |
| 38 | +} |
| 39 | + |
| 40 | +#[derive(Copy, Clone, Debug)] |
| 41 | +pub struct EqCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> { |
| 42 | + offset: usize, |
| 43 | +} |
| 44 | + |
| 45 | +impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F> |
| 46 | + for EqCoreAir<NUM_LIMBS, LIMB_BITS> |
| 47 | +{ |
| 48 | + fn width(&self) -> usize { |
| 49 | + EqCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width() |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> ColumnsAir<F> |
| 54 | + for EqCoreAir<NUM_LIMBS, LIMB_BITS> |
| 55 | +{ |
| 56 | + fn columns(&self) -> Option<Vec<String>> { |
| 57 | + EqCoreCols::<F, NUM_LIMBS, LIMB_BITS>::struct_reflection() |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F> |
| 62 | + for EqCoreAir<NUM_LIMBS, LIMB_BITS> |
| 63 | +{ |
| 64 | +} |
| 65 | + |
| 66 | +impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I> |
| 67 | + for EqCoreAir<NUM_LIMBS, LIMB_BITS> |
| 68 | +where |
| 69 | + AB: InteractionBuilder, |
| 70 | + I: VmAdapterInterface<AB::Expr>, |
| 71 | + I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>, |
| 72 | + I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>, |
| 73 | + I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>, |
| 74 | +{ |
| 75 | + fn eval( |
| 76 | + &self, |
| 77 | + builder: &mut AB, |
| 78 | + local: &[AB::Var], |
| 79 | + _from_pc: AB::Var, |
| 80 | + ) -> AdapterAirContext<AB::Expr, I> { |
| 81 | + let cols: &EqCoreCols<_, NUM_LIMBS, LIMB_BITS> = local.borrow(); |
| 82 | + let flags = [cols.opcode_eq_flag, cols.opcode_ne_flag]; |
| 83 | + |
| 84 | + let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| { |
| 85 | + builder.assert_bool(flag); |
| 86 | + acc + flag.into() |
| 87 | + }); |
| 88 | + builder.assert_bool(is_valid.clone()); |
| 89 | + builder.assert_bool(cols.cmp_result); |
| 90 | + |
| 91 | + let a = &cols.a; |
| 92 | + let b = &cols.b; |
| 93 | + let c = &cols.c; |
| 94 | + let inv_marker = &cols.diff_inv_marker; |
| 95 | + |
| 96 | + // 1 if cmp_result indicates b and c are equal, 0 otherwise |
| 97 | + let cmp_eq = |
| 98 | + cols.cmp_result * cols.opcode_eq_flag + not(cols.cmp_result) * cols.opcode_ne_flag; |
| 99 | + let mut sum = cmp_eq.clone(); |
| 100 | + |
| 101 | + // For EQ, inv_marker is used to check equality of b and c: |
| 102 | + // - If b == c, all inv_marker values must be 0 (sum = 0) |
| 103 | + // - If b != c, inv_marker contains 0s for all positions except ONE position i where b[i] != |
| 104 | + // c[i] |
| 105 | + // - At this position, inv_marker[i] contains the multiplicative inverse of (b[i] - c[i]) |
| 106 | + // - This ensures inv_marker[i] * (b[i] - c[i]) = 1, making the sum = 1 |
| 107 | + // Note: There might be multiple valid inv_marker if b != c. |
| 108 | + // But as long as the trace can provide at least one, that’s sufficient to prove b != c. |
| 109 | + // |
| 110 | + // Note: |
| 111 | + // - If cmp_eq == 0, then it is impossible to have sum != 0 if b == c. |
| 112 | + // - If cmp_eq == 1, then it is impossible for b[i] - c[i] == 0 to pass for all i if b != c. |
| 113 | + for i in 0..NUM_LIMBS { |
| 114 | + sum += (b[i] - c[i]) * inv_marker[i]; |
| 115 | + builder.assert_zero(cmp_eq.clone() * (b[i] - c[i])); |
| 116 | + } |
| 117 | + builder.when(is_valid.clone()).assert_one(sum); |
| 118 | + |
| 119 | + // a == cmp_result |
| 120 | + builder.assert_eq(a[0], cols.cmp_result); |
| 121 | + for limb in a.iter().skip(1) { |
| 122 | + builder.assert_zero(*limb); |
| 123 | + } |
| 124 | + |
| 125 | + let expected_opcode = flags |
| 126 | + .iter() |
| 127 | + .zip(EqOpcode::iter()) |
| 128 | + .fold(AB::Expr::ZERO, |acc, (flag, opcode)| { |
| 129 | + acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8) |
| 130 | + }) |
| 131 | + + AB::Expr::from_canonical_usize(self.offset); |
| 132 | + |
| 133 | + AdapterAirContext { |
| 134 | + to_pc: None, |
| 135 | + reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(), |
| 136 | + writes: [cols.a.map(Into::into)].into(), |
| 137 | + instruction: MinimalInstruction { |
| 138 | + is_valid, |
| 139 | + opcode: expected_opcode, |
| 140 | + } |
| 141 | + .into(), |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + fn start_offset(&self) -> usize { |
| 146 | + self.offset |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +#[repr(C)] |
| 151 | +#[derive(Clone, Debug, Serialize, Deserialize)] |
| 152 | +#[serde(bound = "T: Serialize + DeserializeOwned")] |
| 153 | +pub struct EqCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> { |
| 154 | + #[serde(with = "BigArray")] |
| 155 | + pub a: [T; NUM_LIMBS], |
| 156 | + #[serde(with = "BigArray")] |
| 157 | + pub b: [T; NUM_LIMBS], |
| 158 | + #[serde(with = "BigArray")] |
| 159 | + pub c: [T; NUM_LIMBS], |
| 160 | + pub cmp_result: T, |
| 161 | + pub diff_inv_val: T, |
| 162 | + pub diff_idx: usize, |
| 163 | + pub opcode: EqOpcode, |
| 164 | +} |
| 165 | + |
| 166 | +pub struct EqCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> { |
| 167 | + pub air: EqCoreAir<NUM_LIMBS, LIMB_BITS>, |
| 168 | +} |
| 169 | + |
| 170 | +impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> EqCoreChip<NUM_LIMBS, LIMB_BITS> { |
| 171 | + pub fn new(offset: usize) -> Self { |
| 172 | + Self { |
| 173 | + air: EqCoreAir { offset }, |
| 174 | + } |
| 175 | + } |
| 176 | +} |
| 177 | + |
| 178 | +impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize> |
| 179 | + VmCoreChipWom<F, I> for EqCoreChip<NUM_LIMBS, LIMB_BITS> |
| 180 | +where |
| 181 | + I::Reads: Into<[[F; NUM_LIMBS]; 2]>, |
| 182 | + I::Writes: From<[[F; NUM_LIMBS]; 1]>, |
| 183 | +{ |
| 184 | + type Record = EqCoreRecord<F, NUM_LIMBS, LIMB_BITS>; |
| 185 | + type Air = EqCoreAir<NUM_LIMBS, LIMB_BITS>; |
| 186 | + |
| 187 | + #[allow(clippy::type_complexity)] |
| 188 | + fn execute_instruction( |
| 189 | + &self, |
| 190 | + instruction: &Instruction<F>, |
| 191 | + _from_pc: u32, |
| 192 | + _from_frame: u32, |
| 193 | + reads: I::Reads, |
| 194 | + ) -> ResultVm<(AdapterRuntimeContextWom<F, I>, Self::Record)> { |
| 195 | + let Instruction { opcode, .. } = instruction; |
| 196 | + let eq_opcode = EqOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); |
| 197 | + |
| 198 | + let data: [[F; NUM_LIMBS]; 2] = reads.into(); |
| 199 | + let b = data[0].map(|x| x.as_canonical_u32()); |
| 200 | + let c = data[1].map(|y| y.as_canonical_u32()); |
| 201 | + let (cmp_result, diff_idx, diff_inv_val) = run_eq::<F, NUM_LIMBS>(eq_opcode, &b, &c); |
| 202 | + let mut a: [F; NUM_LIMBS] = [F::ZERO; NUM_LIMBS]; |
| 203 | + a[0] = F::from_bool(cmp_result); |
| 204 | + |
| 205 | + let output = AdapterRuntimeContextWom { |
| 206 | + to_pc: None, |
| 207 | + to_fp: None, |
| 208 | + writes: [a].into(), |
| 209 | + }; |
| 210 | + |
| 211 | + let record = EqCoreRecord { |
| 212 | + opcode: eq_opcode, |
| 213 | + a, |
| 214 | + b: data[0], |
| 215 | + c: data[1], |
| 216 | + cmp_result: F::from_bool(cmp_result), |
| 217 | + diff_idx, |
| 218 | + diff_inv_val, |
| 219 | + }; |
| 220 | + |
| 221 | + Ok((output, record)) |
| 222 | + } |
| 223 | + |
| 224 | + fn get_opcode_name(&self, opcode: usize) -> String { |
| 225 | + format!("{:?}", EqOpcode::from_usize(opcode - self.air.offset)) |
| 226 | + } |
| 227 | + |
| 228 | + fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { |
| 229 | + let row_slice: &mut EqCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); |
| 230 | + row_slice.a = record.a; |
| 231 | + row_slice.b = record.b; |
| 232 | + row_slice.c = record.c; |
| 233 | + row_slice.cmp_result = record.cmp_result; |
| 234 | + row_slice.opcode_eq_flag = F::from_bool(record.opcode == EqOpcode::EQ); |
| 235 | + row_slice.opcode_ne_flag = F::from_bool(record.opcode == EqOpcode::NEQ); |
| 236 | + row_slice.diff_inv_marker = array::from_fn(|i| { |
| 237 | + if i == record.diff_idx { |
| 238 | + record.diff_inv_val |
| 239 | + } else { |
| 240 | + F::ZERO |
| 241 | + } |
| 242 | + }); |
| 243 | + } |
| 244 | + |
| 245 | + fn air(&self) -> &Self::Air { |
| 246 | + &self.air |
| 247 | + } |
| 248 | +} |
| 249 | + |
| 250 | +// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) |
| 251 | +pub(super) fn run_eq<F: PrimeField32, const NUM_LIMBS: usize>( |
| 252 | + local_opcode: EqOpcode, |
| 253 | + x: &[u32; NUM_LIMBS], |
| 254 | + y: &[u32; NUM_LIMBS], |
| 255 | +) -> (bool, usize, F) { |
| 256 | + for i in 0..NUM_LIMBS { |
| 257 | + if x[i] != y[i] { |
| 258 | + return ( |
| 259 | + local_opcode == EqOpcode::NEQ, |
| 260 | + i, |
| 261 | + (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])).inverse(), |
| 262 | + ); |
| 263 | + } |
| 264 | + } |
| 265 | + (local_opcode == EqOpcode::EQ, 0, F::ZERO) |
| 266 | +} |
0 commit comments