From ff6251a8c67853a005fa2eaefe09620834714067 Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Fri, 28 Feb 2025 18:14:37 +0100 Subject: [PATCH] Fix roundtrip deserialization of durations --- lib/src/types/duration.rs | 15 +-- lib/src/types/serde/date_time.rs | 14 +-- lib/src/types/serde/duration.rs | 152 ++++++++++++++++++++++++++ lib/src/types/serde/mod.rs | 1 + lib/src/types/serde/typ.rs | 29 ++++- lib/tests/duration_deserialization.rs | 34 ++++++ 6 files changed, 221 insertions(+), 24 deletions(-) create mode 100644 lib/src/types/serde/duration.rs create mode 100644 lib/tests/duration_deserialization.rs diff --git a/lib/src/types/duration.rs b/lib/src/types/duration.rs index 08625bc8..4d222fe1 100644 --- a/lib/src/types/duration.rs +++ b/lib/src/types/duration.rs @@ -4,10 +4,10 @@ use neo4rs_macros::BoltStruct; #[derive(Debug, PartialEq, Eq, Clone, BoltStruct)] #[signature(0xB4, 0x45)] pub struct BoltDuration { - months: BoltInteger, - days: BoltInteger, - seconds: BoltInteger, - nanoseconds: BoltInteger, + pub(crate) months: BoltInteger, + pub(crate) days: BoltInteger, + pub(crate) seconds: BoltInteger, + pub(crate) nanoseconds: BoltInteger, } impl BoltDuration { @@ -31,10 +31,6 @@ impl BoltDuration { .saturating_add(self.days.value.saturating_mul(24 * 3600)) .saturating_add(self.months.value.saturating_mul(2_629_800)) } - - pub(crate) fn nanoseconds(&self) -> i64 { - self.nanoseconds.value - } } impl From for BoltDuration { @@ -53,8 +49,7 @@ impl From for BoltDuration { impl From for std::time::Duration { fn from(value: BoltDuration) -> Self { //TODO: clarify month issue - let seconds = - value.seconds.value + (value.days.value * 24 * 3600) + (value.months.value * 2_629_800); + let seconds = value.seconds(); std::time::Duration::new(seconds as u64, value.nanoseconds.value as u32) } } diff --git a/lib/src/types/serde/date_time.rs b/lib/src/types/serde/date_time.rs index 1b61e51e..b64cb655 100644 --- a/lib/src/types/serde/date_time.rs +++ b/lib/src/types/serde/date_time.rs @@ -2,13 +2,15 @@ use core::fmt; use std::{iter::Peekable, marker::PhantomData}; use serde::de::{ - value::{BorrowedStrDeserializer, MapDeserializer, SeqDeserializer}, + value::{BorrowedStrDeserializer, MapDeserializer}, DeserializeSeed, Error, IntoDeserializer, MapAccess, SeqAccess, Visitor, }; -use crate::types::{serde::builder::SetOnce, BoltLocalDateTime, BoltString}; use crate::{ - types::{BoltDateTime, BoltDateTimeZoneId, BoltDuration, BoltInteger}, + types::{ + serde::builder::SetOnce, BoltDateTime, BoltDateTimeZoneId, BoltInteger, BoltLocalDateTime, + BoltString, + }, DeError, }; @@ -57,12 +59,6 @@ impl BoltDateTimeZoneId { } } -impl BoltDuration { - pub(crate) fn seq_access(&self) -> impl SeqAccess<'_, Error = DeError> { - SeqDeserializer::new([self.seconds(), self.nanoseconds()].into_iter()) - } -} - struct BoltDateTimeZoneIdAccess<'a, const N: usize>( &'a BoltDateTimeZoneId, Peekable<<[Fields; N] as IntoIterator>::IntoIter>, diff --git a/lib/src/types/serde/duration.rs b/lib/src/types/serde/duration.rs new file mode 100644 index 00000000..e7b9fd44 --- /dev/null +++ b/lib/src/types/serde/duration.rs @@ -0,0 +1,152 @@ +use core::fmt; + +use serde::de::{value::SeqDeserializer, Error, MapAccess, SeqAccess, Visitor}; + +use crate::{ + types::{serde::builder::SetOnce, BoltDuration, BoltInteger}, + DeError, +}; + +crate::cenum!(Fields { + Months, + Days, + Seconds, + NanoSeconds, +}); + +impl BoltDuration { + pub(crate) fn seq_access_bolt(&self) -> impl SeqAccess<'_, Error = DeError> { + SeqDeserializer::new( + [ + self.months.value, + self.days.value, + self.seconds.value, + self.nanoseconds.value, + ] + .into_iter(), + ) + } + pub(crate) fn seq_access_external(&self) -> impl SeqAccess<'_, Error = DeError> { + SeqDeserializer::new([self.seconds(), self.nanoseconds.value].into_iter()) + } +} + +pub struct BoltDurationVisitor; + +impl<'de> Visitor<'de> for BoltDurationVisitor { + type Value = BoltDuration; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("BoltDuration struct") + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let mut builder = DurationBuilder::default(); + + while let Some(key) = map.next_key::()? { + match key { + Fields::Months => builder.months(|| map.next_value())?, + Fields::Days => builder.days(|| map.next_value())?, + Fields::Seconds => builder.seconds(|| map.next_value())?, + Fields::NanoSeconds => builder.nanoseconds(|| map.next_value())?, + } + } + + builder.build() + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + const FIELDS: [Fields; 4] = [ + Fields::Months, + Fields::Days, + Fields::Seconds, + Fields::NanoSeconds, + ]; + + let mut require_next = |field| { + seq.next_element() + .and_then(|value| value.ok_or_else(|| Error::missing_field(field))) + }; + + let mut builder = DurationBuilder::default(); + + for field in FIELDS { + match field { + Fields::Months => builder.months(|| require_next("months"))?, + Fields::Days => builder.days(|| require_next("days"))?, + Fields::Seconds => builder.seconds(|| require_next("seconds"))?, + Fields::NanoSeconds => builder.nanoseconds(|| require_next("nanoseconds"))?, + } + } + + if seq.next_element::()?.is_some() { + return Err(Error::invalid_length(0, &"4")); + } + + builder.build() + } +} + +#[derive(Default)] +pub(crate) struct DurationBuilder { + pub(crate) months: SetOnce, + pub(crate) days: SetOnce, + pub(crate) seconds: SetOnce, + pub(crate) nanoseconds: SetOnce, +} + +impl DurationBuilder { + fn months(&mut self, f: impl FnOnce() -> Result) -> Result<(), E> { + self.months + .try_insert_with(f) + .map_or_else(|_| Err(Error::duplicate_field("months")), |_| Ok(())) + } + + fn days(&mut self, f: impl FnOnce() -> Result) -> Result<(), E> { + self.days + .try_insert_with(f) + .map_or_else(|_| Err(Error::duplicate_field("days")), |_| Ok(())) + } + + fn seconds(&mut self, f: impl FnOnce() -> Result) -> Result<(), E> { + self.seconds + .try_insert_with(f) + .map_or_else(|_| Err(Error::duplicate_field("seconds")), |_| Ok(())) + } + + fn nanoseconds( + &mut self, + f: impl FnOnce() -> Result, + ) -> Result<(), E> { + self.nanoseconds + .try_insert_with(f) + .map_or_else(|_| Err(Error::duplicate_field("nanoseconds")), |_| Ok(())) + } + + fn build(mut self: DurationBuilder) -> Result { + Ok(BoltDuration { + months: self + .months + .take() + .ok_or_else(|| Error::missing_field("months"))?, + days: self + .days + .take() + .ok_or_else(|| Error::missing_field("days"))?, + seconds: self + .seconds + .take() + .ok_or_else(|| Error::missing_field("seconds"))?, + nanoseconds: self + .nanoseconds + .take() + .ok_or_else(|| Error::missing_field("nanoseconds"))?, + }) + } +} diff --git a/lib/src/types/serde/mod.rs b/lib/src/types/serde/mod.rs index f1737628..9a2798ba 100644 --- a/lib/src/types/serde/mod.rs +++ b/lib/src/types/serde/mod.rs @@ -10,6 +10,7 @@ mod builder; mod cenum; mod date_time; mod de; +mod duration; mod element; mod error; mod kind; diff --git a/lib/src/types/serde/typ.rs b/lib/src/types/serde/typ.rs index 89136a56..eb8135da 100644 --- a/lib/src/types/serde/typ.rs +++ b/lib/src/types/serde/typ.rs @@ -2,6 +2,7 @@ use crate::{ types::{ serde::{ date_time::BoltDateTimeVisitor, + duration::BoltDurationVisitor, element::ElementDataDeserializer, node::BoltNodeVisitor, path::BoltPathVisitor, @@ -240,7 +241,9 @@ impl<'de> Visitor<'de> for BoltTypeVisitor { BoltKind::Path => variant .tuple_variant(1, BoltPathVisitor) .map(BoltType::Path), - BoltKind::Duration => variant.tuple_variant(1, self), + BoltKind::Duration => variant + .tuple_variant(1, BoltDurationVisitor) + .map(BoltType::Duration), BoltKind::Date => variant .tuple_variant(1, BoltDateTimeVisitor::::new()) .map(BoltType::Date), @@ -328,7 +331,7 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> { BoltType::Point3D(p) => p .into_deserializer() .deserialize_struct(name, fields, visitor), - BoltType::Duration(d) => visitor.visit_seq(d.seq_access()), + BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()), _ => self.unexpected(visitor), } } @@ -360,7 +363,7 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> { BoltType::Point3D(p) => p .into_deserializer() .deserialize_newtype_struct(name, visitor), - BoltType::Duration(d) => visitor.visit_seq(d.seq_access()), + BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()), BoltType::DateTimeZoneId(dtz) if name == "Timezone" => { visitor.visit_newtype_struct(BorrowedStrDeserializer::new(dtz.tz_id())) } @@ -378,7 +381,8 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> { } BoltType::Point2D(p) => p.into_deserializer().deserialize_tuple(len, visitor), BoltType::Point3D(p) => p.into_deserializer().deserialize_tuple(len, visitor), - BoltType::Duration(d) if len == 2 => visitor.visit_seq(d.seq_access()), + BoltType::Duration(d) if len == 2 => visitor.visit_seq(d.seq_access_external()), + BoltType::Duration(d) if len == 4 => visitor.visit_seq(d.seq_access_bolt()), BoltType::DateTimeZoneId(dtz) => visitor.visit_seq( dtz.seq_access( std::any::type_name::() @@ -879,7 +883,8 @@ impl<'de> VariantAccess<'de> for BoltEnum<'de> { BoltType::Point3D(p) => BoltPointDeserializer::new(p).deserialize_tuple(len, visitor), BoltType::Bytes(b) => visitor.visit_borrowed_bytes(&b.value), BoltType::Path(p) => ElementDataDeserializer::new(p).tuple_variant(len, visitor), - BoltType::Duration(d) => visitor.visit_seq(d.seq_access()), + BoltType::Duration(d) if len == 1 => visitor.visit_seq(d.seq_access_bolt()), + BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()), BoltType::Date(d) => visitor.visit_map(d.map_access()), BoltType::Time(t) => visitor.visit_map(t.map_access()), BoltType::LocalTime(t) => visitor.visit_map(t.map_access()), @@ -2007,6 +2012,20 @@ mod tests { assert_eq!(actual, duration); } + #[test] + fn duration_roundtrip() { + let duration = BoltDuration::from(Duration::new(42, 1337)); + + let bolt = BoltType::Duration(duration.clone()); + + let actual = bolt.to::().unwrap(); + let BoltType::Duration(actual) = actual else { + panic!() + }; + + assert_eq!(actual, duration); + } + fn test_date() -> NaiveDate { NaiveDate::from_ymd_opt(1999, 7, 14).unwrap() } diff --git a/lib/tests/duration_deserialization.rs b/lib/tests/duration_deserialization.rs new file mode 100644 index 00000000..bea4f9a3 --- /dev/null +++ b/lib/tests/duration_deserialization.rs @@ -0,0 +1,34 @@ +use neo4rs::*; + +mod container; + +#[tokio::test] +async fn duration_deserialization() { + let neo4j = container::Neo4jContainer::new().await; + let graph = neo4j.graph(); + + let duration = std::time::Duration::new(5259600, 7); + let mut result = graph + .execute(query("RETURN $d as output").param("d", duration)) + .await + .unwrap(); + let row = result.next().await.unwrap().unwrap(); + let d: std::time::Duration = row.get("output").unwrap(); + assert_eq!(d, duration); + + let mut result = graph + .execute(query("RETURN $d as output").param("d", duration)) + .await + .unwrap(); + let row = result.next().await.unwrap().unwrap(); + let d = row.get::("output").unwrap(); + assert_eq!( + d, + BoltType::Duration(BoltDuration::new( + 0.into(), + 0.into(), + 5259600.into(), + 7.into(), + )) + ); +}