Skip to content

Commit 2fffb96

Browse files
authored
Fix bug with handling of null values in dictionaries (#70)
1 parent 999d672 commit 2fffb96

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

src/common.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33

44
use datafusion::arrow::array::{
55
downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray,
6-
PrimitiveArray, RunArray, StringArray, StringViewArray,
6+
PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray,
77
};
88
use datafusion::arrow::compute::kernels::cast;
99
use datafusion::arrow::compute::take;
@@ -245,6 +245,34 @@ fn invoke_array_array<R: InvokeResult>(
245245
}
246246
}
247247

248+
/// Transform keys that may be pointing to values with nulls to nulls themselves.
249+
/// keys = `[0, 1, 2, 3]`, values = `[null, "a", null, "b"]`
250+
/// into
251+
/// keys = `[null, 0, null, 1]`, values = `["a", "b"]`
252+
///
253+
/// Arrow / `DataFusion` assumes that dictionary values do not contain nulls, nulls are encoded by the keys.
254+
/// Not following this invariant causes invalid dictionary arrays to be built later on inside of `DataFusion`
255+
/// when arrays are concacted and such.
256+
fn remap_dictionary_key_nulls(keys: PrimitiveArray<Int64Type>, values: ArrayRef) -> DictionaryArray<Int64Type> {
257+
// fast path: no nulls in values
258+
if values.null_count() == 0 {
259+
return DictionaryArray::new(keys, values);
260+
}
261+
262+
let mut new_keys_builder = PrimitiveBuilder::<Int64Type>::new();
263+
264+
for key in &keys {
265+
match key {
266+
Some(k) if values.is_null(k.as_usize()) => new_keys_builder.append_null(),
267+
Some(k) => new_keys_builder.append_value(k),
268+
None => new_keys_builder.append_null(),
269+
}
270+
}
271+
272+
let new_keys = new_keys_builder.finish();
273+
DictionaryArray::new(new_keys, values)
274+
}
275+
248276
fn invoke_array_scalars<R: InvokeResult>(
249277
json_array: &ArrayRef,
250278
path: &[JsonPath],
@@ -281,7 +309,7 @@ fn invoke_array_scalars<R: InvokeResult>(
281309
let type_ids = values.as_union().type_ids();
282310
keys = mask_dictionary_keys(&keys, type_ids);
283311
}
284-
Ok(Arc::new(DictionaryArray::new(keys, values)))
312+
Ok(Arc::new(remap_dictionary_key_nulls(keys, values)))
285313
} else {
286314
// this is what cast would do under the hood to unpack a dictionary into an array of its values
287315
Ok(take(&values, json_array.keys(), None)?)

tests/main.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::Arc;
22

3-
use datafusion::arrow::array::{ArrayRef, RecordBatch};
4-
use datafusion::arrow::datatypes::{Field, Int8Type, Schema};
3+
use datafusion::arrow::array::{Array, ArrayRef, DictionaryArray, RecordBatch};
4+
use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema};
55
use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType};
66
use datafusion::assert_batches_eq;
77
use datafusion::common::ScalarValue;
@@ -1280,6 +1280,68 @@ async fn test_dict_haystack() {
12801280
assert_batches_eq!(expected, &batches);
12811281
}
12821282

1283+
fn check_for_null_dictionary_values(array: &dyn Array) {
1284+
let array = array.as_any().downcast_ref::<DictionaryArray<Int64Type>>().unwrap();
1285+
let keys_array = array.keys();
1286+
let keys = keys_array
1287+
.iter()
1288+
.filter_map(|x| x.map(|v| usize::try_from(v).unwrap()))
1289+
.collect::<Vec<_>>();
1290+
let values_array = array.values();
1291+
// no non-null keys should point to a null value
1292+
for i in 0..values_array.len() {
1293+
if values_array.is_null(i) {
1294+
// keys should not contain
1295+
if keys.contains(&i) {
1296+
println!("keys: {:?}", keys);
1297+
println!("values: {:?}", values_array);
1298+
panic!("keys should not contain null values");
1299+
}
1300+
}
1301+
}
1302+
}
1303+
1304+
/// Test that we don't output nulls in dictionary values.
1305+
/// This can cause issues with arrow-rs and DataFusion; they expect nulls to be in keys.
1306+
#[tokio::test]
1307+
async fn test_dict_get_no_null_values() {
1308+
let ctx = build_dict_schema().await;
1309+
1310+
let sql = "select json_get(x, 'baz') v from data";
1311+
let expected = [
1312+
"+------------+",
1313+
"| v |",
1314+
"+------------+",
1315+
"| |",
1316+
"| {str=fizz} |",
1317+
"| |",
1318+
"| {str=abcd} |",
1319+
"| |",
1320+
"| {str=fizz} |",
1321+
"| {str=fizz} |",
1322+
"| {str=fizz} |",
1323+
"| {str=fizz} |",
1324+
"| |",
1325+
"+------------+",
1326+
];
1327+
let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap();
1328+
assert_batches_eq!(expected, &batches);
1329+
for batch in batches {
1330+
check_for_null_dictionary_values(batch.column(0).as_ref());
1331+
}
1332+
1333+
let sql = "select json_get_str(x, 'baz') v from data";
1334+
let expected = [
1335+
"+------+", "| v |", "+------+", "| |", "| fizz |", "| |", "| abcd |", "| |", "| fizz |",
1336+
"| fizz |", "| fizz |", "| fizz |", "| |", "+------+",
1337+
];
1338+
let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap();
1339+
assert_batches_eq!(expected, &batches);
1340+
for batch in batches {
1341+
check_for_null_dictionary_values(batch.column(0).as_ref());
1342+
}
1343+
}
1344+
12831345
#[tokio::test]
12841346
async fn test_dict_haystack_filter() {
12851347
let sql = "select json_data v from dicts where json_get(json_data, 'foo') is not null";

0 commit comments

Comments
 (0)