Skip to content

Commit 95fab56

Browse files
committed
Merge branch 'main' into lte_gte
2 parents 8a803ab + 0290715 commit 95fab56

File tree

8 files changed

+409
-77
lines changed

8 files changed

+409
-77
lines changed

extensions/circuit/src/eq/core.rs

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
use std::{
2+
array,
3+
borrow::{Borrow, BorrowMut},
4+
};
5+
6+
use openvm_circuit::arch::{AdapterAirContext, MinimalInstruction, VmAdapterInterface, VmCoreAir};
7+
use openvm_circuit_primitives::utils::not;
8+
use openvm_circuit_primitives_derive::AlignedBorrow;
9+
use openvm_instructions::{instruction::Instruction, LocalOpcode};
10+
use openvm_stark_backend::{
11+
interaction::InteractionBuilder,
12+
p3_air::{AirBuilder, BaseAir},
13+
p3_field::{Field, FieldAlgebra, PrimeField32},
14+
rap::{BaseAirWithPublicValues, ColumnsAir},
15+
};
16+
use openvm_womir_transpiler::EqOpcode;
17+
use serde::{de::DeserializeOwned, Deserialize, Serialize};
18+
use serde_big_array::BigArray;
19+
use struct_reflection::{StructReflection, StructReflectionHelper};
20+
use strum::IntoEnumIterator;
21+
22+
use crate::{AdapterRuntimeContextWom, VmCoreChipWom};
23+
use openvm_circuit::arch::Result as ResultVm;
24+
25+
#[repr(C)]
26+
#[derive(AlignedBorrow, StructReflection)]
27+
pub struct EqCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
28+
pub a: [T; NUM_LIMBS],
29+
pub b: [T; NUM_LIMBS],
30+
pub c: [T; NUM_LIMBS],
31+
32+
pub cmp_result: T,
33+
34+
pub opcode_eq_flag: T,
35+
pub opcode_ne_flag: T,
36+
37+
pub diff_inv_marker: [T; NUM_LIMBS],
38+
}
39+
40+
#[derive(Copy, Clone, Debug)]
41+
pub struct EqCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
42+
offset: usize,
43+
}
44+
45+
impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
46+
for EqCoreAir<NUM_LIMBS, LIMB_BITS>
47+
{
48+
fn width(&self) -> usize {
49+
EqCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
50+
}
51+
}
52+
53+
impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> ColumnsAir<F>
54+
for EqCoreAir<NUM_LIMBS, LIMB_BITS>
55+
{
56+
fn columns(&self) -> Option<Vec<String>> {
57+
EqCoreCols::<F, NUM_LIMBS, LIMB_BITS>::struct_reflection()
58+
}
59+
}
60+
61+
impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
62+
for EqCoreAir<NUM_LIMBS, LIMB_BITS>
63+
{
64+
}
65+
66+
impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
67+
for EqCoreAir<NUM_LIMBS, LIMB_BITS>
68+
where
69+
AB: InteractionBuilder,
70+
I: VmAdapterInterface<AB::Expr>,
71+
I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
72+
I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
73+
I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
74+
{
75+
fn eval(
76+
&self,
77+
builder: &mut AB,
78+
local: &[AB::Var],
79+
_from_pc: AB::Var,
80+
) -> AdapterAirContext<AB::Expr, I> {
81+
let cols: &EqCoreCols<_, NUM_LIMBS, LIMB_BITS> = local.borrow();
82+
let flags = [cols.opcode_eq_flag, cols.opcode_ne_flag];
83+
84+
let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
85+
builder.assert_bool(flag);
86+
acc + flag.into()
87+
});
88+
builder.assert_bool(is_valid.clone());
89+
builder.assert_bool(cols.cmp_result);
90+
91+
let a = &cols.a;
92+
let b = &cols.b;
93+
let c = &cols.c;
94+
let inv_marker = &cols.diff_inv_marker;
95+
96+
// 1 if cmp_result indicates b and c are equal, 0 otherwise
97+
let cmp_eq =
98+
cols.cmp_result * cols.opcode_eq_flag + not(cols.cmp_result) * cols.opcode_ne_flag;
99+
let mut sum = cmp_eq.clone();
100+
101+
// For EQ, inv_marker is used to check equality of b and c:
102+
// - If b == c, all inv_marker values must be 0 (sum = 0)
103+
// - If b != c, inv_marker contains 0s for all positions except ONE position i where b[i] !=
104+
// c[i]
105+
// - At this position, inv_marker[i] contains the multiplicative inverse of (b[i] - c[i])
106+
// - This ensures inv_marker[i] * (b[i] - c[i]) = 1, making the sum = 1
107+
// Note: There might be multiple valid inv_marker if b != c.
108+
// But as long as the trace can provide at least one, that’s sufficient to prove b != c.
109+
//
110+
// Note:
111+
// - If cmp_eq == 0, then it is impossible to have sum != 0 if b == c.
112+
// - If cmp_eq == 1, then it is impossible for b[i] - c[i] == 0 to pass for all i if b != c.
113+
for i in 0..NUM_LIMBS {
114+
sum += (b[i] - c[i]) * inv_marker[i];
115+
builder.assert_zero(cmp_eq.clone() * (b[i] - c[i]));
116+
}
117+
builder.when(is_valid.clone()).assert_one(sum);
118+
119+
// a == cmp_result
120+
builder.assert_eq(a[0], cols.cmp_result);
121+
for limb in a.iter().skip(1) {
122+
builder.assert_zero(*limb);
123+
}
124+
125+
let expected_opcode = flags
126+
.iter()
127+
.zip(EqOpcode::iter())
128+
.fold(AB::Expr::ZERO, |acc, (flag, opcode)| {
129+
acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
130+
})
131+
+ AB::Expr::from_canonical_usize(self.offset);
132+
133+
AdapterAirContext {
134+
to_pc: None,
135+
reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
136+
writes: [cols.a.map(Into::into)].into(),
137+
instruction: MinimalInstruction {
138+
is_valid,
139+
opcode: expected_opcode,
140+
}
141+
.into(),
142+
}
143+
}
144+
145+
fn start_offset(&self) -> usize {
146+
self.offset
147+
}
148+
}
149+
150+
#[repr(C)]
151+
#[derive(Clone, Debug, Serialize, Deserialize)]
152+
#[serde(bound = "T: Serialize + DeserializeOwned")]
153+
pub struct EqCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
154+
#[serde(with = "BigArray")]
155+
pub a: [T; NUM_LIMBS],
156+
#[serde(with = "BigArray")]
157+
pub b: [T; NUM_LIMBS],
158+
#[serde(with = "BigArray")]
159+
pub c: [T; NUM_LIMBS],
160+
pub cmp_result: T,
161+
pub diff_inv_val: T,
162+
pub diff_idx: usize,
163+
pub opcode: EqOpcode,
164+
}
165+
166+
pub struct EqCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
167+
pub air: EqCoreAir<NUM_LIMBS, LIMB_BITS>,
168+
}
169+
170+
impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> EqCoreChip<NUM_LIMBS, LIMB_BITS> {
171+
pub fn new(offset: usize) -> Self {
172+
Self {
173+
air: EqCoreAir { offset },
174+
}
175+
}
176+
}
177+
178+
impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
179+
VmCoreChipWom<F, I> for EqCoreChip<NUM_LIMBS, LIMB_BITS>
180+
where
181+
I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
182+
I::Writes: From<[[F; NUM_LIMBS]; 1]>,
183+
{
184+
type Record = EqCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
185+
type Air = EqCoreAir<NUM_LIMBS, LIMB_BITS>;
186+
187+
#[allow(clippy::type_complexity)]
188+
fn execute_instruction(
189+
&self,
190+
instruction: &Instruction<F>,
191+
_from_pc: u32,
192+
_from_frame: u32,
193+
reads: I::Reads,
194+
) -> ResultVm<(AdapterRuntimeContextWom<F, I>, Self::Record)> {
195+
let Instruction { opcode, .. } = instruction;
196+
let eq_opcode = EqOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
197+
198+
let data: [[F; NUM_LIMBS]; 2] = reads.into();
199+
let b = data[0].map(|x| x.as_canonical_u32());
200+
let c = data[1].map(|y| y.as_canonical_u32());
201+
let (cmp_result, diff_idx, diff_inv_val) = run_eq::<F, NUM_LIMBS>(eq_opcode, &b, &c);
202+
let mut a: [F; NUM_LIMBS] = [F::ZERO; NUM_LIMBS];
203+
a[0] = F::from_bool(cmp_result);
204+
205+
let output = AdapterRuntimeContextWom {
206+
to_pc: None,
207+
to_fp: None,
208+
writes: [a].into(),
209+
};
210+
211+
let record = EqCoreRecord {
212+
opcode: eq_opcode,
213+
a,
214+
b: data[0],
215+
c: data[1],
216+
cmp_result: F::from_bool(cmp_result),
217+
diff_idx,
218+
diff_inv_val,
219+
};
220+
221+
Ok((output, record))
222+
}
223+
224+
fn get_opcode_name(&self, opcode: usize) -> String {
225+
format!("{:?}", EqOpcode::from_usize(opcode - self.air.offset))
226+
}
227+
228+
fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
229+
let row_slice: &mut EqCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
230+
row_slice.a = record.a;
231+
row_slice.b = record.b;
232+
row_slice.c = record.c;
233+
row_slice.cmp_result = record.cmp_result;
234+
row_slice.opcode_eq_flag = F::from_bool(record.opcode == EqOpcode::EQ);
235+
row_slice.opcode_ne_flag = F::from_bool(record.opcode == EqOpcode::NEQ);
236+
row_slice.diff_inv_marker = array::from_fn(|i| {
237+
if i == record.diff_idx {
238+
record.diff_inv_val
239+
} else {
240+
F::ZERO
241+
}
242+
});
243+
}
244+
245+
fn air(&self) -> &Self::Air {
246+
&self.air
247+
}
248+
}
249+
250+
// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx])
251+
pub(super) fn run_eq<F: PrimeField32, const NUM_LIMBS: usize>(
252+
local_opcode: EqOpcode,
253+
x: &[u32; NUM_LIMBS],
254+
y: &[u32; NUM_LIMBS],
255+
) -> (bool, usize, F) {
256+
for i in 0..NUM_LIMBS {
257+
if x[i] != y[i] {
258+
return (
259+
local_opcode == EqOpcode::NEQ,
260+
i,
261+
(F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])).inverse(),
262+
);
263+
}
264+
}
265+
(local_opcode == EqOpcode::EQ, 0, F::ZERO)
266+
}

extensions/circuit/src/eq/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
use super::adapters::{WomBaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
2+
use crate::VmChipWrapperWom;
3+
4+
mod core;
5+
pub use core::*;
6+
7+
pub type EqChipWom<F> = VmChipWrapperWom<
8+
F,
9+
WomBaseAluAdapterChip<F>,
10+
EqCoreChip<RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>,
11+
>;

extensions/circuit/src/extension.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
1717
use openvm_instructions::{LocalOpcode, PhantomDiscriminant};
1818
use openvm_stark_backend::p3_field::PrimeField32;
1919
use openvm_womir_transpiler::{
20-
AllocateFrameOpcode, BaseAluOpcode, ConstOpcodes, CopyIntoFrameOpcode, DivRemOpcode,
20+
AllocateFrameOpcode, BaseAluOpcode, ConstOpcodes, CopyIntoFrameOpcode, DivRemOpcode, EqOpcode,
2121
HintStoreOpcode, JaafOpcode, JumpOpcode, LessThanOpcode, LoadStoreOpcode, MulOpcode, Phantom,
2222
ShiftOpcode,
2323
};
@@ -114,6 +114,7 @@ pub enum WomirIExecutor<F: PrimeField32> {
114114
DivRem(WomDivRemChip<F>),
115115
Shift(Rv32ShiftChip<F>),
116116
LoadStore(Rv32LoadStoreChip<F>),
117+
Eq(EqChipWom<F>),
117118
// LoadSignExtend(Rv32LoadSignExtendChip<F>),
118119
// BranchEqual(Rv32BranchEqualChip<F>),
119120
// BranchLessThan(Rv32BranchLessThanChip<F>),
@@ -245,6 +246,20 @@ impl<F: PrimeField32> VmExtension<F> for WomirI {
245246
);
246247
inventory.add_executor(lt_chip, LessThanOpcode::iter().map(|x| x.global_opcode()))?;
247248

249+
let eq_chip = EqChipWom::new(
250+
WomBaseAluAdapterChip::new(
251+
execution_bus,
252+
program_bus,
253+
frame_bus,
254+
memory_bridge,
255+
bitwise_lu_chip.clone(),
256+
),
257+
EqCoreChip::new(EqOpcode::CLASS_OFFSET),
258+
offline_memory.clone(),
259+
shared_fp.clone(),
260+
);
261+
inventory.add_executor(eq_chip, EqOpcode::iter().map(|x| x.global_opcode()))?;
262+
248263
let mut hintstore_chip = HintStoreChip::new(
249264
execution_bus,
250265
frame_bus,

extensions/circuit/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod branch_lt;
77
mod consts;
88
mod copy_into_frame;
99
mod divrem;
10+
mod eq;
1011
mod hintstore;
1112
mod jaaf;
1213
mod jump;
@@ -23,6 +24,7 @@ pub use branch_lt::*;
2324
pub use consts::*;
2425
pub use copy_into_frame::*;
2526
pub use divrem::*;
27+
pub use eq::*;
2628
pub use hintstore::*;
2729
pub use jaaf::*;
2830
pub use jump::*;

extensions/transpiler/src/instructions.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,28 @@ pub enum LessThanOpcode {
7373
SLTU,
7474
}
7575

76+
#[derive(
77+
Copy,
78+
Clone,
79+
Debug,
80+
PartialEq,
81+
Eq,
82+
PartialOrd,
83+
Ord,
84+
EnumCount,
85+
EnumIter,
86+
FromRepr,
87+
LocalOpcode,
88+
Serialize,
89+
Deserialize,
90+
)]
91+
#[opcode_offset = 0x120c]
92+
#[repr(usize)]
93+
pub enum EqOpcode {
94+
EQ,
95+
NEQ,
96+
}
97+
7698
#[derive(
7799
Copy,
78100
Clone,

integration/src/instruction_builder.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use openvm_instructions::{instruction::Instruction, riscv, LocalOpcode, SystemOpcode, VmOpcode};
22
use openvm_stark_backend::p3_field::PrimeField32;
33
use openvm_womir_transpiler::{
4-
AllocateFrameOpcode, BaseAluOpcode, ConstOpcodes, CopyIntoFrameOpcode, DivRemOpcode,
4+
AllocateFrameOpcode, BaseAluOpcode, ConstOpcodes, CopyIntoFrameOpcode, DivRemOpcode, EqOpcode,
55
HintStoreOpcode, JaafOpcode, JumpOpcode, LessThanOpcode, LoadStoreOpcode, MulOpcode, Phantom,
66
ShiftOpcode,
77
};
@@ -141,6 +141,18 @@ pub fn gt_s<F: PrimeField32>(rd: usize, rs1: usize, rs2: usize) -> Instruction<F
141141
lt_s(rd, rs2, rs1)
142142
}
143143

144+
pub fn eq<F: PrimeField32>(rd: usize, rs1: usize, rs2: usize) -> Instruction<F> {
145+
instr_r(EqOpcode::EQ.global_opcode().as_usize(), rd, rs1, rs2)
146+
}
147+
148+
pub fn eqi<F: PrimeField32>(rd: usize, rs1: usize, imm: usize) -> Instruction<F> {
149+
instr_i(EqOpcode::EQ.global_opcode().as_usize(), rd, rs1, imm)
150+
}
151+
152+
pub fn neq<F: PrimeField32>(rd: usize, rs1: usize, rs2: usize) -> Instruction<F> {
153+
instr_r(EqOpcode::NEQ.global_opcode().as_usize(), rd, rs1, rs2)
154+
}
155+
144156
pub fn const_32_imm<F: PrimeField32>(
145157
target_reg: usize,
146158
imm_lo: u16,

0 commit comments

Comments
 (0)