Skip to content

Commit 2564379

Browse files
authored
Add optimized path for intermediate values aggregator (#131390)
Similar to #127849, this change adds an optimized path for leveraging ordinal blocks of intermediate input pages in the Values aggregator. Below are the micro-benchmark results. Before: ``` // 1 raw input page + 1000 intermediate input pages Benchmark (dataType) (groups) Mode Cnt Score Error Units ValuesAggregatorBenchmark.run BytesRef 1 avgt 2 0.382 ms/op ValuesAggregatorBenchmark.run BytesRef 1000 avgt 2 112.293 ms/op ValuesAggregatorBenchmark.run BytesRef 1000000 avgt 2 113182.908 ms/op ``` ``` After: // 1 raw input page + 1000 intermediate input pages Benchmark (dataType) (groups) Mode Cnt Score Error Units ValuesAggregatorBenchmark.run BytesRef 1 avgt 2 0.378 ms/op ValuesAggregatorBenchmark.run BytesRef 1000 avgt 2 34.410 ms/op ValuesAggregatorBenchmark.run BytesRef 1000000 avgt 2 64654.830 ms/op ``` 1K groups: 112 ms -> 34.4ms 1M groups: 113s -> 64s More to come with #130510 Relates #127849
1 parent 2381e5d commit 2564379

File tree

7 files changed

+245
-149
lines changed

7 files changed

+245
-149
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,16 @@ static void selfTest() {
113113
@Param({ BYTES_REF, INT, LONG })
114114
public String dataType;
115115

116-
private static Operator operator(DriverContext driverContext, int groups, String dataType) {
116+
private static Operator operator(DriverContext driverContext, int groups, String dataType, AggregatorMode mode) {
117117
if (groups == 1) {
118118
return new AggregationOperator(
119-
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
119+
List.of(supplier(dataType).aggregatorFactory(mode, List.of(0)).apply(driverContext)),
120120
driverContext
121121
);
122122
}
123123
List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
124124
return new HashAggregationOperator(
125-
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
125+
List.of(supplier(dataType).groupingAggregatorFactory(mode, List.of(1))),
126126
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
127127
driverContext
128128
) {
@@ -177,6 +177,9 @@ private static void checkGrouped(String prefix, int groups, String dataType, Pag
177177

178178
// Check them
179179
BytesRefBlock values = page.getBlock(1);
180+
if (values.asOrdinals() == null) {
181+
throw new AssertionError(" expected ordinals; but got " + values);
182+
}
180183
for (int p = 0; p < groups; p++) {
181184
checkExpectedBytesRef(prefix, values, p, expected.get(p));
182185
}
@@ -341,13 +344,21 @@ public void run() {
341344

342345
private static void run(int groups, String dataType, int opCount) {
343346
DriverContext driverContext = driverContext();
344-
try (Operator operator = operator(driverContext, groups, dataType)) {
345-
Page page = page(groups, dataType);
346-
for (int i = 0; i < opCount; i++) {
347-
operator.addInput(page.shallowCopy());
347+
try (Operator finalAggregator = operator(driverContext, groups, dataType, AggregatorMode.FINAL)) {
348+
try (Operator initialAggregator = operator(driverContext, groups, dataType, AggregatorMode.INITIAL)) {
349+
Page rawPage = page(groups, dataType);
350+
for (int i = 0; i < opCount; i++) {
351+
initialAggregator.addInput(rawPage.shallowCopy());
352+
}
353+
initialAggregator.finish();
354+
Page intermediatePage = initialAggregator.getOutput();
355+
for (int i = 0; i < opCount; i++) {
356+
finalAggregator.addInput(intermediatePage.shallowCopy());
357+
}
348358
}
349-
operator.finish();
350-
checkExpected(groups, dataType, operator.getOutput());
359+
finalAggregator.finish();
360+
Page outputPage = finalAggregator.getOutput();
361+
checkExpected(groups, dataType, outputPage);
351362
}
352363
}
353364

docs/changelog/131390.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 131390
2+
summary: Add optimized path for intermediate values aggregator
3+
area: ES|QL
4+
type: enhancement
5+
issues: []

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 84 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
5959
import static org.elasticsearch.compute.gen.Types.INT_ARRAY_BLOCK;
6060
import static org.elasticsearch.compute.gen.Types.INT_BIG_ARRAY_BLOCK;
61+
import static org.elasticsearch.compute.gen.Types.INT_BLOCK;
6162
import static org.elasticsearch.compute.gen.Types.INT_VECTOR;
6263
import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
6364
import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
@@ -609,77 +610,98 @@ private MethodSpec addIntermediateInput(TypeName groupsType) {
609610
.collect(joining(" && "))
610611
);
611612
}
612-
if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
613-
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
614-
}
615-
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
616-
{
617-
if (groupsIsBlock) {
618-
builder.beginControlFlow("if (groups.isNull(groupPosition))");
619-
builder.addStatement("continue");
620-
builder.endControlFlow();
621-
builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)");
622-
builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)");
623-
builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)");
624-
builder.addStatement("int groupId = groups.getInt(g)");
625-
} else {
626-
builder.addStatement("int groupId = groups.getInt(groupPosition)");
613+
var bulkCombineIntermediateMethod = optionalStaticMethod(
614+
declarationType,
615+
requireVoidType(),
616+
requireName("combineIntermediate"),
617+
requireArgs(
618+
Stream.concat(
619+
// aggState, positionOffset, groupIds
620+
Stream.of(aggState.declaredType(), TypeName.INT, groupsIsBlock ? INT_BLOCK : INT_VECTOR),
621+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType)
622+
).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
623+
)
624+
);
625+
if (bulkCombineIntermediateMethod != null) {
626+
var states = intermediateState.stream()
627+
.map(AggregatorImplementer.IntermediateStateDesc::name)
628+
.collect(Collectors.joining(", "));
629+
builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType);
630+
} else {
631+
if (intermediateState.stream()
632+
.map(AggregatorImplementer.IntermediateStateDesc::elementType)
633+
.anyMatch(n -> n.equals("BYTES_REF"))) {
634+
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
627635
}
628-
629-
if (aggState.declaredType().isPrimitive()) {
630-
if (warnExceptions.isEmpty()) {
631-
assert intermediateState.size() == 2;
632-
assert intermediateState.get(1).name().equals("seen");
633-
builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
636+
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
637+
{
638+
if (groupsIsBlock) {
639+
builder.beginControlFlow("if (groups.isNull(groupPosition))");
640+
builder.addStatement("continue");
641+
builder.endControlFlow();
642+
builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)");
643+
builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)");
644+
builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)");
645+
builder.addStatement("int groupId = groups.getInt(g)");
634646
} else {
635-
assert intermediateState.size() == 3;
636-
assert intermediateState.get(1).name().equals("seen");
637-
assert intermediateState.get(2).name().equals("failed");
638-
builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
639-
{
640-
builder.addStatement("state.setFailed(groupId)");
641-
}
642-
builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
647+
builder.addStatement("int groupId = groups.getInt(groupPosition)");
643648
}
644649

645-
warningsBlock(builder, () -> {
646-
var name = intermediateState.get(0).name();
647-
var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
648-
builder.addStatement(
649-
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
650+
if (aggState.declaredType().isPrimitive()) {
651+
if (warnExceptions.isEmpty()) {
652+
assert intermediateState.size() == 2;
653+
assert intermediateState.get(1).name().equals("seen");
654+
builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
655+
} else {
656+
assert intermediateState.size() == 3;
657+
assert intermediateState.get(1).name().equals("seen");
658+
assert intermediateState.get(2).name().equals("failed");
659+
builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
660+
{
661+
builder.addStatement("state.setFailed(groupId)");
662+
}
663+
builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
664+
}
665+
666+
warningsBlock(builder, () -> {
667+
var name = intermediateState.get(0).name();
668+
var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
669+
builder.addStatement(
670+
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
671+
declarationType,
672+
name,
673+
vectorAccessor
674+
);
675+
});
676+
builder.endControlFlow();
677+
} else {
678+
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
679+
requireStaticMethod(
650680
declarationType,
651-
name,
652-
vectorAccessor
681+
requireVoidType(),
682+
requireName("combineIntermediate"),
683+
requireArgs(
684+
Stream.of(
685+
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
686+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
687+
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
688+
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
689+
)
653690
);
654-
});
655-
builder.endControlFlow();
656-
} else {
657-
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
658-
requireStaticMethod(
659-
declarationType,
660-
requireVoidType(),
661-
requireName("combineIntermediate"),
662-
requireArgs(
663-
Stream.of(
664-
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
665-
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
666-
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
667-
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
668-
)
669-
);
670691

671-
builder.addStatement(
672-
"$T.combineIntermediate(state, groupId, "
673-
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
674-
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
675-
+ ")",
676-
declarationType
677-
);
678-
}
679-
if (groupsIsBlock) {
692+
builder.addStatement(
693+
"$T.combineIntermediate(state, groupId, "
694+
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
695+
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
696+
+ ")",
697+
declarationType
698+
);
699+
}
700+
if (groupsIsBlock) {
701+
builder.endControlFlow();
702+
}
680703
builder.endControlFlow();
681704
}
682-
builder.endControlFlow();
683705
}
684706
return builder.build();
685707
}

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java

Lines changed: 7 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java

Lines changed: 3 additions & 29 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)