Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions cryptography/bls12_381/benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate_crypto_internal_eth_kzg_bls12_381::{
batch_inversion,
ff::Field,
fixed_base_msm::{FixedBaseMSM, UsePrecomp},
fixed_base_msm::FixedBaseMSMPrecompBLST,
fixed_base_msm_window::FixedBaseMSMPrecompWindow,
g1_batch_normalize, g2_batch_normalize,
group::Group,
lincomb::{g1_lincomb, g1_lincomb_unsafe, g2_lincomb, g2_lincomb_unsafe},
Expand All @@ -28,12 +29,18 @@ pub fn fixed_base_msm(c: &mut Criterion) {
.into_iter()
.map(|p| p.into())
.collect();
let fbm = FixedBaseMSM::new(generators, UsePrecomp::Yes { width: 8 });
let scalars: Vec<_> = random_scalars(length);

c.bench_function("bls12_381 fixed_base_msm length=64 width=8", |b| {
let fbm = FixedBaseMSMPrecompBLST::new(generators.clone(), 8);
let scalars: Vec<_> = random_scalars(length);
c.bench_function("bls12_381 fixed_base_msm length=64 width=8 (blst)", |b| {
b.iter(|| fbm.msm(scalars.clone()))
});

let fbm = FixedBaseMSMPrecompWindow::new(&generators, 8);
let scalars: Vec<_> = random_scalars(length);
c.bench_function("bls12_381 fixed_base_msm length=64 width=8 (rust)", |b| {
b.iter(|| fbm.msm(&scalars))
});
}

pub fn bench_msm(c: &mut Criterion) {
Expand Down
96 changes: 96 additions & 0 deletions cryptography/bls12_381/src/booth_encoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::ops::Neg;

// Code was taken from: https://github.com/privacy-scaling-explorations/halo2curves/blob/b753a832e92d5c86c5c997327a9cf9de86a18851/src/msm.rs#L13
pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
// and remembering them as in classic signed digit encoding

let skip_bits = (window_index * window_size).saturating_sub(1);
let skip_bytes = skip_bits / 8;

// fill into a u32
let mut v: [u8; 4] = [0; 4];
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
*dst = *src
}
let mut tmp = u32::from_le_bytes(v);

// pad with one 0 if slicing the least significant window
if window_index == 0 {
tmp <<= 1;
}

// remove further bits
tmp >>= skip_bits - (skip_bytes * 8);
// apply the booth window
tmp &= (1 << (window_size + 1)) - 1;

let sign = tmp & (1 << window_size) == 0;

// div ceil by 2
tmp = (tmp + 1) >> 1;

// find the booth action index
if sign {
tmp as i32
} else {
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
}
}

#[cfg(test)]
mod tests {
use std::ops::Neg;

use super::get_booth_index;
use crate::G1Point;
use blstrs::{G1Projective, Scalar};
use ff::{Field, PrimeField};

#[test]
fn smoke_scalar_mul() {
use group::prime::PrimeCurveAffine;
let gen = G1Point::generator();
let s = -Scalar::ONE;

let res = gen * s;

let got = mul(&s, &gen, 4);

assert_eq!(G1Point::from(res), got)
}

fn mul(scalar: &Scalar, point: &G1Point, window: usize) -> G1Point {
let u = scalar.to_bytes_le();
let n = Scalar::NUM_BITS as usize / window + 1;

let table = (0..=1 << (window - 1))
.map(|i| point * Scalar::from(i as u64))
.collect::<Vec<_>>();

let mut acc: G1Projective = G1Point::default().into();
for i in (0..n).rev() {
for _ in 0..window {
acc = acc + acc;
}

let idx = get_booth_index(i as usize, window, u.as_ref());

if idx.is_negative() {
acc += table[idx.unsigned_abs() as usize].neg();
}
if idx.is_positive() {
acc += table[idx.unsigned_abs() as usize];
}
}

acc.into()
}
}
18 changes: 9 additions & 9 deletions cryptography/bls12_381/src/fixed_base_msm.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{G1Projective, Scalar};
use crate::{fixed_base_msm_window::FixedBaseMSMPrecompWindow, G1Projective, Scalar};
use blstrs::{Fp, G1Affine};

/// FixedBaseMSMPrecomp computes a multi scalar multiplication using pre-computations.
///
/// It uses batch addition to amortize the cost of adding multiple points together.
#[derive(Debug)]
pub struct FixedBaseMSMPrecomp {
pub struct FixedBaseMSMPrecompBLST {
table: Vec<blst::blst_p1_affine>,
wbits: usize,
num_points: usize,
Expand All @@ -27,23 +27,23 @@ pub enum UsePrecomp {
/// of memory.
#[derive(Debug)]
pub enum FixedBaseMSM {
Precomp(FixedBaseMSMPrecomp),
Precomp(FixedBaseMSMPrecompWindow),
NoPrecomp(Vec<G1Affine>),
}

impl FixedBaseMSM {
pub fn new(generators: Vec<G1Affine>, use_precomp: UsePrecomp) -> Self {
match use_precomp {
UsePrecomp::Yes { width } => {
FixedBaseMSM::Precomp(FixedBaseMSMPrecomp::new(generators, width))
FixedBaseMSM::Precomp(FixedBaseMSMPrecompWindow::new(&generators, width))
}
UsePrecomp::No => FixedBaseMSM::NoPrecomp(generators),
}
}

pub fn msm(&self, scalars: Vec<Scalar>) -> G1Projective {
match self {
FixedBaseMSM::Precomp(precomp) => precomp.msm(scalars),
FixedBaseMSM::Precomp(precomp) => precomp.msm(&scalars),
FixedBaseMSM::NoPrecomp(generators) => {
use crate::lincomb::g1_lincomb;
g1_lincomb(generators, &scalars)
Expand All @@ -53,7 +53,7 @@ impl FixedBaseMSM {
}
}

impl FixedBaseMSMPrecomp {
impl FixedBaseMSMPrecompBLST {
pub fn new(generators_affine: Vec<G1Affine>, wbits: usize) -> Self {
let num_points = generators_affine.len();
let table_size_bytes =
Expand All @@ -74,7 +74,7 @@ impl FixedBaseMSMPrecomp {

let scratch_space_size = unsafe { blst::blst_p1s_mult_wbits_scratch_sizeof(num_points) };

FixedBaseMSMPrecomp {
FixedBaseMSMPrecompBLST {
table,
wbits,
num_points,
Expand Down Expand Up @@ -120,7 +120,7 @@ impl FixedBaseMSMPrecomp {

#[cfg(test)]
mod tests {
use super::{FixedBaseMSMPrecomp, UsePrecomp};
use super::{FixedBaseMSMPrecompBLST, UsePrecomp};
use crate::{fixed_base_msm::FixedBaseMSM, G1Projective, Scalar};
use ff::Field;
use group::Group;
Expand Down Expand Up @@ -158,7 +158,7 @@ mod tests {
let generators: Vec<_> = (0..length)
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
.collect();
let fbm = FixedBaseMSMPrecomp::new(generators, 8);
let fbm = FixedBaseMSMPrecompBLST::new(generators, 8);
for val in fbm.table.into_iter() {
let is_inf =
unsafe { blst::blst_p1_affine_is_inf(&val as *const blst::blst_p1_affine) };
Expand Down
143 changes: 143 additions & 0 deletions cryptography/bls12_381/src/fixed_base_msm_window.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use crate::{
batch_add::multi_batch_addition_binary_tree_stride, booth_encoding::get_booth_index,
g1_batch_normalize, G1Projective, Scalar,
};
use blstrs::G1Affine;
use ff::PrimeField;
use group::Group;

// Note: This is the same strategy that blst uses
#[derive(Debug)]
pub struct FixedBaseMSMPrecompWindow {
table: Vec<Vec<G1Affine>>,
wbits: usize,
}

impl FixedBaseMSMPrecompWindow {
pub fn new(points: &[G1Affine], wbits: usize) -> Self {
// For every point `P`, wbits indicates that we should compute
// 1 * P, ..., (2^{wbits} - 1) * P
//
// The total amount of memory is roughly (numPoints * 2^{wbits} - 1)
// where each point is 64 bytes.
//
let precomputed_points: Vec<_> = points
.into_iter()
.map(|point| Self::precompute_points(wbits, *point))
.collect();

Self {
table: precomputed_points,
wbits,
}
}
// Given a point, we precompute P,..., (2^{w-1}-1) * P
fn precompute_points(wbits: usize, point: G1Affine) -> Vec<G1Affine> {
let mut lookup_table = Vec::with_capacity(1 << (wbits - 1));

// Convert to projective for faster operations
let mut current = G1Projective::from(point);

// Compute and store multiples
for _ in 0..(1 << (wbits - 1)) {
lookup_table.push(current);
current += point;
}

g1_batch_normalize(&lookup_table)
}

pub fn msm(&self, scalars: &[Scalar]) -> G1Projective {
let scalars_bytes: Vec<_> = scalars.iter().map(|a| a.to_bytes_le()).collect();
let number_of_windows = Scalar::NUM_BITS as usize / self.wbits + 1;

let mut windows_of_points = vec![Vec::with_capacity(scalars.len()); number_of_windows];

for window_idx in 0..number_of_windows {
for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() {
let sub_table = &self.table[scalar_idx];
let point_idx = get_booth_index(window_idx, self.wbits, scalar_bytes.as_ref());

if point_idx == 0 {
continue;
}
let sign = point_idx.is_positive();
let point_idx = point_idx.unsigned_abs() as usize - 1;
let mut point = sub_table[point_idx];
if !sign {
point = -point;
}

windows_of_points[window_idx].push(point);
}
}

let accumulated_points = multi_batch_addition_binary_tree_stride(windows_of_points);

// Now accumulate the windows by doubling wbits times
let mut result: G1Projective = *accumulated_points.last().unwrap();
for point in accumulated_points.into_iter().rev().skip(1) {
// Double the result 'wbits' times
for _ in 0..self.wbits {
result = result.double();
}
// Add the accumulated point for this window
result += point;
}

result
}
}

#[cfg(test)]
mod tests {
use super::*;
use ff::Field;
use group::prime::PrimeCurveAffine;

#[test]
fn precomp_lookup_table() {
use group::Group;
let lookup_table = FixedBaseMSMPrecompWindow::precompute_points(7, G1Affine::generator());

for i in 1..lookup_table.len() {
let expected = G1Projective::generator() * Scalar::from((i + 1) as u64);
assert_eq!(lookup_table[i], expected.into(),)
}
}

#[test]
fn msm_blst_precomp() {
let length = 64;
let generators: Vec<_> = (0..length)
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
.collect();
let scalars: Vec<_> = (0..length)
.map(|_| Scalar::random(&mut rand::thread_rng()))
.collect();

let res = crate::lincomb::g1_lincomb(&generators, &scalars)
.expect("number of generators and number of scalars is equal");

let fbm = FixedBaseMSMPrecompWindow::new(&generators, 7);
let result = fbm.msm(&scalars);

assert_eq!(res, result);
}

#[test]
fn bench_window_sizes_msm() {
let length = 64;
let generators: Vec<_> = (0..length)
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
.collect();
let scalars: Vec<_> = (0..length)
.map(|_| Scalar::random(&mut rand::thread_rng()))
.collect();

for i in 2..=14 {
let fbm = FixedBaseMSMPrecompWindow::new(&generators, i);
fbm.msm(&scalars);
}
}
}
2 changes: 2 additions & 0 deletions cryptography/bls12_381/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pub mod batch_add;
pub mod batch_inversion;
mod booth_encoding;
pub mod fixed_base_msm;
pub mod fixed_base_msm_window;
pub mod lincomb;

// Re-exporting the blstrs crate
Expand Down
Loading