Skip to content

Commit c2e37f7

Browse files
committed
gh-17 fix unary deserialization
1 parent f7053c5 commit c2e37f7

File tree

1 file changed

+184
-16
lines changed

1 file changed

+184
-16
lines changed

src/de.rs

Lines changed: 184 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use serde::de::{self, DeserializeSeed, Deserialize, Deserializer, MapAccess, SeqAccess, VariantAccess, Visitor};
22
use std::fmt;
3+
use std::ops::Neg;
34
use crate::parser::{JSONValue, JSONKeyValuePair, UnaryOperator, from_str as model_from_str, from_bytes as model_from_bytes};
45
use crate::utils::unescape;
56
#[derive(Debug)]
@@ -86,7 +87,7 @@ impl<'de, 'a> Deserializer<'de> for JSONValueDeserializer<'a> {
8687
JSONValue::NaN => visitor.visit_f64(f64::NAN),
8788
JSONValue::Hexadecimal(s) => {
8889
// Optionally convert to integer, or treat as string
89-
match u64::from_str_radix(s.trim_start_matches("0x"), 16) {
90+
match u64::from_str_radix(s.to_lowercase().trim_start_matches("0x"), 16) {
9091
Ok(hex) => {
9192
visitor.visit_u64(hex)
9293
}
@@ -96,13 +97,73 @@ impl<'de, 'a> Deserializer<'de> for JSONValueDeserializer<'a> {
9697
}
9798
}
9899
JSONValue::Unary { operator, value } => {
99-
let sign = match operator {
100-
UnaryOperator::Plus => 1.0,
101-
UnaryOperator::Minus => -1.0,
102-
};
103-
let inner_de = JSONValueDeserializer { input: &**value };
104-
let number: f64 = Deserialize::deserialize(inner_de)?;
105-
visitor.visit_f64(sign * number)
100+
match &**value {
101+
JSONValue::Integer(s) => {
102+
if let Ok(i) = s.parse::<i64>() {
103+
match operator {
104+
UnaryOperator::Plus => {visitor.visit_i64(i)}
105+
UnaryOperator::Minus => {visitor.visit_i64(i.neg())}
106+
}
107+
} else {
108+
match operator {
109+
UnaryOperator::Plus => {
110+
let x = s.parse::<u64>().map_err(de::Error::custom)?;
111+
visitor.visit_u64(x)
112+
}
113+
_ => {
114+
Err(de::Error::custom(format!("Invalid integer literal for unary: {:?}", s)))
115+
116+
}
117+
}
118+
}
119+
}
120+
JSONValue::Float(s) | JSONValue::Exponent(s) => {
121+
if let Ok(f) = s.parse::<f64>() {
122+
match operator {
123+
UnaryOperator::Plus => {visitor.visit_f64(f)}
124+
UnaryOperator::Minus => {visitor.visit_f64(f.neg())}
125+
}
126+
} else {
127+
Err(de::Error::custom(format!("Invalid float literal: {:?}", s)))
128+
}
129+
}
130+
JSONValue::Infinity => {
131+
match operator {
132+
UnaryOperator::Plus => {visitor.visit_f64(f64::INFINITY)}
133+
UnaryOperator::Minus => {visitor.visit_f64(f64::NEG_INFINITY)}
134+
}
135+
}
136+
JSONValue::NaN => {
137+
match operator {
138+
UnaryOperator::Plus => {visitor.visit_f64(f64::NAN)}
139+
UnaryOperator::Minus => {visitor.visit_f64(f64::NAN.neg())}
140+
}
141+
}
142+
JSONValue::Hexadecimal(s) => {
143+
match u64::from_str_radix(s.to_lowercase().trim_start_matches("0x"), 16) {
144+
Ok(hex) => {
145+
match operator {
146+
UnaryOperator::Plus => {
147+
visitor.visit_u64(hex)
148+
}
149+
UnaryOperator::Minus => {
150+
if hex > i64::MAX as u64 {
151+
return Err(de::Error::custom(format!("Overflow when converting {} to i64", s)))
152+
}
153+
let i = hex as i64;
154+
visitor.visit_i64(i)
155+
}
156+
}
157+
}
158+
Err(e) => {
159+
Err(de::Error::custom(format!("Invalid hex {}", e)))
160+
}
161+
}
162+
}
163+
invalid_unary_val => {
164+
Err(de::Error::custom(format!("Invalid unary value: {:?}", invalid_unary_val)))
165+
}
166+
}
106167
}
107168
}
108169
}
@@ -121,21 +182,21 @@ impl<'de, 'a> Deserializer<'de> for JSONValueDeserializer<'a> {
121182
where
122183
V: Visitor<'de>
123184
{
124-
self.deserialize_any(visitor)
185+
self.deserialize_i64(visitor)
125186
}
126187

127188
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
128189
where
129190
V: Visitor<'de>
130191
{
131-
self.deserialize_any(visitor)
192+
self.deserialize_i64(visitor)
132193
}
133194

134195
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
135196
where
136197
V: Visitor<'de>
137198
{
138-
self.deserialize_any(visitor)
199+
self.deserialize_i64(visitor)
139200
}
140201

141202
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
@@ -148,6 +209,55 @@ impl<'de, 'a> Deserializer<'de> for JSONValueDeserializer<'a> {
148209
let i = s.parse::<i64>().map_err(de::Error::custom)?;
149210
visitor.visit_i64(i)
150211
}
212+
JSONValue::Hexadecimal(s) => {
213+
match u64::from_str_radix(s.to_lowercase().trim_start_matches("0x"), 16) {
214+
Ok(hex) => {
215+
if hex > i64::MAX as u64 {
216+
return Err(de::Error::custom(format!("Overflow when converting {} to i64", s)))
217+
}
218+
let i = hex as i64;
219+
visitor.visit_i64(i)
220+
}
221+
Err(e) => {
222+
Err(de::Error::custom(format!("Invalid hex {}", e)))
223+
}
224+
}
225+
}
226+
JSONValue::Unary {operator, value} => {
227+
match &**value {
228+
JSONValue::Integer(s) => {
229+
let i = s.parse::<i64>().map_err(de::Error::custom)?;
230+
match operator {
231+
UnaryOperator::Plus => {visitor.visit_i64(i)}
232+
UnaryOperator::Minus => {visitor.visit_i64(-i)}
233+
}
234+
}
235+
JSONValue::Hexadecimal(s) => {
236+
match u64::from_str_radix(s.to_lowercase().trim_start_matches("0x"), 16) {
237+
Ok(hex) => {
238+
if hex > i64::MAX as u64 {
239+
return Err(de::Error::custom(format!("Overflow when converting {} to i64", s)))
240+
}
241+
let i = hex as i64;
242+
match operator {
243+
UnaryOperator::Plus => {
244+
visitor.visit_i64(i)
245+
}
246+
UnaryOperator::Minus => {
247+
visitor.visit_i64(-i)
248+
}
249+
}
250+
}
251+
Err(e) => {
252+
Err(de::Error::custom(format!("Invalid hex {}", e)))
253+
}
254+
}
255+
}
256+
val => {
257+
Err(de::Error::custom(format!("Unsupported value for i64 {:?}", val)))
258+
}
259+
}
260+
}
151261
_ => self.deserialize_any(visitor),
152262
}
153263
}
@@ -156,35 +266,93 @@ impl<'de, 'a> Deserializer<'de> for JSONValueDeserializer<'a> {
156266
where
157267
V: Visitor<'de>
158268
{
159-
self.deserialize_any(visitor)
269+
self.deserialize_u64(visitor)
160270
}
161271

162272
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
163273
where
164274
V: Visitor<'de>
165275
{
166-
self.deserialize_any(visitor)
276+
self.deserialize_u64(visitor)
167277
}
168278

169279
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
170280
where
171281
V: Visitor<'de>
172282
{
173-
self.deserialize_any(visitor)
283+
self.deserialize_u64(visitor)
174284
}
175285

176286
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
177287
where
178288
V: Visitor<'de>
179289
{
180-
self.deserialize_any(visitor)
290+
match self.input {
291+
JSONValue::Integer(s) => {
292+
let i = s.parse::<u64>().map_err(de::Error::custom)?;
293+
visitor.visit_u64(i)
294+
}
295+
JSONValue::Hexadecimal(s) => {
296+
match u64::from_str_radix(s.to_lowercase().trim_start_matches("0x"), 16) {
297+
Ok(hex) => {
298+
visitor.visit_u64(hex)
299+
}
300+
Err(e) => {
301+
Err(de::Error::custom(format!("Invalid hex {}", e)))
302+
}
303+
}
304+
}
305+
JSONValue::Unary {operator, value} => {
306+
match &**value {
307+
JSONValue::Integer(s) => {
308+
let i = s.parse::<u64>().map_err(de::Error::custom)?;
309+
match operator {
310+
UnaryOperator::Plus => {visitor.visit_u64(i)}
311+
UnaryOperator::Minus => {
312+
if i != 0 {
313+
Err(de::Error::custom(format!("Invalid integer value: {:?}", s)))
314+
} else {
315+
visitor.visit_u64(0)
316+
}
317+
}
318+
}
319+
}
320+
JSONValue::Hexadecimal(s) => {
321+
match u64::from_str_radix(s.to_lowercase().trim_start_matches("0x"), 16) {
322+
Ok(hex) => {
323+
match operator {
324+
UnaryOperator::Plus => {
325+
visitor.visit_u64(hex)
326+
}
327+
UnaryOperator::Minus => {
328+
if hex != 0 {
329+
Err(de::Error::custom(format!("Invalid integer value: {:?}", s)))
330+
} else {
331+
visitor.visit_u64(0)
332+
}
333+
334+
}
335+
}
336+
}
337+
Err(e) => {
338+
Err(de::Error::custom(format!("Invalid hex {}", e)))
339+
}
340+
}
341+
}
342+
val => {
343+
Err(de::Error::custom(format!("Unsupported value for u64 {:?}", val)))
344+
}
345+
}
346+
}
347+
_ => self.deserialize_any(visitor),
348+
}
181349
}
182350

183351
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
184352
where
185353
V: Visitor<'de>
186354
{
187-
self.deserialize_any(visitor)
355+
self.deserialize_f64(visitor)
188356
}
189357

190358
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>

0 commit comments

Comments
 (0)