Skip to content

Commit 3195dda

Browse files
authored
support time_weighted functions (#872)
* support time_weighted functions * bug fix combinator function * combinator function * code revise and add tests * fix merge and serialize problems * code revise
1 parent 78e8aba commit 3195dda

8 files changed

+655
-3
lines changed

src/AggregateFunctions/AggregateFunctionAvgWeighted.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using AvgWeightedFieldType = std::conditional_t<is_decimal<T>,
1515
NearestFieldType<T>>>;
1616

1717
template <typename T, typename U>
18-
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
18+
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) >= sizeof(AvgWeightedFieldType<U>)),
1919
AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;
2020

2121
template <typename Value, typename Weight>
@@ -30,7 +30,7 @@ class AggregateFunctionAvgWeighted final :
3030

3131
using Numerator = typename Base::Numerator;
3232
using Denominator = typename Base::Denominator;
33-
using Fraction = typename Base::Fraction;
33+
using Fraction = typename Base::Fraction;
3434

3535
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
3636
{

src/AggregateFunctions/AggregateFunctionFactory.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,17 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
179179
query_context->addQueryFactoriesInfo(Context::QueryLogFactories::AggregateFunctionCombinator, combinator_name);
180180

181181
String nested_name = name.substr(0, name.size() - combinator_name.size());
182+
183+
if (combinator_name == "_time_weighted")
184+
{
185+
if (nested_name == "avg")
186+
nested_name = "avg_weighted";
187+
else if (nested_name == "median")
188+
nested_name = "median_exact_weighted";
189+
else
190+
throw Exception(ErrorCodes::ILLEGAL_AGGREGATION, "Unknown aggregate function '{}'", name);
191+
}
192+
182193
/// Nested identical combinators (i.e. uniqCombinedIfIf) is not
183194
/// supported (since they don't work -- silently).
184195
///
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
2+
#include <AggregateFunctions/AggregateFunctionTimeWeighted.h>
3+
#include <AggregateFunctions/Helpers.h>
4+
#include <AggregateFunctions/FactoryHelpers.h>
5+
#include <DataTypes/DataTypeDate.h>
6+
#include <DataTypes/DataTypeDate32.h>
7+
#include <DataTypes/DataTypeDateTime.h>
8+
#include <DataTypes/DataTypeDateTime64.h>
9+
#include <AggregateFunctions/IAggregateFunction.h>
10+
11+
#include <memory>
12+
13+
namespace DB
14+
{
15+
16+
namespace
17+
{
18+
19+
class AggregateFunctionCombinatorTimeWeighted final : public IAggregateFunctionCombinator
20+
{
21+
public:
22+
String getName() const override { return "_time_weighted"; }
23+
24+
DataTypes transformArguments(const DataTypes & arguments) const override
25+
{
26+
if (arguments.size() != 2 && arguments.size() != 3)
27+
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Incorrect number of arguments for aggregate function with {} suffix", this->getName());
28+
29+
const auto & data_type_time_weight = arguments[1];
30+
const WhichDataType t_dt(data_type_time_weight);
31+
32+
if (!t_dt.isDateOrDate32() && !t_dt.isDateTime() && !t_dt.isDateTime64())
33+
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Types {} are non-conforming as time weighted arguments for aggregate function {}", data_type_time_weight->getName(), this->getName());
34+
35+
if (arguments.size() == 3)
36+
{
37+
const auto & data_type_third_arg = arguments[2];
38+
39+
if(!data_type_third_arg->equals(*data_type_time_weight))
40+
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The second and the third argument should be the same for aggregate function {}, but now it's {} and {}", this->getName(), data_type_third_arg->getName(), data_type_time_weight->getName());
41+
}
42+
43+
return {arguments[0], std::make_shared<DataTypeUInt64>()};
44+
}
45+
46+
/// Decimal128 and Decimal256 aren't supported
47+
#define AT_SWITCH(LINE) \
48+
switch (which.idx) \
49+
{ \
50+
LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
51+
LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
52+
LINE(Decimal32); LINE(Decimal64); \
53+
LINE(Float32); LINE(Float64); \
54+
default: return nullptr; \
55+
}
56+
57+
// Not using helper functions because there are no templates for binary decimal/numeric function.
58+
template <class... TArgs>
59+
IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args) const
60+
{
61+
const WhichDataType which(first_type);
62+
63+
#define LINE(Type) \
64+
case TypeIndex::Type: return create<Type, TArgs...>(second_type, std::forward<TArgs>(args)...)
65+
AT_SWITCH(LINE)
66+
#undef LINE
67+
}
68+
template <class First, class ... TArgs>
69+
IAggregateFunction * create(const IDataType & second_type, TArgs && ... args) const
70+
{
71+
const WhichDataType which(second_type);
72+
73+
switch (which.idx)
74+
{
75+
case TypeIndex::Date: return new AggregateFunctionTimeWeighted<First, DataTypeDate::FieldType>(std::forward<TArgs>(args)...);
76+
case TypeIndex::Date32: return new AggregateFunctionTimeWeighted<First, DataTypeDate32::FieldType>(std::forward<TArgs>(args)...);
77+
case TypeIndex::DateTime: return new AggregateFunctionTimeWeighted<First, DataTypeDateTime::FieldType>(std::forward<TArgs>(args)...);
78+
case TypeIndex::DateTime64: return new AggregateFunctionTimeWeighted<First, DataTypeDateTime64::FieldType>(std::forward<TArgs>(args)...);
79+
default: return nullptr;
80+
}
81+
}
82+
83+
AggregateFunctionPtr transformAggregateFunction(
84+
const AggregateFunctionPtr & nested_function,
85+
const AggregateFunctionProperties &,
86+
const DataTypes & arguments,
87+
const Array & params) const override
88+
{
89+
AggregateFunctionPtr ptr;
90+
const auto & data_type = arguments[0];
91+
const auto & data_type_time_weight = arguments[1];
92+
ptr.reset(create(*data_type, *data_type_time_weight, nested_function, arguments, params));
93+
if(!ptr)
94+
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal argument types existed in {} function", this->getName());
95+
96+
return ptr;
97+
}
98+
};
99+
}
100+
101+
void registerAggregateFunctionCombinatorTimeWeighted(AggregateFunctionCombinatorFactory & factory)
102+
{
103+
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorTimeWeighted>());
104+
}
105+
}

0 commit comments

Comments
 (0)