|
58 | 58 | import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
|
59 | 59 | import static org.elasticsearch.compute.gen.Types.INT_ARRAY_BLOCK;
|
60 | 60 | import static org.elasticsearch.compute.gen.Types.INT_BIG_ARRAY_BLOCK;
|
| 61 | +import static org.elasticsearch.compute.gen.Types.INT_BLOCK; |
61 | 62 | import static org.elasticsearch.compute.gen.Types.INT_VECTOR;
|
62 | 63 | import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
|
63 | 64 | import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
|
@@ -609,77 +610,98 @@ private MethodSpec addIntermediateInput(TypeName groupsType) {
|
609 | 610 | .collect(joining(" && "))
|
610 | 611 | );
|
611 | 612 | }
|
612 |
| - if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { |
613 |
| - builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); |
614 |
| - } |
615 |
| - builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); |
616 |
| - { |
617 |
| - if (groupsIsBlock) { |
618 |
| - builder.beginControlFlow("if (groups.isNull(groupPosition))"); |
619 |
| - builder.addStatement("continue"); |
620 |
| - builder.endControlFlow(); |
621 |
| - builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); |
622 |
| - builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); |
623 |
| - builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); |
624 |
| - builder.addStatement("int groupId = groups.getInt(g)"); |
625 |
| - } else { |
626 |
| - builder.addStatement("int groupId = groups.getInt(groupPosition)"); |
| 613 | + var bulkCombineIntermediateMethod = optionalStaticMethod( |
| 614 | + declarationType, |
| 615 | + requireVoidType(), |
| 616 | + requireName("combineIntermediate"), |
| 617 | + requireArgs( |
| 618 | + Stream.concat( |
| 619 | + // aggState, positionOffset, groupIds |
| 620 | + Stream.of(aggState.declaredType(), TypeName.INT, groupsIsBlock ? INT_BLOCK : INT_VECTOR), |
| 621 | + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType) |
| 622 | + ).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) |
| 623 | + ) |
| 624 | + ); |
| 625 | + if (bulkCombineIntermediateMethod != null) { |
| 626 | + var states = intermediateState.stream() |
| 627 | + .map(AggregatorImplementer.IntermediateStateDesc::name) |
| 628 | + .collect(Collectors.joining(", ")); |
| 629 | + builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType); |
| 630 | + } else { |
| 631 | + if (intermediateState.stream() |
| 632 | + .map(AggregatorImplementer.IntermediateStateDesc::elementType) |
| 633 | + .anyMatch(n -> n.equals("BYTES_REF"))) { |
| 634 | + builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); |
627 | 635 | }
|
628 |
| - |
629 |
| - if (aggState.declaredType().isPrimitive()) { |
630 |
| - if (warnExceptions.isEmpty()) { |
631 |
| - assert intermediateState.size() == 2; |
632 |
| - assert intermediateState.get(1).name().equals("seen"); |
633 |
| - builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); |
| 636 | + builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); |
| 637 | + { |
| 638 | + if (groupsIsBlock) { |
| 639 | + builder.beginControlFlow("if (groups.isNull(groupPosition))"); |
| 640 | + builder.addStatement("continue"); |
| 641 | + builder.endControlFlow(); |
| 642 | + builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); |
| 643 | + builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); |
| 644 | + builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); |
| 645 | + builder.addStatement("int groupId = groups.getInt(g)"); |
634 | 646 | } else {
|
635 |
| - assert intermediateState.size() == 3; |
636 |
| - assert intermediateState.get(1).name().equals("seen"); |
637 |
| - assert intermediateState.get(2).name().equals("failed"); |
638 |
| - builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); |
639 |
| - { |
640 |
| - builder.addStatement("state.setFailed(groupId)"); |
641 |
| - } |
642 |
| - builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); |
| 647 | + builder.addStatement("int groupId = groups.getInt(groupPosition)"); |
643 | 648 | }
|
644 | 649 |
|
645 |
| - warningsBlock(builder, () -> { |
646 |
| - var name = intermediateState.get(0).name(); |
647 |
| - var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); |
648 |
| - builder.addStatement( |
649 |
| - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", |
| 650 | + if (aggState.declaredType().isPrimitive()) { |
| 651 | + if (warnExceptions.isEmpty()) { |
| 652 | + assert intermediateState.size() == 2; |
| 653 | + assert intermediateState.get(1).name().equals("seen"); |
| 654 | + builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))"); |
| 655 | + } else { |
| 656 | + assert intermediateState.size() == 3; |
| 657 | + assert intermediateState.get(1).name().equals("seen"); |
| 658 | + assert intermediateState.get(2).name().equals("failed"); |
| 659 | + builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))"); |
| 660 | + { |
| 661 | + builder.addStatement("state.setFailed(groupId)"); |
| 662 | + } |
| 663 | + builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))"); |
| 664 | + } |
| 665 | + |
| 666 | + warningsBlock(builder, () -> { |
| 667 | + var name = intermediateState.get(0).name(); |
| 668 | + var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType()); |
| 669 | + builder.addStatement( |
| 670 | + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))", |
| 671 | + declarationType, |
| 672 | + name, |
| 673 | + vectorAccessor |
| 674 | + ); |
| 675 | + }); |
| 676 | + builder.endControlFlow(); |
| 677 | + } else { |
| 678 | + var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); |
| 679 | + requireStaticMethod( |
650 | 680 | declarationType,
|
651 |
| - name, |
652 |
| - vectorAccessor |
| 681 | + requireVoidType(), |
| 682 | + requireName("combineIntermediate"), |
| 683 | + requireArgs( |
| 684 | + Stream.of( |
| 685 | + Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId |
| 686 | + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), |
| 687 | + Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position |
| 688 | + ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) |
| 689 | + ) |
653 | 690 | );
|
654 |
| - }); |
655 |
| - builder.endControlFlow(); |
656 |
| - } else { |
657 |
| - var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); |
658 |
| - requireStaticMethod( |
659 |
| - declarationType, |
660 |
| - requireVoidType(), |
661 |
| - requireName("combineIntermediate"), |
662 |
| - requireArgs( |
663 |
| - Stream.of( |
664 |
| - Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId |
665 |
| - intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), |
666 |
| - Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position |
667 |
| - ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) |
668 |
| - ) |
669 |
| - ); |
670 | 691 |
|
671 |
| - builder.addStatement( |
672 |
| - "$T.combineIntermediate(state, groupId, " |
673 |
| - + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) |
674 |
| - + (stateHasBlock ? ", groupPosition + positionOffset" : "") |
675 |
| - + ")", |
676 |
| - declarationType |
677 |
| - ); |
678 |
| - } |
679 |
| - if (groupsIsBlock) { |
| 692 | + builder.addStatement( |
| 693 | + "$T.combineIntermediate(state, groupId, " |
| 694 | + + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) |
| 695 | + + (stateHasBlock ? ", groupPosition + positionOffset" : "") |
| 696 | + + ")", |
| 697 | + declarationType |
| 698 | + ); |
| 699 | + } |
| 700 | + if (groupsIsBlock) { |
| 701 | + builder.endControlFlow(); |
| 702 | + } |
680 | 703 | builder.endControlFlow();
|
681 | 704 | }
|
682 |
| - builder.endControlFlow(); |
683 | 705 | }
|
684 | 706 | return builder.build();
|
685 | 707 | }
|
|
0 commit comments