Skip to content

Commit 59d2152

Browse files
authored
fix: remove node wrapper and add UTs (#234)
* fix: remove node wrapper and add UTs * fix: do not ignore the empty leaf * fix: donot ignore the empty leaf * fix: typos
1 parent 6f8114f commit 59d2152

File tree

9 files changed

+259
-278
lines changed

9 files changed

+259
-278
lines changed

starky/src/digest.rs

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,31 @@
22
use crate::field_bls12381::Fr as Fr_bls12381;
33
use crate::field_bls12381::FrRepr as FrRepr_bls12381;
44
use crate::field_bn128::{Fr, FrRepr};
5+
use crate::helper;
56
use crate::traits::MTNodeType;
67
use ff::*;
78
use fields::field_gl::Fr as FGL;
89
use serde::de::{SeqAccess, Visitor};
910
use serde::ser::SerializeSeq;
1011
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
12+
use std::any::TypeId;
1113
use std::fmt;
1214
use std::fmt::Display;
1315
use std::marker::PhantomData;
1416

1517
/// the trait F is used to keep track of source data type, so we can implement its deserializer
16-
// TODO: Remove the generic type: F. As it's never used.
1718
#[repr(C)]
1819
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
1920
pub struct ElementDigest<const N: usize, F: PrimeField + Default>(pub [FGL; N], PhantomData<F>);
2021

22+
impl<const N: usize, F: PrimeField + Default> ElementDigest<N, F> {
23+
// FIXME: this is a bit tricky that assuming the len is 4, replace it by N here.
24+
pub fn is_dim_1(&self) -> bool {
25+
let e = self.as_elements();
26+
e[1] == e[2] && e[1] == e[3] && e[1] == FGL::ZERO
27+
}
28+
}
29+
2130
impl<const N: usize, F: PrimeField + Default> MTNodeType for ElementDigest<N, F> {
2231
type BaseField = F;
2332
#[inline(always)]
@@ -77,13 +86,28 @@ impl<const N: usize, F: PrimeField + Default> Serialize for ElementDigest<N, F>
7786
where
7887
S: Serializer,
7988
{
80-
let elems = self.0.to_vec();
81-
82-
let mut seq = serializer.serialize_seq(Some(elems.len()))?;
83-
for v in elems.iter() {
84-
seq.serialize_element(&v.as_int().to_string())?;
89+
let source = TypeId::of::<F>();
90+
if source == TypeId::of::<Fr>() {
91+
let r: Fr = Fr(self.as_scalar::<Fr>());
92+
return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string());
93+
}
94+
if source == TypeId::of::<Fr_bls12381>() {
95+
let r: Fr_bls12381 = Fr_bls12381(self.as_scalar::<Fr_bls12381>());
96+
return serializer.serialize_str(&helper::fr_to_biguint(&r).to_string());
97+
}
98+
if source == TypeId::of::<FGL>() {
99+
let e = self.as_elements();
100+
if self.is_dim_1() {
101+
return serializer.serialize_str(&e[0].as_int().to_string());
102+
} else {
103+
let mut seq = serializer.serialize_seq(Some(4))?;
104+
for v in e.iter() {
105+
seq.serialize_element(&v.as_int().to_string())?;
106+
}
107+
return seq.end();
108+
}
85109
}
86-
seq.end()
110+
panic!("Invalid element to serialize, {:?}", self.0)
87111
}
88112
}
89113

@@ -105,24 +129,34 @@ impl<'de, const N: usize, F: PrimeField + Default> Deserialize<'de> for ElementD
105129
where
106130
A: SeqAccess<'de>,
107131
{
108-
let mut entries = Vec::with_capacity(N);
132+
let mut entries = Vec::new();
109133
while let Some(entry) = seq.next_element::<String>()? {
110134
let entry: u64 = entry.parse().unwrap();
111-
112-
entries.push(FGL::from_repr(fields::field_gl::FrRepr::from(entry)).unwrap());
135+
entries.push(FGL::from(entry));
113136
}
114137
Ok(ElementDigest::<N, F>::new(&entries))
115138
}
116139

140+
// it could be one-dim GL, BN128, or BLS12381
117141
fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
118142
where
119143
E: de::Error,
120144
{
121-
let entry: u64 = s.parse().unwrap();
122-
123-
let data = vec![FGL::from_repr(fields::field_gl::FrRepr::from(entry)).unwrap(); N];
124-
125-
Ok(ElementDigest::<N, F>::new(&data))
145+
let source = TypeId::of::<F>();
146+
if source == TypeId::of::<FGL>() {
147+
// one-dim GL elements
148+
let value = FGL::from_str(s).unwrap();
149+
Ok(ElementDigest::<N, F>::new(&[
150+
value,
151+
FGL::ZERO,
152+
FGL::ZERO,
153+
FGL::ZERO,
154+
]))
155+
} else {
156+
// BN128 or BLS12381
157+
let t = F::from_str(s).unwrap();
158+
Ok(ElementDigest::<N, F>::from_scalar(&t))
159+
}
126160
}
127161
}
128162
deserializer.deserialize_any(EntriesVisitor::<N, F>(Default::default()))
@@ -280,11 +314,10 @@ mod tests {
280314

281315
#[test]
282316
fn test_element_digest_serialize_and_deserialize() {
283-
const N: usize = 3;
317+
const N: usize = 4;
284318
let fields = vec![FGL::one(); N];
285319
let data = ElementDigest::<N, FGL>::new(&fields);
286320
let serialized = serde_json::to_string(&data).unwrap();
287-
println!("Serialized: {}", serialized);
288321

289322
let expect: ElementDigest<N, FGL> = serde_json::from_str(&serialized).unwrap();
290323

starky/src/field_bls12381.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,51 @@
22
use core::ops::{Add, Div, Mul, Neg, Sub};
33
use ff::*;
44

5+
use crate::helper;
6+
use serde::de::{Error, SeqAccess, Visitor};
7+
use serde::ser::SerializeSeq;
8+
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
9+
use std::fmt;
10+
511
#[derive(PrimeField)]
612
#[PrimeFieldModulus = "52435875175126190479447740508185965837690552500527637822603658699938581184513"]
713
#[PrimeFieldGenerator = "7"]
814
pub struct Fr(pub FrRepr);
915

16+
impl Serialize for Fr {
17+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
18+
where
19+
S: Serializer,
20+
{
21+
serializer.serialize_str(&helper::fr_to_biguint(self).to_string())
22+
}
23+
}
24+
25+
impl<'de> Deserialize<'de> for Fr {
26+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
27+
where
28+
D: Deserializer<'de>,
29+
{
30+
struct EntriesVisitor;
31+
32+
impl<'de> Visitor<'de> for EntriesVisitor {
33+
type Value = Fr;
34+
35+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
36+
formatter.write_str("struct Bls12381's Fr")
37+
}
38+
39+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
40+
where
41+
E: Error,
42+
{
43+
Ok(Self::Value::from_str(v).unwrap())
44+
}
45+
}
46+
deserializer.deserialize_any(EntriesVisitor)
47+
}
48+
}
49+
1050
#[cfg(test)]
1151
mod tests {
1252
use crate::field_bls12381::*;

starky/src/field_bn128.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,52 @@
11
#![allow(unused_imports, clippy::too_many_arguments)]
22
use ff::*;
33

4+
use crate::helper;
5+
use ff::*;
6+
use serde::de::{Error, SeqAccess, Visitor};
7+
use serde::ser::SerializeSeq;
8+
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
9+
use std::fmt;
10+
411
#[derive(PrimeField)]
512
#[PrimeFieldModulus = "21888242871839275222246405745257275088548364400416034343698204186575808495617"]
613
#[PrimeFieldGenerator = "7"]
714
pub struct Fr(pub FrRepr);
815

16+
impl Serialize for Fr {
17+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
18+
where
19+
S: Serializer,
20+
{
21+
serializer.serialize_str(&helper::fr_to_biguint(self).to_string())
22+
}
23+
}
24+
25+
impl<'de> Deserialize<'de> for Fr {
26+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
27+
where
28+
D: Deserializer<'de>,
29+
{
30+
struct EntriesVisitor;
31+
32+
impl<'de> Visitor<'de> for EntriesVisitor {
33+
type Value = Fr;
34+
35+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
36+
formatter.write_str("struct Bn128's Fr")
37+
}
38+
39+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
40+
where
41+
E: Error,
42+
{
43+
Ok(Self::Value::from_str(v).unwrap())
44+
}
45+
}
46+
deserializer.deserialize_any(EntriesVisitor)
47+
}
48+
}
49+
950
#[cfg(test)]
1051
mod tests {
1152
use crate::field_bn128::*;

starky/src/merklehash.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ impl MerkleTree for MerkleTreeGL {
283283
vec![node.as_elements().to_vec()[0]]
284284
}
285285

286+
fn from_basefield(node: &FGL) -> Self::MTNode {
287+
Self::MTNode::new(&[*node, FGL::ZERO, FGL::ZERO, FGL::ZERO])
288+
}
289+
286290
#[cfg(not(any(
287291
target_feature = "avx512bw",
288292
target_feature = "avx512cd",
@@ -582,12 +586,8 @@ mod tests {
582586
#[test]
583587
fn test_merkle_tree_gl_serialize_and_deserialize() {
584588
let data = MerkleTreeGL::new();
585-
586589
let serialized = serde_json::to_string(&data).unwrap();
587-
println!("Serialized: {}", serialized);
588-
589590
let expect: MerkleTreeGL = serde_json::from_str(&serialized).unwrap();
590-
591591
assert_eq!(data, expect);
592592
}
593593
}

starky/src/merklehash_bls12381.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ impl MerkleTree for MerkleTreeBLS12381 {
176176
vec![Fr(node.as_scalar::<Fr>())]
177177
}
178178

179+
fn from_basefield(node: &Fr) -> Self::MTNode {
180+
Self::MTNode::from_scalar(node)
181+
}
182+
179183
fn merkelize(&mut self, buff: Vec<FGL>, width: usize, height: usize) -> Result<()> {
180184
let max_workers = get_max_workers();
181185

@@ -373,12 +377,8 @@ mod tests {
373377
#[test]
374378
fn test_merkle_tree_bls381_serialize_and_deserialize() {
375379
let data = MerkleTreeBLS12381::new();
376-
377380
let serialized = serde_json::to_string(&data).unwrap();
378-
println!("Serialized: {}", serialized);
379-
380381
let expect: MerkleTreeBLS12381 = serde_json::from_str(&serialized).unwrap();
381-
382382
assert_eq!(data, expect);
383383
}
384384
}

starky/src/merklehash_bn128.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ impl MerkleTree for MerkleTreeBN128 {
176176
vec![Fr(node.as_scalar::<Fr>())]
177177
}
178178

179+
fn from_basefield(node: &Fr) -> Self::MTNode {
180+
Self::MTNode::from_scalar(node)
181+
}
182+
179183
fn merkelize(&mut self, buff: Vec<FGL>, width: usize, height: usize) -> Result<()> {
180184
let max_workers = get_max_workers();
181185

@@ -367,12 +371,8 @@ mod tests {
367371
#[test]
368372
fn test_merkle_tree_bn128_serialize_and_deserialize() {
369373
let data = MerkleTreeBN128::new();
370-
371374
let serialized = serde_json::to_string(&data).unwrap();
372-
println!("Serialized: {}", serialized);
373-
374375
let expect: MerkleTreeBN128 = serde_json::from_str(&serialized).unwrap();
375-
376376
assert_eq!(data, expect);
377377
}
378378
}

0 commit comments

Comments
 (0)