diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index dfd56996e1c15..def9c58160002 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -113,16 +113,16 @@ static void selfTest() { @Param({ BYTES_REF, INT, LONG }) public String dataType; - private static Operator operator(DriverContext driverContext, int groups, String dataType) { + private static Operator operator(DriverContext driverContext, int groups, String dataType, AggregatorMode mode) { if (groups == 1) { return new AggregationOperator( - List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), + List.of(supplier(dataType).aggregatorFactory(mode, List.of(0)).apply(driverContext)), driverContext ); } List groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG)); return new HashAggregationOperator( - List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), + List.of(supplier(dataType).groupingAggregatorFactory(mode, List.of(1))), () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), driverContext ) { @@ -177,6 +177,9 @@ private static void checkGrouped(String prefix, int groups, String dataType, Pag // Check them BytesRefBlock values = page.getBlock(1); + if (values.asOrdinals() == null) { + throw new AssertionError(" expected ordinals; but got " + values); + } for (int p = 0; p < groups; p++) { checkExpectedBytesRef(prefix, values, p, expected.get(p)); } @@ -341,13 +344,21 @@ public void run() { private static void run(int groups, String dataType, int opCount) { DriverContext driverContext = driverContext(); - try (Operator operator = operator(driverContext, groups, dataType)) { - Page page = page(groups, dataType); - for (int i = 0; i < opCount; i++) { - operator.addInput(page.shallowCopy()); + try (Operator finalAggregator = operator(driverContext, groups, dataType, AggregatorMode.FINAL)) { + try (Operator initialAggregator = operator(driverContext, groups, dataType, AggregatorMode.INITIAL)) { + Page rawPage = page(groups, dataType); + for (int i = 0; i < opCount; i++) { + initialAggregator.addInput(rawPage.shallowCopy()); + } + initialAggregator.finish(); + Page intermediatePage = initialAggregator.getOutput(); + for (int i = 0; i < opCount; i++) { + finalAggregator.addInput(intermediatePage.shallowCopy()); + } } - operator.finish(); - checkExpected(groups, dataType, operator.getOutput()); + finalAggregator.finish(); + Page outputPage = finalAggregator.getOutput(); + checkExpected(groups, dataType, outputPage); } } diff --git a/docs/changelog/131390.yaml b/docs/changelog/131390.yaml new file mode 100644 index 0000000000000..849adcb1a173a --- /dev/null +++ b/docs/changelog/131390.yaml @@ -0,0 +1,5 @@ +pr: 131390 +summary: Add optimized path for intermediate values aggregator +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 8dc6f594ca47a..6042915f70aee 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -58,6 +58,7 @@ import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC; import static org.elasticsearch.compute.gen.Types.INT_ARRAY_BLOCK; import static org.elasticsearch.compute.gen.Types.INT_BIG_ARRAY_BLOCK; +import static org.elasticsearch.compute.gen.Types.INT_BLOCK; import static org.elasticsearch.compute.gen.Types.INT_VECTOR; import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC; import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; @@ -609,77 +610,98 @@ private MethodSpec addIntermediateInput(TypeName groupsType) { .collect(joining(" && ")) ); } - if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { - builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); - } - builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); - { - if (groupsIsBlock) { - builder.beginControlFlow("if (groups.isNull(groupPosition))"); - builder.addStatement("continue"); - builder.endControlFlow(); - builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); - builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); - builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); - builder.addStatement("int groupId = groups.getInt(g)"); - } else { - builder.addStatement("int groupId = groups.getInt(groupPosition)"); + var bulkCombineIntermediateMethod = optionalStaticMethod( + declarationType, + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.concat( + // aggState, positionOffset, groupIds + Stream.of(aggState.declaredType(), TypeName.INT, groupsIsBlock ? INT_BLOCK : INT_VECTOR), + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType) + ).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) + ) + ); + if (bulkCombineIntermediateMethod != null) { + var states = intermediateState.stream() + .map(AggregatorImplementer.IntermediateStateDesc::name) + .collect(Collectors.joining(", ")); + builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType); + } else { + if (intermediateState.stream() + .map(AggregatorImplementer.IntermediateStateDesc::elementType) + .anyMatch(n -> n.equals("BYTES_REF"))) { + builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); } - - if (aggState.declaredType().isPrimitive()) { - if (warnExceptions.isEmpty()) { - assert intermediateState.size() == 2; - assert intermediateState.get(1).name().equals("seen"); - builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); + builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); + { + if (groupsIsBlock) { + builder.beginControlFlow("if (groups.isNull(groupPosition))"); + builder.addStatement("continue"); + builder.endControlFlow(); + builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); + builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); + builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); + builder.addStatement("int groupId = groups.getInt(g)"); } else { - assert intermediateState.size() == 3; - assert intermediateState.get(1).name().equals("seen"); - assert intermediateState.get(2).name().equals("failed"); - builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); - { - builder.addStatement("state.setFailed(groupId)"); - } - builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); + builder.addStatement("int groupId = groups.getInt(groupPosition)"); } - warningsBlock(builder, () -> { - var name = intermediateState.get(0).name(); - var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); - builder.addStatement( - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", + if (aggState.declaredType().isPrimitive()) { + if (warnExceptions.isEmpty()) { + assert intermediateState.size() == 2; + assert intermediateState.get(1).name().equals("seen"); + builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); + } else { + assert intermediateState.size() == 3; + assert intermediateState.get(1).name().equals("seen"); + assert intermediateState.get(2).name().equals("failed"); + builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); + { + builder.addStatement("state.setFailed(groupId)"); + } + builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); + } + + warningsBlock(builder, () -> { + var name = intermediateState.get(0).name(); + var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); + builder.addStatement( + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", + declarationType, + name, + vectorAccessor + ); + }); + builder.endControlFlow(); + } else { + var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); + requireStaticMethod( declarationType, - name, - vectorAccessor + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.of( + Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), + Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position + ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) + ) ); - }); - builder.endControlFlow(); - } else { - var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); - requireStaticMethod( - declarationType, - requireVoidType(), - requireName("combineIntermediate"), - requireArgs( - Stream.of( - Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId - intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), - Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position - ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) - ) - ); - builder.addStatement( - "$T.combineIntermediate(state, groupId, " - + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) - + (stateHasBlock ? ", groupPosition + positionOffset" : "") - + ")", - declarationType - ); - } - if (groupsIsBlock) { + builder.addStatement( + "$T.combineIntermediate(state, groupId, " + + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) + + (stateHasBlock ? ", groupPosition + positionOffset" : "") + + ")", + declarationType + ); + } + if (groupsIsBlock) { + builder.endControlFlow(); + } builder.endControlFlow(); } - builder.endControlFlow(); } return builder.build(); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 337c8cde768f9..79077b6628105 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -80,13 +80,12 @@ public static void combine(GroupingState state, int groupId, BytesRef v) { state.addValue(groupId, v); } - public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { - BytesRef scratch = new BytesRef(); - int start = values.getFirstValueIndex(valuesPosition); - int end = start + values.getValueCount(valuesPosition); - for (int i = start; i < end; i++) { - state.addValue(groupId, values.getBytesRef(i, scratch)); - } + public static void combineIntermediate(GroupingState state, int positionOffset, IntVector groups, BytesRefBlock values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); + } + + public static void combineIntermediate(GroupingState state, int positionOffset, IntBlock groups, BytesRefBlock values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); } public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { @@ -199,7 +198,7 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } try (var sorted = buildSorted(selected)) { - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); } else { return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index da8e93f9cf61a..b76f52d335a03 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -152,18 +152,7 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page return; } BytesRefBlock values = (BytesRefBlock) valuesUncast; - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); - } - } + ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values); } private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { @@ -209,18 +198,7 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa return; } BytesRefBlock values = (BytesRefBlock) valuesUncast; - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); - } - } + ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values); } private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { @@ -255,11 +233,7 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } BytesRefBlock values = (BytesRefBlock) valuesUncast; - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); - } + ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java index 4a2fa0923abe4..20ec7b08f9cb7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -28,15 +28,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput( if (valuesOrdinal == null) { return delegate; } - BytesRefVector dict = valuesOrdinal.getDictionaryVector(); - final IntVector hashIds; - BytesRef spare = new BytesRef(); - try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { - for (int p = 0; p < dict.getPositionCount(); p++) { - hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); - } - hashIds = hashIdsBuilder.build(); - } + final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector()); IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock(); return new GroupingAggregatorFunction.AddInput() { @Override @@ -85,17 +77,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { @Override public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int groupId = groupIds.getInt(groupPosition); - if (ordinalIds.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v))); - } - } + addOrdinalInputBlock(state, positionOffset, groupIds, ordinalIds, hashIds); } @Override @@ -114,15 +96,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput( if (valuesOrdinal == null) { return delegate; } - BytesRefVector dict = valuesOrdinal.getDictionaryVector(); - final IntVector hashIds; - BytesRef spare = new BytesRef(); - try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { - for (int p = 0; p < dict.getPositionCount(); p++) { - hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); - } - hashIds = hashIdsBuilder.build(); - } + final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector()); var ordinalIds = valuesOrdinal.getOrdinalsVector(); return new GroupingAggregatorFunction.AddInput() { @Override @@ -157,10 +131,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { @Override public void add(int positionOffset, IntVector groupIds) { - for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { - int groupId = groupIds.getInt(groupPosition); - state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); - } + addOrdinalInputVector(state, positionOffset, groupIds, ordinalIds, hashIds); } @Override @@ -169,4 +140,114 @@ public void close() { } }; } + + static IntVector hashDict(ValuesBytesRefAggregator.GroupingState state, BytesRefVector dict) { + BytesRef scratch = new BytesRef(); + try (var hashIdsBuilder = dict.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { + for (int p = 0; p < dict.getPositionCount(); p++) { + final long hashId = BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, scratch))); + hashIdsBuilder.appendInt(Math.toIntExact(hashId)); + } + return hashIdsBuilder.build(); + } + } + + static void addOrdinalInputBlock( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntVector groupIds, + IntBlock ordinalIds, + IntVector hashIds + ) { + for (int p = 0; p < groupIds.getPositionCount(); p++) { + final int valuePosition = p + positionOffset; + final int groupId = groupIds.getInt(valuePosition); + final int start = ordinalIds.getFirstValueIndex(valuePosition); + final int end = start + ordinalIds.getValueCount(valuePosition); + for (int i = start; i < end; i++) { + int ord = ordinalIds.getInt(i); + state.addValueOrdinal(groupId, hashIds.getInt(ord)); + } + } + } + + static void addOrdinalInputVector( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntVector groupIds, + IntVector ordinalIds, + IntVector hashIds + ) { + for (int p = 0; p < groupIds.getPositionCount(); p++) { + int groupId = groupIds.getInt(p); + int ord = ordinalIds.getInt(p + positionOffset); + state.addValueOrdinal(groupId, hashIds.getInt(ord)); + } + } + + static void combineIntermediateInputValues( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntVector groupIds, + BytesRefBlock values + ) { + BytesRefVector dict = null; + IntBlock ordinals = null; + { + final OrdinalBytesRefBlock asOrdinals = values.asOrdinals(); + if (asOrdinals != null) { + dict = asOrdinals.getDictionaryVector(); + ordinals = asOrdinals.getOrdinalsBlock(); + } + } + if (dict != null && dict.getPositionCount() < groupIds.getPositionCount()) { + try (var hashIds = hashDict(state, dict)) { + IntVector ordinalsVector = ordinals.asVector(); + if (ordinalsVector != null) { + addOrdinalInputVector(state, positionOffset, groupIds, ordinalsVector, hashIds); + } else { + addOrdinalInputBlock(state, positionOffset, groupIds, ordinals, hashIds); + } + } + } else { + final BytesRef scratch = new BytesRef(); + for (int p = 0; p < groupIds.getPositionCount(); p++) { + final int valuePosition = p + positionOffset; + final int groupId = groupIds.getInt(valuePosition); + final int start = values.getFirstValueIndex(valuePosition); + final int end = start + values.getValueCount(valuePosition); + for (int i = start; i < end; i++) { + state.addValue(groupId, values.getBytesRef(i, scratch)); + } + } + } + } + + static void combineIntermediateInputValues( + ValuesBytesRefAggregator.GroupingState state, + int positionOffset, + IntBlock groupIds, + BytesRefBlock values + ) { + final BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupId = groupIds.getInt(g); + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + var bytes = values.getBytesRef(v, scratch); + state.addValue(groupId, bytes); + } + } + } + } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index cc084644832ca..d92ac5fa0afce 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -113,20 +113,24 @@ $endif$ state.addValue(groupId, v); } - public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { $if(BytesRef)$ - BytesRef scratch = new BytesRef(); -$endif$ + public static void combineIntermediate(GroupingState state, int positionOffset, IntVector groups, $Type$Block values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); + } + + public static void combineIntermediate(GroupingState state, int positionOffset, IntBlock groups, $Type$Block values) { + ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values); + } + +$else$ + public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); for (int i = start; i < end; i++) { -$if(BytesRef)$ - state.addValue(groupId, values.getBytesRef(i, scratch)); -$else$ state.addValue(groupId, values.get$Type$(i)); -$endif$ } } +$endif$ public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); @@ -304,7 +308,7 @@ $endif$ try (var sorted = buildSorted(selected)) { $if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); } else { return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);