|
1 | 1 | use std::sync::{Arc, OnceLock}; |
2 | 2 |
|
3 | 3 | use datafusion::arrow::array::{ |
4 | | - Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray, |
| 4 | + Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray, |
5 | 5 | }; |
6 | | -use datafusion::arrow::buffer::Buffer; |
| 6 | +use datafusion::arrow::buffer::{Buffer, ScalarBuffer}; |
7 | 7 | use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; |
| 8 | +use datafusion::arrow::error::ArrowError; |
8 | 9 | use datafusion::common::ScalarValue; |
9 | 10 |
|
10 | | -pub(crate) fn is_json_union(data_type: &DataType) -> bool { |
| 11 | +pub fn is_json_union(data_type: &DataType) -> bool { |
11 | 12 | match data_type { |
12 | 13 | DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(), |
13 | 14 | _ => false, |
@@ -64,7 +65,7 @@ impl JsonUnion { |
64 | 65 | strings: vec![None; length], |
65 | 66 | arrays: vec![None; length], |
66 | 67 | objects: vec![None; length], |
67 | | - type_ids: vec![0; length], |
| 68 | + type_ids: vec![TYPE_ID_NULL; length], |
68 | 69 | index: 0, |
69 | 70 | length, |
70 | 71 | } |
@@ -114,7 +115,7 @@ impl FromIterator<Option<JsonUnionField>> for JsonUnion { |
114 | 115 | } |
115 | 116 |
|
116 | 117 | impl TryFrom<JsonUnion> for UnionArray { |
117 | | - type Error = datafusion::arrow::error::ArrowError; |
| 118 | + type Error = ArrowError; |
118 | 119 |
|
119 | 120 | fn try_from(value: JsonUnion) -> Result<Self, Self::Error> { |
120 | 121 | let children: Vec<Arc<dyn Array>> = vec![ |
@@ -199,3 +200,109 @@ impl From<JsonUnionField> for ScalarValue { |
199 | 200 | } |
200 | 201 | } |
201 | 202 | } |
| 203 | + |
| 204 | +pub struct JsonUnionEncoder { |
| 205 | + boolean: BooleanArray, |
| 206 | + int: Int64Array, |
| 207 | + float: Float64Array, |
| 208 | + string: StringArray, |
| 209 | + array: StringArray, |
| 210 | + object: StringArray, |
| 211 | + type_ids: ScalarBuffer<i8>, |
| 212 | +} |
| 213 | + |
| 214 | +impl JsonUnionEncoder { |
| 215 | + #[must_use] |
| 216 | + pub fn from_union(union: UnionArray) -> Option<Self> { |
| 217 | + if is_json_union(union.data_type()) { |
| 218 | + let (_, type_ids, _, c) = union.into_parts(); |
| 219 | + Some(Self { |
| 220 | + boolean: c[1].as_boolean().clone(), |
| 221 | + int: c[2].as_primitive().clone(), |
| 222 | + float: c[3].as_primitive().clone(), |
| 223 | + string: c[4].as_string().clone(), |
| 224 | + array: c[5].as_string().clone(), |
| 225 | + object: c[6].as_string().clone(), |
| 226 | + type_ids, |
| 227 | + }) |
| 228 | + } else { |
| 229 | + None |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + #[must_use] |
| 234 | + #[allow(clippy::len_without_is_empty)] |
| 235 | + pub fn len(&self) -> usize { |
| 236 | + self.type_ids.len() |
| 237 | + } |
| 238 | + |
| 239 | + /// Get the encodable value for a given index |
| 240 | + /// |
| 241 | + /// # Panics |
| 242 | + /// |
| 243 | + /// Panics if the idx is outside the union values or an invalid type id exists in the union. |
| 244 | + #[must_use] |
| 245 | + pub fn get_value(&self, idx: usize) -> JsonUnionValue { |
| 246 | + let type_id = self.type_ids[idx]; |
| 247 | + match type_id { |
| 248 | + TYPE_ID_NULL => JsonUnionValue::JsonNull, |
| 249 | + TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)), |
| 250 | + TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)), |
| 251 | + TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)), |
| 252 | + TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)), |
| 253 | + TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)), |
| 254 | + TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)), |
| 255 | + _ => panic!("Invalid type_id: {type_id}, not a valid JSON type"), |
| 256 | + } |
| 257 | + } |
| 258 | +} |
| 259 | + |
| 260 | +#[derive(Debug, PartialEq)] |
| 261 | +pub enum JsonUnionValue<'a> { |
| 262 | + JsonNull, |
| 263 | + Bool(bool), |
| 264 | + Int(i64), |
| 265 | + Float(f64), |
| 266 | + Str(&'a str), |
| 267 | + Array(&'a str), |
| 268 | + Object(&'a str), |
| 269 | +} |
| 270 | + |
| 271 | +#[cfg(test)] |
| 272 | +mod test { |
| 273 | + use super::*; |
| 274 | + |
| 275 | + #[test] |
| 276 | + fn test_json_union() { |
| 277 | + let json_union = JsonUnion::from_iter(vec![ |
| 278 | + Some(JsonUnionField::JsonNull), |
| 279 | + Some(JsonUnionField::Bool(true)), |
| 280 | + Some(JsonUnionField::Bool(false)), |
| 281 | + Some(JsonUnionField::Int(42)), |
| 282 | + Some(JsonUnionField::Float(42.0)), |
| 283 | + Some(JsonUnionField::Str("foo".to_string())), |
| 284 | + Some(JsonUnionField::Array("[42]".to_string())), |
| 285 | + Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())), |
| 286 | + None, |
| 287 | + ]); |
| 288 | + |
| 289 | + let union_array = UnionArray::try_from(json_union).unwrap(); |
| 290 | + let encoder = JsonUnionEncoder::from_union(union_array).unwrap(); |
| 291 | + |
| 292 | + let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect(); |
| 293 | + assert_eq!( |
| 294 | + values_after, |
| 295 | + vec![ |
| 296 | + JsonUnionValue::JsonNull, |
| 297 | + JsonUnionValue::Bool(true), |
| 298 | + JsonUnionValue::Bool(false), |
| 299 | + JsonUnionValue::Int(42), |
| 300 | + JsonUnionValue::Float(42.0), |
| 301 | + JsonUnionValue::Str("foo"), |
| 302 | + JsonUnionValue::Array("[42]"), |
| 303 | + JsonUnionValue::Object(r#"{"foo": 42}"#), |
| 304 | + JsonUnionValue::JsonNull, |
| 305 | + ] |
| 306 | + ); |
| 307 | + } |
| 308 | +} |
0 commit comments