diff --git a/internal/domain/fft.go b/internal/domain/fft.go index 1037057..b42be08 100644 --- a/internal/domain/fft.go +++ b/internal/domain/fft.go @@ -2,6 +2,9 @@ package domain import ( "math/big" + "math/bits" + "slices" + "sync" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -21,7 +24,9 @@ import ( // The elements are returned in order as opposed to being returned in // bit-reversed order. func (domain *Domain) FftG1(values []bls12381.G1Affine) []bls12381.G1Affine { - return fftG1(values, domain.Generator) + fftVals := slices.Clone(values) + fftG1(fftVals, domain.Generator) + return fftVals } // Computes an IFFT(Inverse Fast Fourier Transform) of the G1 elements. @@ -32,7 +37,8 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine { var invDomainBI big.Int domain.CardinalityInv.BigInt(&invDomainBI) - inverseFFT := fftG1(values, domain.GeneratorInv) + inverseFFT := slices.Clone(values) + fftG1(inverseFFT, domain.GeneratorInv) // scale by the inverse of the domain size for i := 0; i < len(inverseFFT); i++ { @@ -47,53 +53,92 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine { // This is the actual implementation of [FftG1] with the same convention. // That is, the returned slice is in "normal", rather than bit-reversed order. // We assert that values is a slice of length n==2^i and nthRootOfUnity is a primitive n'th root of unity. -func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine { - n := len(values) - if n == 1 { - return values +// func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine { +// n := len(values) +// if n == 1 { +// return values +// } + +// var generatorSquared fr.Element +// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2 + +// // split the input slice into a (copy of) the values at even resp. odd indices. +// even, odd := takeEvenOdd(values) + +// // perform FFT recursively on those parts. +// fftEven := fftG1(even, generatorSquared) +// fftOdd := fftG1(odd, generatorSquared) + +// // combine them to get the result +// // - evaluations[k] = fftEven[k] + w^k * fftOdd[k] +// // - evaluations[k] = fftEven[k] - w^k * fftOdd[k] +// // where w is a n'th primitive root of unity. +// inputPoint := fr.One() +// evaluations := make([]bls12381.G1Affine, n) +// for k := 0; k < n/2; k++ { +// var tmp bls12381.G1Affine + +// var inputPointBI big.Int +// inputPoint.BigInt(&inputPointBI) + +// if inputPoint.IsOne() { +// tmp.Set(&fftOdd[k]) +// } else { +// tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI) +// } + +// evaluations[k].Add(&fftEven[k], &tmp) +// evaluations[k+n/2].Sub(&fftEven[k], &tmp) + +// // we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain. +// // At any rate, we don't really need to optimize here. +// inputPoint.Mul(&inputPoint, &nthRootOfUnity) +// } + +// return evaluations +// } + +func fftG1(a []bls12381.G1Affine, omega fr.Element) { + n := uint(len(a)) + logN := log2PowerOf2(uint64(n)) + if n != 1<> 1 + wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m))) + + var wg sync.WaitGroup + for k := uint(0); k < n; k += m { + wg.Add(1) + go func(k uint) { + defer wg.Done() + w := new(fr.Element).SetOne() + for j := uint(0); j < halfM; j++ { + var t bls12381.G1Affine + var bi big.Int + t.ScalarMultiplication(&a[k+j+halfM], w.BigInt(&bi)) + u := a[k+j] + a[k+j].Add(&u, &t) + a[k+j+halfM].Sub(&u, &t) + w.Mul(w, wm) + } + }(k) } - - evaluations[k].Add(&fftEven[k], &tmp) - evaluations[k+n/2].Sub(&fftEven[k], &tmp) - - // we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain. - // At any rate, we don't really need to optimize here. - inputPoint.Mul(&inputPoint, &nthRootOfUnity) + wg.Wait() } - - return evaluations } func (d *Domain) FftFr(values []fr.Element) []fr.Element { - return fftFr(values, d.Generator) + fftVals := slices.Clone(values) + fftFr(fftVals, d.Generator) + return fftVals } func (d *Domain) IfftFr(values []fr.Element) []fr.Element { @@ -101,7 +146,8 @@ func (d *Domain) IfftFr(values []fr.Element) []fr.Element { invDomain.SetInt64(int64(len(values))) invDomain.Inverse(&invDomain) - inverseFFT := fftFr(values, d.GeneratorInv) + inverseFFT := slices.Clone(values) + fftFr(inverseFFT, d.GeneratorInv) // scale by the inverse of the domain size for i := 0; i < len(inverseFFT); i++ { @@ -110,34 +156,72 @@ func (d *Domain) IfftFr(values []fr.Element) []fr.Element { return inverseFFT } -func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element { - n := len(values) - if n == 1 { - return values +func log2PowerOf2(n uint64) uint { + if n == 0 || (n&(n-1)) != 0 { + panic("Input must be a power of 2 and not zero") } - var generatorSquared fr.Element - generatorSquared.Square(&nthRootOfUnity) // generator with order n/2 - - even, odd := takeEvenOdd(values) - - fftEven := fftFr(even, generatorSquared) - fftOdd := fftFr(odd, generatorSquared) + return uint(bits.TrailingZeros64(n)) +} - inputPoint := fr.One() - evaluations := make([]fr.Element, n) - for k := 0; k < n/2; k++ { - var tmp fr.Element - tmp.Mul(&inputPoint, &fftOdd[k]) +func fftFr(a []fr.Element, omega fr.Element) { + n := uint(len(a)) + logN := log2PowerOf2(uint64(n)) - evaluations[k].Add(&fftEven[k], &tmp) - evaluations[k+n/2].Sub(&fftEven[k], &tmp) + if n != 1<> 1 + wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m))) + + for k := uint(0); k < n; k += m { + w := new(fr.Element).SetOne() + for j := uint(0); j < halfM; j++ { + t := new(fr.Element).Mul(&a[k+j+halfM], w) + u := a[k+j] + a[k+j].Add(&u, t) + a[k+j+halfM].Sub(&u, t) + w.Mul(w, wm) + } + } } - return evaluations } +// func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element { +// n := len(values) +// if n == 1 { +// return values +// } + +// var generatorSquared fr.Element +// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2 + +// even, odd := takeEvenOdd(values) + +// fftEven := fftFr(even, generatorSquared) +// fftOdd := fftFr(odd, generatorSquared) + +// inputPoint := fr.One() +// evaluations := make([]fr.Element, n) +// for k := 0; k < n/2; k++ { +// var tmp fr.Element +// tmp.Mul(&inputPoint, &fftOdd[k]) + +// evaluations[k].Add(&fftEven[k], &tmp) +// evaluations[k+n/2].Sub(&fftEven[k], &tmp) + +// inputPoint.Mul(&inputPoint, &nthRootOfUnity) +// } +// return evaluations +// } + // takeEvenOdd Takes a slice and return two slices // The first slice contains (a copy of) all of the elements // at even indices, the second slice contains @@ -146,17 +230,17 @@ func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element { // We assume that the length of the given values slice is even // so the returned arrays will be the same length. // This is the case for a radix-2 FFT -func takeEvenOdd[T interface{}](values []T) ([]T, []T) { - n := len(values) - even := make([]T, 0, n/2) - odd := make([]T, 0, n/2) - for i := 0; i < n; i++ { - if i%2 == 0 { - even = append(even, values[i]) - } else { - odd = append(odd, values[i]) - } - } - - return even, odd -} +// func takeEvenOdd[T interface{}](values []T) ([]T, []T) { +// n := len(values) +// even := make([]T, 0, n/2) +// odd := make([]T, 0, n/2) +// for i := 0; i < n; i++ { +// if i%2 == 0 { +// even = append(even, values[i]) +// } else { +// odd = append(odd, values[i]) +// } +// } + +// return even, odd +// } diff --git a/internal/kzg/kzg_prove.go b/internal/kzg/kzg_prove.go index e9bebf0..8b5e59a 100644 --- a/internal/kzg/kzg_prove.go +++ b/internal/kzg/kzg_prove.go @@ -3,15 +3,34 @@ package kzg import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/crate-crypto/go-eth-kzg/internal/domain" + "github.com/crate-crypto/go-eth-kzg/internal/poly" ) +func Open(domain *domain.Domain, polyCoeff []fr.Element, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) { + + outputPoint := poly.PolyEval(polyCoeff, evaluationPoint) + + quotient := poly.DividePolyByXminusA(polyCoeff, evaluationPoint) + + comm, err := ck.Commit(quotient, 0) + if err != nil { + return OpeningProof{}, nil + } + + return OpeningProof{ + QuotientCommitment: *comm, + InputPoint: evaluationPoint, + ClaimedValue: outputPoint, + }, nil +} + // Open verifies that a polynomial f(x) when evaluated at a point `z` is equal to `f(z)` // // numGoRoutines is used to configure the amount of concurrency needed. Setting this // value to a negative number or 0 will make it default to the number of CPUs. // // [compute_kzg_proof_impl]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#compute_kzg_proof_impl -func Open(domain *domain.Domain, p Polynomial, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) { +func Open_(domain *domain.Domain, p Polynomial, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) { if len(p) == 0 || len(p) > len(ck.G1) { return OpeningProof{}, ErrInvalidPolynomialSize } diff --git a/internal/kzg/kzg_test.go b/internal/kzg/kzg_test.go index 801b192..47efaf5 100644 --- a/internal/kzg/kzg_test.go +++ b/internal/kzg/kzg_test.go @@ -12,9 +12,9 @@ import ( func TestProofVerifySmoke(t *testing.T) { domain := domain.NewDomain(4) - srs, _ := newLagrangeSRSInsecure(*domain, big.NewInt(1234)) + srs, _ := newMonomialSRSInsecure(*domain, big.NewInt(1234)) - // polynomial in lagrange form + // polynomial in monomial form poly := Polynomial{fr.NewElement(2), fr.NewElement(3), fr.NewElement(4), fr.NewElement(5)} comm, _ := srs.CommitKey.Commit(poly, 0) @@ -29,7 +29,7 @@ func TestProofVerifySmoke(t *testing.T) { func TestBatchVerifySmoke(t *testing.T) { domain := domain.NewDomain(4) - srs, _ := newLagrangeSRSInsecure(*domain, big.NewInt(1234)) + srs, _ := newMonomialSRSInsecure(*domain, big.NewInt(1234)) numProofs := 10 commitments := make([]Commitment, 0, numProofs) diff --git a/internal/poly/poly.go b/internal/poly/poly.go index ad84c07..fc428e3 100644 --- a/internal/poly/poly.go +++ b/internal/poly/poly.go @@ -85,15 +85,13 @@ func equalPoly(a, b PolynomialCoeff) bool { // PolyEval evaluates a polynomial f(x) at a point z, computing f(z). // The polynomial is given in coefficient form, and `z` is denoted as inputPoint. func PolyEval(poly PolynomialCoeff, inputPoint fr.Element) fr.Element { - result := fr.NewElement(0) - - for i := len(poly) - 1; i >= 0; i-- { - tmp := fr.Element{} - tmp.Mul(&result, &inputPoint) - result.Add(&tmp, &poly[i]) + res := poly[len(poly)-1] + for i := len(poly) - 2; i >= 0; i-- { + res.Mul(&res, &inputPoint) + res.Add(&res, &poly[i]) } - return result + return res } // DividePolyByXminusA computes f(x) / (x - a) and returns the quotient. diff --git a/prove.go b/prove.go index 6c1d04c..9d1e277 100644 --- a/prove.go +++ b/prove.go @@ -1,6 +1,7 @@ package goethkzg import ( + "github.com/crate-crypto/go-eth-kzg/internal/domain" "github.com/crate-crypto/go-eth-kzg/internal/kzg" ) @@ -62,8 +63,11 @@ func (c *Context) ComputeBlobKZGProof(blob *Blob, blobCommitment KZGCommitment, // 2. Compute Fiat-Shamir challenge evaluationChallenge := computeChallenge(blob, blobCommitment) + domain.BitReverse(polynomial) + polyCoeff := c.domain.IfftFr(polynomial) + // 3. Create opening proof - openingProof, err := kzg.Open(c.domain, polynomial, evaluationChallenge, c.commitKeyLagrange, numGoRoutines) + openingProof, err := kzg.Open(c.domain, polyCoeff, evaluationChallenge, c.commitKeyMonomial, numGoRoutines) if err != nil { return KZGProof{}, err } @@ -95,8 +99,11 @@ func (c *Context) ComputeKZGProof(blob *Blob, inputPointBytes Scalar, numGoRouti return KZGProof{}, [32]byte{}, err } + domain.BitReverse(polynomial) + polyCoeff := c.domain.IfftFr(polynomial) + // 2. Create opening proof - openingProof, err := kzg.Open(c.domain, polynomial, inputPoint, c.commitKeyLagrange, numGoRoutines) + openingProof, err := kzg.Open(c.domain, polyCoeff, inputPoint, c.commitKeyMonomial, numGoRoutines) if err != nil { return KZGProof{}, [32]byte{}, err }