Skip to content

Commit f650af0

Browse files
authored
Union encoding (#49)
1 parent 21554f2 commit f650af0

File tree

2 files changed

+114
-5
lines changed

2 files changed

+114
-5
lines changed

src/common_union.rs

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use std::sync::{Arc, OnceLock};
22

33
use datafusion::arrow::array::{
4-
Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray,
4+
Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray,
55
};
6-
use datafusion::arrow::buffer::Buffer;
6+
use datafusion::arrow::buffer::{Buffer, ScalarBuffer};
77
use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
8+
use datafusion::arrow::error::ArrowError;
89
use datafusion::common::ScalarValue;
910

10-
pub(crate) fn is_json_union(data_type: &DataType) -> bool {
11+
pub fn is_json_union(data_type: &DataType) -> bool {
1112
match data_type {
1213
DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(),
1314
_ => false,
@@ -64,7 +65,7 @@ impl JsonUnion {
6465
strings: vec![None; length],
6566
arrays: vec![None; length],
6667
objects: vec![None; length],
67-
type_ids: vec![0; length],
68+
type_ids: vec![TYPE_ID_NULL; length],
6869
index: 0,
6970
length,
7071
}
@@ -114,7 +115,7 @@ impl FromIterator<Option<JsonUnionField>> for JsonUnion {
114115
}
115116

116117
impl TryFrom<JsonUnion> for UnionArray {
117-
type Error = datafusion::arrow::error::ArrowError;
118+
type Error = ArrowError;
118119

119120
fn try_from(value: JsonUnion) -> Result<Self, Self::Error> {
120121
let children: Vec<Arc<dyn Array>> = vec![
@@ -199,3 +200,109 @@ impl From<JsonUnionField> for ScalarValue {
199200
}
200201
}
201202
}
203+
204+
pub struct JsonUnionEncoder {
205+
boolean: BooleanArray,
206+
int: Int64Array,
207+
float: Float64Array,
208+
string: StringArray,
209+
array: StringArray,
210+
object: StringArray,
211+
type_ids: ScalarBuffer<i8>,
212+
}
213+
214+
impl JsonUnionEncoder {
215+
#[must_use]
216+
pub fn from_union(union: UnionArray) -> Option<Self> {
217+
if is_json_union(union.data_type()) {
218+
let (_, type_ids, _, c) = union.into_parts();
219+
Some(Self {
220+
boolean: c[1].as_boolean().clone(),
221+
int: c[2].as_primitive().clone(),
222+
float: c[3].as_primitive().clone(),
223+
string: c[4].as_string().clone(),
224+
array: c[5].as_string().clone(),
225+
object: c[6].as_string().clone(),
226+
type_ids,
227+
})
228+
} else {
229+
None
230+
}
231+
}
232+
233+
#[must_use]
234+
#[allow(clippy::len_without_is_empty)]
235+
pub fn len(&self) -> usize {
236+
self.type_ids.len()
237+
}
238+
239+
/// Get the encodable value for a given index
240+
///
241+
/// # Panics
242+
///
243+
/// Panics if the idx is outside the union values or an invalid type id exists in the union.
244+
#[must_use]
245+
pub fn get_value(&self, idx: usize) -> JsonUnionValue {
246+
let type_id = self.type_ids[idx];
247+
match type_id {
248+
TYPE_ID_NULL => JsonUnionValue::JsonNull,
249+
TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)),
250+
TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)),
251+
TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)),
252+
TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)),
253+
TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)),
254+
TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)),
255+
_ => panic!("Invalid type_id: {type_id}, not a valid JSON type"),
256+
}
257+
}
258+
}
259+
260+
#[derive(Debug, PartialEq)]
261+
pub enum JsonUnionValue<'a> {
262+
JsonNull,
263+
Bool(bool),
264+
Int(i64),
265+
Float(f64),
266+
Str(&'a str),
267+
Array(&'a str),
268+
Object(&'a str),
269+
}
270+
271+
#[cfg(test)]
272+
mod test {
273+
use super::*;
274+
275+
#[test]
276+
fn test_json_union() {
277+
let json_union = JsonUnion::from_iter(vec![
278+
Some(JsonUnionField::JsonNull),
279+
Some(JsonUnionField::Bool(true)),
280+
Some(JsonUnionField::Bool(false)),
281+
Some(JsonUnionField::Int(42)),
282+
Some(JsonUnionField::Float(42.0)),
283+
Some(JsonUnionField::Str("foo".to_string())),
284+
Some(JsonUnionField::Array("[42]".to_string())),
285+
Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())),
286+
None,
287+
]);
288+
289+
let union_array = UnionArray::try_from(json_union).unwrap();
290+
let encoder = JsonUnionEncoder::from_union(union_array).unwrap();
291+
292+
let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect();
293+
assert_eq!(
294+
values_after,
295+
vec![
296+
JsonUnionValue::JsonNull,
297+
JsonUnionValue::Bool(true),
298+
JsonUnionValue::Bool(false),
299+
JsonUnionValue::Int(42),
300+
JsonUnionValue::Float(42.0),
301+
JsonUnionValue::Str("foo"),
302+
JsonUnionValue::Array("[42]"),
303+
JsonUnionValue::Object(r#"{"foo": 42}"#),
304+
JsonUnionValue::JsonNull,
305+
]
306+
);
307+
}
308+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ mod json_get_str;
1919
mod json_length;
2020
mod rewrite;
2121

22+
pub use common_union::{JsonUnionEncoder, JsonUnionValue};
23+
2224
pub mod functions {
2325
pub use crate::json_as_text::json_as_text;
2426
pub use crate::json_contains::json_contains;

0 commit comments

Comments
 (0)