Skip to content

Commit f854f10

Browse files
committed
Add optimized path for intermediate values aggregator
1 parent f135998 commit f854f10

File tree

5 files changed

+163
-92
lines changed

5 files changed

+163
-92
lines changed

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -609,63 +609,80 @@ private MethodSpec addIntermediateInput() {
609609
.collect(joining(" && "))
610610
);
611611
}
612-
if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
613-
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
614-
}
615-
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
616-
{
617-
builder.addStatement("int groupId = groups.getInt(groupPosition)");
618-
if (aggState.declaredType().isPrimitive()) {
619-
if (warnExceptions.isEmpty()) {
620-
assert intermediateState.size() == 2;
621-
assert intermediateState.get(1).name().equals("seen");
622-
builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
623-
} else {
624-
assert intermediateState.size() == 3;
625-
assert intermediateState.get(1).name().equals("seen");
626-
assert intermediateState.get(2).name().equals("failed");
627-
builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
628-
{
629-
builder.addStatement("state.setFailed(groupId)");
612+
var bulkCombineIntermediateMethod = optionalStaticMethod(
613+
declarationType,
614+
requireVoidType(),
615+
requireName("combineIntermediate"),
616+
requireArgs(
617+
Stream.of(
618+
Stream.of(aggState.declaredType(), TypeName.INT, INT_VECTOR), // aggState, positionOffset, groupIds
619+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType)
620+
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
621+
)
622+
);
623+
if (bulkCombineIntermediateMethod != null) {
624+
var states = intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::name).collect(Collectors.joining(","));
625+
builder.addStatement("$T.combineIntermediate(state, positionOffset, groups," + states + ")", declarationType);
626+
} else {
627+
if (intermediateState.stream()
628+
.map(AggregatorImplementer.IntermediateStateDesc::elementType)
629+
.anyMatch(n -> n.equals("BYTES_REF"))) {
630+
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
631+
}
632+
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
633+
{
634+
builder.addStatement("int groupId = groups.getInt(groupPosition)");
635+
if (aggState.declaredType().isPrimitive()) {
636+
if (warnExceptions.isEmpty()) {
637+
assert intermediateState.size() == 2;
638+
assert intermediateState.get(1).name().equals("seen");
639+
builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
640+
} else {
641+
assert intermediateState.size() == 3;
642+
assert intermediateState.get(1).name().equals("seen");
643+
assert intermediateState.get(2).name().equals("failed");
644+
builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
645+
{
646+
builder.addStatement("state.setFailed(groupId)");
647+
}
648+
builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
630649
}
631-
builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
632-
}
633650

634-
warningsBlock(builder, () -> {
635-
var name = intermediateState.get(0).name();
636-
var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
637-
builder.addStatement(
638-
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
651+
warningsBlock(builder, () -> {
652+
var name = intermediateState.get(0).name();
653+
var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
654+
builder.addStatement(
655+
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
656+
declarationType,
657+
name,
658+
vectorAccessor
659+
);
660+
});
661+
builder.endControlFlow();
662+
} else {
663+
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
664+
requireStaticMethod(
639665
declarationType,
640-
name,
641-
vectorAccessor
666+
requireVoidType(),
667+
requireName("combineIntermediate"),
668+
requireArgs(
669+
Stream.of(
670+
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
671+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
672+
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
673+
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
674+
)
675+
);
676+
builder.addStatement(
677+
"$T.combineIntermediate(state, groupId, "
678+
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
679+
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
680+
+ ")",
681+
declarationType
642682
);
643-
});
683+
}
644684
builder.endControlFlow();
645-
} else {
646-
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
647-
requireStaticMethod(
648-
declarationType,
649-
requireVoidType(),
650-
requireName("combineIntermediate"),
651-
requireArgs(
652-
Stream.of(
653-
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
654-
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
655-
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
656-
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
657-
)
658-
);
659-
660-
builder.addStatement(
661-
"$T.combineIntermediate(state, groupId, "
662-
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
663-
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
664-
+ ")",
665-
declarationType
666-
);
667685
}
668-
builder.endControlFlow();
669686
}
670687
return builder.build();
671688
}

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java

Lines changed: 2 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java

Lines changed: 1 addition & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java

Lines changed: 85 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput(
2828
if (valuesOrdinal == null) {
2929
return delegate;
3030
}
31-
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
32-
final IntVector hashIds;
33-
BytesRef spare = new BytesRef();
34-
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
35-
for (int p = 0; p < dict.getPositionCount(); p++) {
36-
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
37-
}
38-
hashIds = hashIdsBuilder.build();
39-
}
31+
final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector());
4032
IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock();
4133
return new GroupingAggregatorFunction.AddInput() {
4234
@Override
@@ -114,15 +106,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput(
114106
if (valuesOrdinal == null) {
115107
return delegate;
116108
}
117-
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
118-
final IntVector hashIds;
119-
BytesRef spare = new BytesRef();
120-
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
121-
for (int p = 0; p < dict.getPositionCount(); p++) {
122-
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
123-
}
124-
hashIds = hashIdsBuilder.build();
125-
}
109+
final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector());
126110
var ordinalIds = valuesOrdinal.getOrdinalsVector();
127111
return new GroupingAggregatorFunction.AddInput() {
128112
@Override
@@ -157,10 +141,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) {
157141

158142
@Override
159143
public void add(int positionOffset, IntVector groupIds) {
160-
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
161-
int groupId = groupIds.getInt(groupPosition);
162-
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
163-
}
144+
addOrdinalInputVector(state, positionOffset, groupIds, ordinalIds, hashIds);
164145
}
165146

166147
@Override
@@ -169,4 +150,86 @@ public void close() {
169150
}
170151
};
171152
}
153+
154+
static IntVector hashDict(ValuesBytesRefAggregator.GroupingState state, BytesRefVector dict) {
155+
BytesRef scratch = new BytesRef();
156+
try (var hashIdsBuilder = dict.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
157+
for (int p = 0; p < dict.getPositionCount(); p++) {
158+
final long hashId = BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, scratch)));
159+
hashIdsBuilder.appendInt(Math.toIntExact(hashId));
160+
}
161+
return hashIdsBuilder.build();
162+
}
163+
}
164+
165+
static void addOrdinalInputBlock(
166+
ValuesBytesRefAggregator.GroupingState state,
167+
int positionOffset,
168+
IntVector groupIds,
169+
IntBlock ordinalIds,
170+
IntVector hashIds
171+
) {
172+
for (int p = 0; p < groupIds.getPositionCount(); p++) {
173+
final int valuePosition = p + positionOffset;
174+
final int groupId = groupIds.getInt(valuePosition);
175+
final int start = ordinalIds.getFirstValueIndex(valuePosition);
176+
final int end = start + ordinalIds.getValueCount(valuePosition);
177+
for (int i = start; i < end; i++) {
178+
int ord = ordinalIds.getInt(i);
179+
state.addValueOrdinal(groupId, hashIds.getInt(ord));
180+
}
181+
}
182+
}
183+
184+
static void addOrdinalInputVector(
185+
ValuesBytesRefAggregator.GroupingState state,
186+
int positionOffset,
187+
IntVector groupIds,
188+
IntVector ordinalIds,
189+
IntVector hashIds
190+
) {
191+
for (int p = 0; p < groupIds.getPositionCount(); p++) {
192+
int groupId = groupIds.getInt(p);
193+
int ord = ordinalIds.getInt(p + positionOffset);
194+
state.addValueOrdinal(groupId, hashIds.getInt(ord));
195+
}
196+
}
197+
198+
static void combineIntermediateInputValues(
199+
ValuesBytesRefAggregator.GroupingState state,
200+
int positionOffset,
201+
IntVector groupIds,
202+
BytesRefBlock values
203+
) {
204+
BytesRefVector dict = null;
205+
IntBlock ordinals = null;
206+
{
207+
final OrdinalBytesRefBlock asOrdinals = values.asOrdinals();
208+
if (asOrdinals != null) {
209+
dict = asOrdinals.getDictionaryVector();
210+
ordinals = asOrdinals.getOrdinalsBlock();
211+
}
212+
}
213+
if (dict != null && dict.getPositionCount() < groupIds.getPositionCount()) {
214+
try (var hashIds = hashDict(state, dict)) {
215+
IntVector ordinalsVector = ordinals.asVector();
216+
if (ordinalsVector != null) {
217+
addOrdinalInputVector(state, positionOffset, groupIds, ordinalsVector, hashIds);
218+
} else {
219+
addOrdinalInputBlock(state, positionOffset, groupIds, ordinals, hashIds);
220+
}
221+
}
222+
} else {
223+
final BytesRef scratch = new BytesRef();
224+
for (int p = 0; p < groupIds.getPositionCount(); p++) {
225+
final int valuePosition = p + positionOffset;
226+
final int groupId = groupIds.getInt(valuePosition);
227+
final int start = values.getFirstValueIndex(valuePosition);
228+
final int end = start + values.getValueCount(valuePosition);
229+
for (int i = start; i < end; i++) {
230+
state.addValue(groupId, values.getBytesRef(i, scratch));
231+
}
232+
}
233+
}
234+
}
172235
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ $endif$
113113
state.addValue(groupId, v);
114114
}
115115

116-
public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
117116
$if(BytesRef)$
118-
BytesRef scratch = new BytesRef();
119-
$endif$
117+
public static void combineIntermediate(GroupingState state, int positionOffset, IntVector groups, $Type$Block values) {
118+
ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values);
119+
}
120+
121+
$else$
122+
public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
120123
int start = values.getFirstValueIndex(valuesPosition);
121124
int end = start + values.getValueCount(valuesPosition);
122125
for (int i = start; i < end; i++) {
123-
$if(BytesRef)$
124-
state.addValue(groupId, values.getBytesRef(i, scratch));
125-
$else$
126126
state.addValue(groupId, values.get$Type$(i));
127-
$endif$
128127
}
129128
}
129+
$endif$
130130

131131
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
132132
if (statePosition > state.maxGroupId) {

0 commit comments

Comments
 (0)