Skip to content

Commit f087cf2

Browse files
Jasmine-geyl-lisen
andauthored
allowing tuple datatype for the 1st param of array_map (#907)
fix arrayMap with array of tuples with single argument Co-authored-by: Lisen <38773813+yl-lisen@users.noreply.github.com>
1 parent 985a012 commit f087cf2

File tree

5 files changed

+83
-10
lines changed

5 files changed

+83
-10
lines changed

src/Functions/array/FunctionArrayMapped.h

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <DataTypes/DataTypeFunction.h>
1818
#include <DataTypes/DataTypeLowCardinality.h>
1919
#include <DataTypes/DataTypeMap.h>
20+
#include <DataTypes/DataTypeTuple.h>
2021
#include <DataTypes/DataTypesNumber.h>
2122

2223
#include <Functions/FunctionHelpers.h>
@@ -115,6 +116,7 @@ class FunctionArrayMapped : public IFunction
115116
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
116117
"Function {} needs one argument with data", getName());
117118

119+
size_t tuple_argument_size = 0;
118120
size_t nested_types_count = is_argument_type_map ? (arguments.size() - 1) * 2 : (arguments.size() - 1);
119121
DataTypes nested_types(nested_types_count);
120122
for (size_t i = 0; i < arguments.size() - 1; ++i)
@@ -128,6 +130,10 @@ class FunctionArrayMapped : public IFunction
128130
getName(),
129131
argument_type_name,
130132
arguments[i + 1]->getName());
133+
134+
if (const auto * tuple_type = checkAndGetDataType<DataTypeTuple>(array_type->getNestedType().get()))
135+
tuple_argument_size = tuple_type->getElements().size();
136+
131137
if constexpr (is_argument_type_map)
132138
{
133139
nested_types[2 * i] = recursiveRemoveLowCardinality(array_type->getKeyType());
@@ -137,14 +143,41 @@ class FunctionArrayMapped : public IFunction
137143
{
138144
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
139145
}
146+
140147
}
141148

142149
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
143-
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
150+
if (!function_type)
151+
throw Exception(
152+
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
153+
"First argument for this overload of {} must be a function with {} arguments, found {} instead",
154+
getName(),
155+
nested_types.size(),
156+
arguments[0]->getName());
157+
158+
size_t num_function_arguments = function_type->getArgumentTypes().size();
159+
if (tuple_argument_size > 1
160+
&& tuple_argument_size == num_function_arguments)
161+
{
162+
assert(nested_types.size() == 1);
163+
164+
auto argument_type = nested_types[0];
165+
const auto & tuple_type = assert_cast<const DataTypeTuple &>(*argument_type);
166+
167+
nested_types.clear();
168+
nested_types.reserve(tuple_argument_size);
169+
170+
for (const auto & element : tuple_type.getElements())
171+
nested_types.push_back(element);
172+
}
173+
174+
if (num_function_arguments != nested_types.size())
144175
throw Exception(
145176
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
146177
"First argument for this overload of {} must be a function with {} arguments, found {} instead",
147-
getName(), nested_types.size(), arguments[0]->getName());
178+
getName(),
179+
nested_types.size(),
180+
arguments[0]->getName());
148181

149182
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
150183
}
@@ -268,6 +301,9 @@ class FunctionArrayMapped : public IFunction
268301
throw Exception("First argument for function " + getName() + " must be a function.",
269302
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
270303

304+
const auto & type_function = assert_cast<const DataTypeFunction &>(*arguments[0].type);
305+
size_t num_function_arguments = type_function.getArgumentTypes().size();
306+
271307
ColumnPtr offsets_column;
272308

273309
ColumnPtr column_first_array_ptr;
@@ -316,24 +352,45 @@ class FunctionArrayMapped : public IFunction
316352
getName());
317353
}
318354

355+
const auto * column_tuple = dynamic_cast<const DB::ColumnMap*>(column_array)
356+
? checkAndGetColumn<ColumnTuple>(&dynamic_cast<const DB::ColumnMap*>(column_array)->getNestedData())
357+
: (dynamic_cast<const DB::ColumnArray*>(column_array)
358+
? checkAndGetColumn<ColumnTuple>(&dynamic_cast<const DB::ColumnArray*>(column_array)->getData())
359+
: throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected ColumnMap or ColumnArray, but received a different type."));
360+
361+
size_t tuple_size = column_tuple ? column_tuple->getColumns().size() : 0;
362+
319363
if (i == 1)
320364
{
321365
column_first_array_ptr = column_array_ptr;
322366
column_first_array = column_array;
323367
}
324368

325-
if constexpr (is_argument_type_map)
369+
if (tuple_size > 1 && tuple_size == num_function_arguments)
326370
{
327-
arrays.emplace_back(ColumnWithTypeAndName(
328-
column_array->getNestedData().getColumnPtr(0), recursiveRemoveLowCardinality(array_type->getKeyType()), array_with_type_and_name.name+".key"));
329-
arrays.emplace_back(ColumnWithTypeAndName(
330-
column_array->getNestedData().getColumnPtr(1), recursiveRemoveLowCardinality(array_type->getValueType()), array_with_type_and_name.name+".value"));
371+
const auto & type_tuple = assert_cast<const DataTypeTuple &>(*array_type->getNestedType());
372+
const auto & tuple_names = type_tuple.getElementNames();
373+
374+
arrays.reserve(column_tuple->getColumns().size());
375+
for (size_t j = 0; j < tuple_size; ++j)
376+
arrays.emplace_back(ColumnWithTypeAndName(
377+
column_tuple->getColumnPtr(j), recursiveRemoveLowCardinality(type_tuple.getElement(j)), array_with_type_and_name.name + "." + tuple_names[j]));
331378
}
332379
else
333380
{
334-
arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
335-
recursiveRemoveLowCardinality(array_type->getNestedType()),
336-
array_with_type_and_name.name));
381+
if constexpr (is_argument_type_map)
382+
{
383+
arrays.emplace_back(ColumnWithTypeAndName(
384+
column_array->getNestedData().getColumnPtr(0), recursiveRemoveLowCardinality(array_type->getKeyType()), array_with_type_and_name.name+".key"));
385+
arrays.emplace_back(ColumnWithTypeAndName(
386+
column_array->getNestedData().getColumnPtr(1), recursiveRemoveLowCardinality(array_type->getValueType()), array_with_type_and_name.name+".value"));
387+
}
388+
else
389+
{
390+
arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
391+
recursiveRemoveLowCardinality(array_type->getNestedType()),
392+
array_with_type_and_name.name));
393+
}
337394
}
338395
}
339396

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[(1)]
2+
[1]
3+
[3]
4+
[3]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
SELECT arrayMap((x) -> x, [tuple(1)]);
2+
SELECT arrayMap((x) -> x.1, [tuple(1)]);
3+
SELECT arrayMap((x) -> x.1 + x.2, [tuple(1, 2)]);
4+
SELECT arrayMap((x, y) -> x + y, [tuple(1, 2)]);
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[(1)]
2+
[1]
3+
[3]
4+
[3]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
SELECT array_map((x) -> x, [tuple_cast(1)]);
2+
SELECT array_map((x) -> x.1, [tuple_cast(1)]);
3+
SELECT array_map((x) -> x.1 + x.2, [tuple_cast(1, 2)]);
4+
SELECT array_map((x, y) -> x + y, [tuple_cast(1, 2)]);

0 commit comments

Comments
 (0)