Skip to content

Commit 904e82b

Browse files
committed
Modify finalize of DomainEvaluationAccumulator.
1 parent 9f7ffbc commit 904e82b

File tree

1 file changed

+26
-40
lines changed

1 file changed

+26
-40
lines changed

crates/stwo/src/prover/air/accumulation.rs

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -86,42 +86,26 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
8686
class = "ConstraintInterpolation"
8787
)
8888
.entered();
89-
let mut cur_poly: Option<SecureCirclePoly<B>> = None;
90-
let twiddles = B::precompute_twiddles(
91-
CanonicCoset::new(self.log_size())
92-
.circle_domain()
93-
.half_coset,
94-
);
9589

96-
for (log_size, values) in self.sub_accumulations.into_iter().enumerate().skip(1) {
97-
let Some(mut values) = values else {
98-
continue;
99-
};
100-
if let Some(prev_poly) = cur_poly {
101-
let eval = SecureColumnByCoords {
102-
columns: prev_poly.0.map(|c| {
103-
c.evaluate_with_twiddles(
104-
CanonicCoset::new(log_size as u32).circle_domain(),
105-
&twiddles,
106-
)
107-
.values
108-
}),
109-
};
110-
B::accumulate(&mut values, &eval);
111-
}
112-
cur_poly = Some(SecureCirclePoly(values.columns.map(|c| {
90+
let sub_accumulations = self.sub_accumulations.into_iter().flatten().collect_vec();
91+
let lifted_accumulation = B::lift_and_accumulate(sub_accumulations);
92+
93+
if let Some(eval) = lifted_accumulation {
94+
let twiddles =
95+
B::precompute_twiddles(CanonicCoset::new(log_size).circle_domain().half_coset);
96+
97+
SecureCirclePoly(eval.columns.map(|c| {
11398
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
114-
CanonicCoset::new(log_size as u32).circle_domain(),
99+
CanonicCoset::new(log_size).circle_domain(),
115100
c,
116101
)
117102
.interpolate_with_twiddles(&twiddles)
118-
})));
119-
}
120-
cur_poly.unwrap_or_else(|| {
103+
}))
104+
} else {
121105
SecureCirclePoly(std::array::from_fn(|_| {
122106
CirclePoly::new(Col::<B, BaseField>::zeros(1 << log_size))
123107
}))
124-
})
108+
}
125109
}
126110
}
127111

@@ -167,17 +151,15 @@ mod tests {
167151

168152
use super::*;
169153
use crate::core::circle::CirclePoint;
170-
use crate::core::fields::m31::{M31, P};
154+
use crate::core::fields::m31::M31;
171155
use crate::prover::backend::cpu::CpuCircleEvaluation;
172156
use crate::qm31;
173157

174158
#[test]
175-
fn test_domain_evaluation_accumulator() {
176-
// Generate a vector of random sizes with a constant seed.
159+
fn test_domain_evaluation_accumulator_lifted() {
177160
let mut rng = SmallRng::seed_from_u64(0);
178161
const LOG_SIZE_MIN: u32 = 4;
179162
const LOG_SIZE_BOUND: u32 = 10;
180-
const MASK: u32 = P;
181163
let mut log_sizes = (0..100)
182164
.map(|_| rng.gen_range(LOG_SIZE_MIN..LOG_SIZE_BOUND))
183165
.collect::<Vec<_>>();
@@ -188,16 +170,15 @@ mod tests {
188170
.iter()
189171
.map(|log_size| {
190172
(0..(1 << *log_size))
191-
.map(|_| M31::from_u32_unchecked(rng.gen::<u32>() & MASK))
173+
.map(|_| M31::from(rng.gen::<u32>()))
192174
.collect::<Vec<_>>()
193175
})
194176
.collect::<Vec<_>>();
195177
let alpha = qm31!(2, 3, 4, 5);
196178

197-
// Use accumulator.
198179
let mut accumulator = DomainEvaluationAccumulator::<CpuBackend>::new(
199180
alpha,
200-
LOG_SIZE_BOUND,
181+
LOG_SIZE_BOUND - 1,
201182
evaluations.len(),
202183
);
203184
let n_cols_per_size: [(u32, usize); (LOG_SIZE_BOUND - LOG_SIZE_MIN) as usize] =
@@ -210,6 +191,7 @@ mod tests {
210191
.count();
211192
(current_log_size, n_cols)
212193
});
194+
213195
let mut cols = accumulator.columns(n_cols_per_size);
214196
let mut eval_chunk_offset = 0;
215197
for (log_size, n_cols) in n_cols_per_size.iter() {
@@ -221,7 +203,6 @@ mod tests {
221203
if *log_size != *col_log_size {
222204
continue;
223205
}
224-
225206
// The random coefficient powers chunk is in regular order.
226207
let random_coeff_chunk =
227208
&cols[(log_size - LOG_SIZE_MIN) as usize].random_coeff_powers;
@@ -239,13 +220,18 @@ mod tests {
239220
let point = CirclePoint::<SecureField>::get_point(98989892);
240221
let accumulator_res = accumulator_poly.eval_at_point(point);
241222

242-
// Use direct computation.
223+
// Use direct computation: first interpolate each evaluation to obtain a polynomial,
224+
// evaluate its lift at `point`, and accumulate over the evaluations.
243225
let mut res = SecureField::default();
244226
for (log_size, values) in log_sizes.into_iter().zip(evaluations) {
245227
res = res * alpha
246-
+ CpuCircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), values)
247-
.interpolate()
248-
.eval_at_point(point);
228+
+ CpuCircleEvaluation::<BaseField, BitReversedOrder>::new(
229+
CanonicCoset::new(log_size).circle_domain(),
230+
values,
231+
)
232+
.interpolate()
233+
// The max log domain size is LOG_SIZE_BOUND - 1.
234+
.eval_at_point(point.repeated_double(LOG_SIZE_BOUND - 1 - log_size));
249235
}
250236

251237
assert_eq!(accumulator_res, res);

0 commit comments

Comments
 (0)