Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 161 additions & 77 deletions internal/domain/fft.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package domain

import (
"math/big"
"math/bits"
"slices"
"sync"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Expand All @@ -21,7 +24,9 @@ import (
// The elements are returned in order as opposed to being returned in
// bit-reversed order.
func (domain *Domain) FftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
return fftG1(values, domain.Generator)
fftVals := slices.Clone(values)
fftG1(fftVals, domain.Generator)
return fftVals
}

// Computes an IFFT(Inverse Fast Fourier Transform) of the G1 elements.
Expand All @@ -32,7 +37,8 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
var invDomainBI big.Int
domain.CardinalityInv.BigInt(&invDomainBI)

inverseFFT := fftG1(values, domain.GeneratorInv)
inverseFFT := slices.Clone(values)
fftG1(inverseFFT, domain.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -47,61 +53,101 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
// This is the actual implementation of [FftG1] with the same convention.
// That is, the returned slice is in "normal", rather than bit-reversed order.
// We assert that values is a slice of length n==2^i and nthRootOfUnity is a primitive n'th root of unity.
func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
n := len(values)
if n == 1 {
return values
// func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// // split the input slice into a (copy of) the values at even resp. odd indices.
// even, odd := takeEvenOdd(values)

// // perform FFT recursively on those parts.
// fftEven := fftG1(even, generatorSquared)
// fftOdd := fftG1(odd, generatorSquared)

// // combine them to get the result
// // - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// // - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// // where w is a n'th primitive root of unity.
// inputPoint := fr.One()
// evaluations := make([]bls12381.G1Affine, n)
// for k := 0; k < n/2; k++ {
// var tmp bls12381.G1Affine

// var inputPointBI big.Int
// inputPoint.BigInt(&inputPointBI)

// if inputPoint.IsOne() {
// tmp.Set(&fftOdd[k])
// } else {
// tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
// }

// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// // we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// // At any rate, we don't really need to optimize here.
// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }

// return evaluations
// }

func fftG1(a []bls12381.G1Affine, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))
if n != 1<<logN {
panic("input size must be a power of 2")
}

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// split the input slice into a (copy of) the values at even resp. odd indices.
even, odd := takeEvenOdd(values)

// perform FFT recursively on those parts.
fftEven := fftG1(even, generatorSquared)
fftOdd := fftG1(odd, generatorSquared)

// combine them to get the result
// - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// where w is a n'th primitive root of unity.
inputPoint := fr.One()
evaluations := make([]bls12381.G1Affine, n)
for k := 0; k < n/2; k++ {
var tmp bls12381.G1Affine

var inputPointBI big.Int
inputPoint.BigInt(&inputPointBI)

if inputPoint.IsOne() {
tmp.Set(&fftOdd[k])
} else {
tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

var wg sync.WaitGroup
for k := uint(0); k < n; k += m {
wg.Add(1)
go func(k uint) {
defer wg.Done()
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
var t bls12381.G1Affine
var bi big.Int
t.ScalarMultiplication(&a[k+j+halfM], w.BigInt(&bi))
u := a[k+j]
a[k+j].Add(&u, &t)
a[k+j+halfM].Sub(&u, &t)
w.Mul(w, wm)
}
}(k)
}

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// At any rate, we don't really need to optimize here.
inputPoint.Mul(&inputPoint, &nthRootOfUnity)
wg.Wait()
}

return evaluations
}

func (d *Domain) FftFr(values []fr.Element) []fr.Element {
return fftFr(values, d.Generator)
fftVals := slices.Clone(values)
fftFr(fftVals, d.Generator)
return fftVals
}

func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
var invDomain fr.Element
invDomain.SetInt64(int64(len(values)))
invDomain.Inverse(&invDomain)

inverseFFT := fftFr(values, d.GeneratorInv)
inverseFFT := slices.Clone(values)
fftFr(inverseFFT, d.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -110,34 +156,72 @@ func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
return inverseFFT
}

func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
n := len(values)
if n == 1 {
return values
func log2PowerOf2(n uint64) uint {
if n == 0 || (n&(n-1)) != 0 {
panic("Input must be a power of 2 and not zero")
}

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

even, odd := takeEvenOdd(values)

fftEven := fftFr(even, generatorSquared)
fftOdd := fftFr(odd, generatorSquared)
return uint(bits.TrailingZeros64(n))
}

inputPoint := fr.One()
evaluations := make([]fr.Element, n)
for k := 0; k < n/2; k++ {
var tmp fr.Element
tmp.Mul(&inputPoint, &fftOdd[k])
func fftFr(a []fr.Element, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)
if n != 1<<logN {
panic("input size must be a power of 2")
}

inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

for k := uint(0); k < n; k += m {
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
t := new(fr.Element).Mul(&a[k+j+halfM], w)
u := a[k+j]
a[k+j].Add(&u, t)
a[k+j+halfM].Sub(&u, t)
w.Mul(w, wm)
}
}
}
return evaluations
}

// func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// even, odd := takeEvenOdd(values)

// fftEven := fftFr(even, generatorSquared)
// fftOdd := fftFr(odd, generatorSquared)

// inputPoint := fr.One()
// evaluations := make([]fr.Element, n)
// for k := 0; k < n/2; k++ {
// var tmp fr.Element
// tmp.Mul(&inputPoint, &fftOdd[k])

// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }
// return evaluations
// }

// takeEvenOdd Takes a slice and return two slices
// The first slice contains (a copy of) all of the elements
// at even indices, the second slice contains
Expand All @@ -146,17 +230,17 @@ func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
// We assume that the length of the given values slice is even
// so the returned arrays will be the same length.
// This is the case for a radix-2 FFT
func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
n := len(values)
even := make([]T, 0, n/2)
odd := make([]T, 0, n/2)
for i := 0; i < n; i++ {
if i%2 == 0 {
even = append(even, values[i])
} else {
odd = append(odd, values[i])
}
}

return even, odd
}
// func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
// n := len(values)
// even := make([]T, 0, n/2)
// odd := make([]T, 0, n/2)
// for i := 0; i < n; i++ {
// if i%2 == 0 {
// even = append(even, values[i])
// } else {
// odd = append(odd, values[i])
// }
// }

// return even, odd
// }
21 changes: 20 additions & 1 deletion internal/kzg/kzg_prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,34 @@ package kzg
import (
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/crate-crypto/go-eth-kzg/internal/domain"
"github.com/crate-crypto/go-eth-kzg/internal/poly"
)

func Open(domain *domain.Domain, polyCoeff []fr.Element, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) {

outputPoint := poly.PolyEval(polyCoeff, evaluationPoint)

quotient := poly.DividePolyByXminusA(polyCoeff, evaluationPoint)

comm, err := ck.Commit(quotient, 0)
if err != nil {
return OpeningProof{}, nil
}

return OpeningProof{
QuotientCommitment: *comm,
InputPoint: evaluationPoint,
ClaimedValue: outputPoint,
}, nil
}

// Open verifies that a polynomial f(x) when evaluated at a point `z` is equal to `f(z)`
//
// numGoRoutines is used to configure the amount of concurrency needed. Setting this
// value to a negative number or 0 will make it default to the number of CPUs.
//
// [compute_kzg_proof_impl]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#compute_kzg_proof_impl
func Open(domain *domain.Domain, p Polynomial, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) {
func Open_(domain *domain.Domain, p Polynomial, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) {
if len(p) == 0 || len(p) > len(ck.G1) {
return OpeningProof{}, ErrInvalidPolynomialSize
}
Expand Down
6 changes: 3 additions & 3 deletions internal/kzg/kzg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (

func TestProofVerifySmoke(t *testing.T) {
domain := domain.NewDomain(4)
srs, _ := newLagrangeSRSInsecure(*domain, big.NewInt(1234))
srs, _ := newMonomialSRSInsecure(*domain, big.NewInt(1234))

// polynomial in lagrange form
// polynomial in monomial form
poly := Polynomial{fr.NewElement(2), fr.NewElement(3), fr.NewElement(4), fr.NewElement(5)}

comm, _ := srs.CommitKey.Commit(poly, 0)
Expand All @@ -29,7 +29,7 @@ func TestProofVerifySmoke(t *testing.T) {

func TestBatchVerifySmoke(t *testing.T) {
domain := domain.NewDomain(4)
srs, _ := newLagrangeSRSInsecure(*domain, big.NewInt(1234))
srs, _ := newMonomialSRSInsecure(*domain, big.NewInt(1234))

numProofs := 10
commitments := make([]Commitment, 0, numProofs)
Expand Down
12 changes: 5 additions & 7 deletions internal/poly/poly.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,13 @@ func equalPoly(a, b PolynomialCoeff) bool {
// PolyEval evaluates a polynomial f(x) at a point z, computing f(z).
// The polynomial is given in coefficient form, and `z` is denoted as inputPoint.
func PolyEval(poly PolynomialCoeff, inputPoint fr.Element) fr.Element {
result := fr.NewElement(0)

for i := len(poly) - 1; i >= 0; i-- {
tmp := fr.Element{}
tmp.Mul(&result, &inputPoint)
result.Add(&tmp, &poly[i])
res := poly[len(poly)-1]
for i := len(poly) - 2; i >= 0; i-- {
res.Mul(&res, &inputPoint)
res.Add(&res, &poly[i])
}

return result
return res
}

// DividePolyByXminusA computes f(x) / (x - a) and returns the quotient.
Expand Down
Loading