diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index 25c69d0a47edb..5437881f2bcf9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -260,6 +260,7 @@ public final class SystemSessionProperties public static final String HYPERLOGLOG_STANDARD_ERROR_WARNING_THRESHOLD = "hyperloglog_standard_error_warning_threshold"; public static final String PREFER_MERGE_JOIN_FOR_SORTED_INPUTS = "prefer_merge_join_for_sorted_inputs"; public static final String PREFER_SORT_MERGE_JOIN = "prefer_sort_merge_join"; + public static final String ENABLE_SORTED_EXCHANGES = "enable_sorted_exchanges"; public static final String SEGMENTED_AGGREGATION_ENABLED = "segmented_aggregation_enabled"; public static final String USE_HISTORY_BASED_PLAN_STATISTICS = "use_history_based_plan_statistics"; public static final String TRACK_HISTORY_BASED_PLAN_STATISTICS = "track_history_based_plan_statistics"; @@ -1399,6 +1400,11 @@ public SystemSessionProperties( "Prefer sort merge join for all joins. A SortNode is added if input is not already sorted.", featuresConfig.isPreferSortMergeJoin(), true), + booleanProperty( + ENABLE_SORTED_EXCHANGES, + "(Experimental) Enable pushing sort operations down to exchange nodes for distributed queries", + featuresConfig.isEnableSortedExchanges(), + false), booleanProperty( SEGMENTED_AGGREGATION_ENABLED, "Enable segmented aggregation.", @@ -2932,6 +2938,11 @@ public static boolean preferSortMergeJoin(Session session) return session.getSystemProperty(PREFER_SORT_MERGE_JOIN, Boolean.class); } + public static boolean isEnableSortedExchanges(Session session) + { + return session.getSystemProperty(ENABLE_SORTED_EXCHANGES, Boolean.class); + } + public static boolean isSegmentedAggregationEnabled(Session session) { return session.getSystemProperty(SEGMENTED_AGGREGATION_ENABLED, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java index 598de9e487ce3..7073e7f49b2e6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java @@ -1211,6 +1211,7 @@ private static StageInfo pruneStatsFromStageInfo(StageInfo stage) plan.getPartitioning(), plan.getTableScanSchedulingOrder(), plan.getPartitioningScheme(), + plan.getOutputOrderingScheme(), plan.getStageExecutionDescriptor(), plan.isOutputTableWriterFragment(), plan.getStatsAndCosts().map(QueryStateMachine::pruneHistogramsFromStatsAndCosts), diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java index 98a9bd80502f5..a3b0ebf06b787 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java @@ -669,6 +669,7 @@ private Optional performRuntimeOptimizations(StreamingSubPlan subP fragment.getPartitioning(), scheduleOrder(newRoot), fragment.getPartitioningScheme(), + fragment.getOutputOrderingScheme(), fragment.getStageExecutionDescriptor(), fragment.isOutputTableWriterFragment(), estimatedStatsAndCosts, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 7214485a7d7a6..abd51e0b73d37 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -234,6 +234,7 @@ public class FeaturesConfig private boolean streamingForPartialAggregationEnabled; private boolean preferMergeJoinForSortedInputs; private boolean preferSortMergeJoin; + private boolean enableSortedExchanges; private boolean segmentedAggregationEnabled; private int maxStageCountForEagerScheduling = 25; @@ -2290,6 +2291,19 @@ public FeaturesConfig setPreferSortMergeJoin(boolean preferSortMergeJoin) return this; } + public boolean isEnableSortedExchanges() + { + return enableSortedExchanges; + } + + @Config("optimizer.experimental.enable-sorted-exchanges") + @ConfigDescription("(Experimental) Enable pushing sort operations down to exchange nodes for distributed queries") + public FeaturesConfig setEnableSortedExchanges(boolean enableSortedExchanges) + { + this.enableSortedExchanges = enableSortedExchanges; + return this; + } + public boolean isSegmentedAggregationEnabled() { return segmentedAggregationEnabled; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index bc77373a5ecbd..301a697a87891 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner; +import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.metadata.Metadata; @@ -26,6 +27,7 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.MetadataDeleteNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -93,6 +95,7 @@ public abstract class BasePlanFragmenter extends SimplePlanRewriter { + private static final Logger log = Logger.get(BasePlanFragmenter.class); private final Session session; private final Metadata metadata; private final PlanNodeIdAllocator idAllocator; @@ -158,6 +161,7 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan properties.getPartitioningHandle(), schedulingOrder, properties.getPartitioningScheme(), + properties.getOutputOrderingScheme(), StageExecutionDescriptor.ungroupedExecution(), outputTableWriterFragment, Optional.of(statsAndCosts.getForSubplan(root)), @@ -300,7 +304,21 @@ private PlanNode createRemoteStreamingExchange(ExchangeNode exchange, RewriteCon ImmutableList.Builder builder = ImmutableList.builder(); for (int sourceIndex = 0; sourceIndex < exchange.getSources().size(); sourceIndex++) { - FragmentProperties childProperties = new FragmentProperties(translateOutputLayout(partitioningScheme, exchange.getInputs().get(sourceIndex))); + PartitioningScheme childPartitioningScheme = translateOutputLayout(partitioningScheme, exchange.getInputs().get(sourceIndex)); + FragmentProperties childProperties = new FragmentProperties(childPartitioningScheme); + + // If the exchange has ordering requirements, translate them for the child fragment + Optional childOutputOrderingScheme = Optional.empty(); + if (exchange.getOrderingScheme().isPresent()) { + log.info("Found ordering scheme on ExchangeNode %s. Transferring to child", exchange.getId()); + childOutputOrderingScheme = exchange.getOrderingScheme(); + } + else { + log.info("No ordering scheme on ExchangeNode %s", exchange.getId()); + } + + // Set the output ordering scheme for the child fragment + childProperties.setOutputOrderingScheme(childOutputOrderingScheme); builder.add(buildSubPlan(exchange.getSources().get(sourceIndex), childProperties, context)); } @@ -435,11 +453,24 @@ public static class FragmentProperties private Optional partitioningHandle = Optional.empty(); private final Set partitionedSources = new HashSet<>(); + // Output ordering scheme for the fragment - this gets transferred to the PlanFragment + private Optional outputOrderingScheme = Optional.empty(); + public FragmentProperties(PartitioningScheme partitioningScheme) { this.partitioningScheme = partitioningScheme; } + public void setOutputOrderingScheme(Optional outputOrderingScheme) + { + this.outputOrderingScheme = requireNonNull(outputOrderingScheme, "outputOrderingScheme is null"); + } + + public Optional getOutputOrderingScheme() + { + return outputOrderingScheme; + } + public List getChildren() { return children; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java index 31e9c2347ecd8..1fdf87a5865b7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java @@ -17,6 +17,7 @@ import com.facebook.airlift.json.Codec; import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.StatsAndCosts; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PartitioningHandle; import com.facebook.presto.spi.plan.PartitioningScheme; import com.facebook.presto.spi.plan.PlanFragmentId; @@ -54,6 +55,10 @@ public class PlanFragment private final PartitioningScheme partitioningScheme; private final StageExecutionDescriptor stageExecutionDescriptor; + // Describes the ordering of the fragment's output data + // This is separate from partitioningScheme as ordering is orthogonal to partitioning + private final Optional outputOrderingScheme; + // Only true for output table writer and false for temporary table writers private final boolean outputTableWriterFragment; private final Optional statsAndCosts; @@ -73,6 +78,7 @@ public PlanFragment( @JsonProperty("partitioning") PartitioningHandle partitioning, @JsonProperty("tableScanSchedulingOrder") List tableScanSchedulingOrder, @JsonProperty("partitioningScheme") PartitioningScheme partitioningScheme, + @JsonProperty("outputOrderingScheme") Optional outputOrderingScheme, @JsonProperty("stageExecutionDescriptor") StageExecutionDescriptor stageExecutionDescriptor, @JsonProperty("outputTableWriterFragment") boolean outputTableWriterFragment, @JsonProperty("statsAndCosts") Optional statsAndCosts, @@ -84,6 +90,7 @@ public PlanFragment( this.partitioning = requireNonNull(partitioning, "partitioning is null"); this.tableScanSchedulingOrder = ImmutableList.copyOf(requireNonNull(tableScanSchedulingOrder, "tableScanSchedulingOrder is null")); this.stageExecutionDescriptor = requireNonNull(stageExecutionDescriptor, "stageExecutionDescriptor is null"); + this.outputOrderingScheme = requireNonNull(outputOrderingScheme, "outputOrderingScheme is null"); this.outputTableWriterFragment = outputTableWriterFragment; this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.jsonRepresentation = requireNonNull(jsonRepresentation, "jsonRepresentation is null"); @@ -156,6 +163,12 @@ public Optional getStatsAndCosts() return statsAndCosts; } + @JsonProperty + public Optional getOutputOrderingScheme() + { + return outputOrderingScheme; + } + @JsonProperty public Optional getJsonRepresentation() { @@ -187,6 +200,7 @@ private PlanFragment forTaskSerialization() id, root, variables, partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, stageExecutionDescriptor, outputTableWriterFragment, Optional.empty(), @@ -246,6 +260,7 @@ public PlanFragment withBucketToPartition(Optional bucketToPartition) partitioning, tableScanSchedulingOrder, partitioningScheme.withBucketToPartition(bucketToPartition), + outputOrderingScheme, stageExecutionDescriptor, outputTableWriterFragment, statsAndCosts, @@ -261,6 +276,7 @@ public PlanFragment withFixedLifespanScheduleGroupedExecution(List c partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, StageExecutionDescriptor.fixedLifespanScheduleGroupedExecution(capableTableScanNodes, totalLifespans), outputTableWriterFragment, statsAndCosts, @@ -276,6 +292,7 @@ public PlanFragment withDynamicLifespanScheduleGroupedExecution(List partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, StageExecutionDescriptor.dynamicLifespanScheduleGroupedExecution(capableTableScanNodes, totalLifespans), outputTableWriterFragment, statsAndCosts, @@ -291,6 +308,7 @@ public PlanFragment withRecoverableGroupedExecution(List capableTabl partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, StageExecutionDescriptor.recoverableGroupedExecution(capableTableScanNodes, totalLifespans), outputTableWriterFragment, statsAndCosts, @@ -306,6 +324,7 @@ public PlanFragment withSubPlan(PlanNode subPlan) partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, stageExecutionDescriptor, outputTableWriterFragment, statsAndCosts, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java index 41c0f15975284..21e94c5564026 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java @@ -231,6 +231,7 @@ private static SubPlan reassignPartitioningHandleIfNecessaryHelper(Metadata meta outputPartitioningScheme.isScaleWriters(), outputPartitioningScheme.getEncoding(), outputPartitioningScheme.getBucketToPartition()), + fragment.getOutputOrderingScheme(), fragment.getStageExecutionDescriptor(), fragment.isOutputTableWriterFragment(), fragment.getStatsAndCosts(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 2d7c8be053645..bb8a1aed5ec0f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -190,6 +190,7 @@ import com.facebook.presto.sql.planner.optimizations.ShardJoins; import com.facebook.presto.sql.planner.optimizations.SimplifyPlanWithEmptyInput; import com.facebook.presto.sql.planner.optimizations.SortMergeJoinOptimizer; +import com.facebook.presto.sql.planner.optimizations.SortedExchangeRule; import com.facebook.presto.sql.planner.optimizations.StatsRecordingPlanOptimizer; import com.facebook.presto.sql.planner.optimizations.TransformQuantifiedComparisonApplyToLateralJoin; import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences; @@ -947,8 +948,10 @@ public PlanOptimizers( // MergeJoinForSortedInputOptimizer can avoid the local exchange for a join operation // Should be placed after AddExchanges, but before AddLocalExchange // To replace the JoinNode to MergeJoin ahead of AddLocalExchange to avoid adding extra local exchange + // SortedExchangeRule pushes sorts down to exchange nodes for distributed queries builder.add(new MergeJoinForSortedInputOptimizer(metadata, featuresConfig.isNativeExecutionEnabled()), - new SortMergeJoinOptimizer(metadata, featuresConfig.isNativeExecutionEnabled())); + new SortMergeJoinOptimizer(metadata, featuresConfig.isNativeExecutionEnabled()), + new SortedExchangeRule(metadata)); // Optimizers above this don't understand local exchanges, so be careful moving this. builder.add(new AddLocalExchanges(metadata, featuresConfig.isNativeExecutionEnabled())); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortedExchangeRule.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortedExchangeRule.java new file mode 100644 index 0000000000000..54ed05d629dd2 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortedExchangeRule.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.SortNode; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +/** + * Optimizer rule that pushes sort operations down to exchange nodes where possible. + * This optimization is beneficial for distributed queries where data needs to be sorted + * after shuffling, as it allows sorting to happen during the shuffle operation itself + * rather than requiring an explicit SortNode afterward. + * + * The rule looks for SortNode → ExchangeNode patterns and attempts to merge them into + * a single sorted exchange node when: + * - The exchange is a REMOTE REPARTITION exchange + * - The exchange doesn't already have an ordering scheme + * - All ordering variables are available in the exchange output + */ +public class SortedExchangeRule + implements PlanOptimizer +{ + private static final Logger log = Logger.get(SortedExchangeRule.class); + private final Metadata metadata; + private boolean isEnabledForTesting; + + public SortedExchangeRule(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public void setEnabledForTesting(boolean isSet) + { + isEnabledForTesting = isSet; + } + + @Override + public boolean isEnabled(Session session) + { + return com.facebook.presto.SystemSessionProperties.isEnableSortedExchanges(session) || isEnabledForTesting; + } + + @Override + public PlanOptimizerResult optimize( + PlanNode plan, + Session session, + TypeProvider types, + VariableAllocator variableAllocator, + PlanNodeIdAllocator idAllocator, + WarningCollector warningCollector) + { + requireNonNull(plan, "plan is null"); + requireNonNull(session, "session is null"); + requireNonNull(types, "types is null"); + requireNonNull(variableAllocator, "variableAllocator is null"); + requireNonNull(idAllocator, "idAllocator is null"); + requireNonNull(warningCollector, "warningCollector is null"); + + if (isEnabled(session)) { + Rewriter rewriter = new Rewriter(idAllocator); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + } + return PlanOptimizerResult.optimizerResult(plan, false); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + private boolean planChanged; + + public Rewriter(PlanNodeIdAllocator idAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + } + + public boolean isPlanChanged() + { + return planChanged; + } + + @Override + public PlanNode visitSort(SortNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + + // Try to push sort down to exchange if the source is an exchange node + Optional sortedExchange = pushSortToExchangeIfPossible(source, node.getOrderingScheme()); + if (sortedExchange.isPresent()) { + planChanged = true; + return sortedExchange.get(); + } + + // If push-down not possible, keep the original sort + if (source != node.getSource()) { + return new SortNode( + node.getSourceLocation(), + node.getId(), + source, + node.getOrderingScheme(), + node.isPartial(), + node.getPartitionBy()); + } + + return node; + } + + /** + * Attempts to push the sorting operation down to the Exchange node if the plan structure allows it. + * This is beneficial for distributed queries where we can sort during the shuffle operation instead of + * adding an explicit SortNode. + * + * @param plan The plan node that needs sorting + * @param orderingScheme The required ordering scheme + * @return Optional containing the enhanced exchange node if push-down is possible, empty otherwise + */ + private Optional pushSortToExchangeIfPossible(PlanNode plan, OrderingScheme orderingScheme) + { + // Check if this is a suitable exchange node for sort push-down + if (!(plan instanceof ExchangeNode)) { + log.debug("Unable to push sorting to exchange because child is not exchange node"); + return Optional.empty(); + } + + ExchangeNode exchangeNode = (ExchangeNode) plan; + + // Only push sort down to REPARTITION exchanges in remote scope + // These are the exchanges that involve shuffling data between executors + if (exchangeNode.getType() != ExchangeNode.Type.REPARTITION || + !exchangeNode.getScope().isRemote()) { + log.debug("Unable to push sorting to exchange because it is not REMOTE REPARTITION exchange " + + "(Type: {}, Scope: {})", exchangeNode.getType(), exchangeNode.getScope()); + return Optional.empty(); + } + + // Don't push down if the exchange already has ordering requirements + if (exchangeNode.getOrderingScheme().isPresent()) { + log.debug("Unable to push sorting to exchange because exchange already has ordering scheme"); + return Optional.empty(); + } + + // Check if all variables in the ordering scheme are available in the exchange output + if (!exchangeNode.getOutputVariables().containsAll(orderingScheme.getOrderByVariables())) { + log.debug("Unable to push sorting to exchange because not all ordering variables are available in exchange output"); + return Optional.empty(); + } + + // Create a new sorted exchange node + try { + ExchangeNode sortedExchange = ExchangeNode.sortedPartitionedExchange( + idAllocator.getNextId(), + exchangeNode.getScope(), + exchangeNode.getSources().get(0), + exchangeNode.getPartitioningScheme().getPartitioning(), + exchangeNode.getPartitioningScheme().getHashColumn(), + orderingScheme); + + log.info("Successfully pushed sorting down to REMOTE REPARTITION exchange with orderingScheme=%s", + orderingScheme); + return Optional.of(sortedExchange); + } + catch (Exception e) { + log.warn("Failed to create sorted exchange: " + e.getMessage()); + // If creating sorted exchange fails, fall back to explicit sort + return Optional.empty(); + } + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java index 32ea16c500394..5f57df98de5a6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java @@ -145,7 +145,8 @@ public ExchangeNode( orderingScheme.ifPresent(ordering -> { PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle(); - checkArgument(!scope.isRemote() || partitioningHandle.equals(SINGLE_DISTRIBUTION), "remote merging exchange requires single distribution"); + // This is no longer true for Presto-on-Spark which can support sorting on external shuffle systems + //checkArgument(!scope.isRemote() || partitioningHandle.equals(SINGLE_DISTRIBUTION), "remote merging exchange requires single distribution"); checkArgument(!scope.isLocal() || partitioningHandle.equals(FIXED_PASSTHROUGH_DISTRIBUTION), "local merging exchange requires passthrough distribution"); checkArgument(partitioningScheme.getOutputLayout().containsAll(ordering.getOrderByVariables()), "Partitioning scheme does not supply all required ordering symbols"); }); @@ -275,6 +276,38 @@ public static ExchangeNode mergingExchange(PlanNodeId id, Scope scope, PlanNode Optional.of(orderingScheme)); } + /** + * Creates an exchange node that performs sorting during the shuffle operation. + * This is used for merge joins where we want to push down sorting to the exchange layer. + */ + public static ExchangeNode sortedPartitionedExchange(PlanNodeId id, Scope scope, PlanNode child, Partitioning partitioning, Optional hashColumn, OrderingScheme sortOrder) + { + return new ExchangeNode( + child.getSourceLocation(), + id, + REPARTITION, + scope, + new PartitioningScheme(partitioning, child.getOutputVariables(), hashColumn, false, false, COLUMNAR, Optional.empty()), + ImmutableList.of(child), + ImmutableList.of(child.getOutputVariables()), + true, // Ensure source ordering since we're sorting + Optional.of(sortOrder)); + } + + /** + * Creates a system partitioned exchange with sorting during shuffle. + */ + public static ExchangeNode sortedSystemPartitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumn, OrderingScheme sortOrder) + { + return sortedPartitionedExchange( + id, + scope, + child, + Partitioning.create(FIXED_HASH_DISTRIBUTION, partitioningColumns), + hashColumn, + sortOrder); + } + @JsonProperty public Type getType() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 2b7059d12e02a..4d65f803664fd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -436,6 +436,7 @@ private static String formatFragment( Joiner.on(", ").join(partitioningScheme.getPartitioning().getArguments()), formatHash(partitioningScheme.getHashColumn()))); } + builder.append(indentString(1)).append(format("Output ordering: %s%n", fragment.getOutputOrderingScheme())); builder.append(indentString(1)).append(format("Output encoding: %s%n", fragment.getPartitioningScheme().getEncoding())); builder.append(indentString(1)).append(format("Stage Execution Strategy: %s%n", fragment.getStageExecutionDescriptor().getStageExecutionStrategy())); @@ -467,6 +468,7 @@ public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types, Stat SINGLE_DISTRIBUTION, ImmutableList.of(plan.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(estimatedStatsAndCosts), @@ -1132,9 +1134,15 @@ public Void visitSort(SortNode node, Void context) @Override public Void visitRemoteSource(RemoteSourceNode node, Void context) { + String nodeName = "RemoteSource"; + String orderingSchemStr = ""; + if (node.getOrderingScheme().isPresent()) { + orderingSchemStr = node.getOrderingScheme().toString(); + nodeName = "RemoteMerge"; + } addNode(node, - format("Remote%s", node.getOrderingScheme().isPresent() ? "Merge" : "Source"), - format("[%s]", Joiner.on(',').join(node.getSourceFragmentIds())), + format("%s", nodeName), + format("[%s] %s", Joiner.on(',').join(node.getSourceFragmentIds()), orderingSchemStr), ImmutableList.of(), ImmutableList.of(), node.getSourceFragmentIds()); diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java b/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java index 9bd75207fa9cf..ecbe7c01c33a1 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java @@ -126,6 +126,7 @@ public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, L SOURCE_DISTRIBUTION, ImmutableList.of(sourceId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java index e4a8d9c9edde2..a13f796547479 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java @@ -126,6 +126,7 @@ public static PlanFragment createPlanFragment() ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(VARIABLE)) .withBucketToPartition(Optional.of(new int[1])), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java index bce1acccd12f3..cac4f07885d74 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java @@ -174,6 +174,7 @@ private static PlanFragment createExchangePlanFragment() SOURCE_DISTRIBUTION, ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java index 8f50e98cd8d02..3a9ca81197856 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java @@ -159,6 +159,7 @@ private static PlanFragment createPlanFragment(PlanFragmentId fragmentId, PlanNo SOURCE_DISTRIBUTION, ImmutableList.of(remoteSourcePlanNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), remoteSourcePlanNode.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java index 8c932566c4453..a42ed6eb014f8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java @@ -280,6 +280,7 @@ private static PlanFragment createFragment(PlanNode planNode) SOURCE_DISTRIBUTION, ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java index 5ea371568b9c3..2c08ed14e0fd8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java @@ -515,6 +515,7 @@ private static SubPlan createPlan() SOURCE_DISTRIBUTION, ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 9f55c3b2eab19..eeece131d73b3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -196,6 +196,7 @@ public void testDefaults() .setHyperloglogStandardErrorWarningThreshold(0.004) .setPreferMergeJoinForSortedInputs(false) .setPreferSortMergeJoin(false) + .setEnableSortedExchanges(false) .setSegmentedAggregationEnabled(false) .setQueryAnalyzerTimeout(new Duration(3, MINUTES)) .setQuickDistinctLimitEnabled(false) @@ -415,6 +416,7 @@ public void testExplicitPropertyMappings() .put("hyperloglog-standard-error-warning-threshold", "0.02") .put("optimizer.prefer-merge-join-for-sorted-inputs", "true") .put("experimental.optimizer.prefer-sort-merge-join", "true") + .put("optimizer.experimental.enable-sorted-exchanges", "true") .put("optimizer.segmented-aggregation-enabled", "true") .put("planner.query-analyzer-timeout", "10s") .put("optimizer.quick-distinct-limit-enabled", "true") @@ -631,6 +633,7 @@ public void testExplicitPropertyMappings() .setHyperloglogStandardErrorWarningThreshold(0.02) .setPreferMergeJoinForSortedInputs(true) .setPreferSortMergeJoin(true) + .setEnableSortedExchanges(true) .setSegmentedAggregationEnabled(true) .setQueryAnalyzerTimeout(new Duration(10, SECONDS)) .setQuickDistinctLimitEnabled(true) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java index 5422304b52935..9789e8bda4652 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java @@ -190,6 +190,7 @@ private LocalExecutionPlan getLocalExecutionPlan(Session session, PlanNode plan, SOURCE_DISTRIBUTION, ImmutableList.of(new PlanNodeId("sourceId")), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSortedExchangeRule.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSortedExchangeRule.java new file mode 100644 index 0000000000000..e20e72e0930b5 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSortedExchangeRule.java @@ -0,0 +1,207 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.Optimizer; +import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static com.facebook.presto.SystemSessionProperties.ENABLE_SORTED_EXCHANGES; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +/** + * Tests for the SortedExchangeRule optimizer which pushes sort operations + * down to exchange nodes for distributed queries. + */ +public class TestSortedExchangeRule + extends BasePlanTest +{ + private Session getSessionWithSortedExchangesEnabled() + { + return Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(ENABLE_SORTED_EXCHANGES, "true") + .build(); + } + + @Test + public void testPushSortToRemoteRepartitionExchange() + { + @Language("SQL") String sql = "SELECT orderkey, custkey FROM orders ORDER BY orderkey"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + + // Find all remote repartition exchanges with ordering + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + // Verify at least one sorted exchange exists + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + + // Verify the ordering contains orderkey + boolean hasOrderkeyOrdering = sortedExchanges.stream() + .anyMatch(exchange -> exchange.getOrderingScheme().isPresent() && + exchange.getOrderingScheme().get().getOrderByVariables().stream() + .anyMatch(v -> v.getName().equals("orderkey"))); + assertTrue(hasOrderkeyOrdering, "Expected exchange with orderkey ordering"); + } + + @Test + public void testSortedExchangeDisabledBySessionProperty() + { + @Language("SQL") String sql = "SELECT orderkey, custkey FROM orders ORDER BY orderkey"; + + Session disabledSession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(ENABLE_SORTED_EXCHANGES, "false") + .build(); + + Plan plan = plan(disabledSession, sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + + // Find all remote repartition exchanges + List exchanges = PlanNodeSearcher.searchFrom(plan.getRoot()) + .where(node -> node instanceof ExchangeNode && + ((ExchangeNode) node).getScope().isRemote() && + ((ExchangeNode) node).getType() == ExchangeNode.Type.REPARTITION) + .findAll() + .stream() + .map(ExchangeNode.class::cast) + .collect(Collectors.toList()); + + // When disabled, exchanges should not have ordering (or there should be explicit sort nodes) + // Just verify the plan builds successfully + assertTrue(exchanges.size() > 0, "Expected at least one exchange"); + } + + @Test + public void testSortOnMultipleColumns() + { + @Language("SQL") String sql = "SELECT orderkey, custkey, totalprice FROM orders ORDER BY orderkey, custkey"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + + // Verify ordering contains both columns + boolean hasMultiColumnOrdering = sortedExchanges.stream() + .anyMatch(exchange -> exchange.getOrderingScheme().isPresent() && + exchange.getOrderingScheme().get().getOrderByVariables().size() >= 2); + assertTrue(hasMultiColumnOrdering, "Expected exchange with multi-column ordering"); + } + + @Test + public void testSortWithDescendingOrder() + { + @Language("SQL") String sql = "SELECT orderkey, custkey FROM orders ORDER BY orderkey DESC"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + } + + @Test + public void testSortedExchangeWithJoin() + { + @Language("SQL") String sql = "SELECT o.orderkey, o.custkey " + + "FROM orders o " + + "JOIN customer c ON o.custkey = c.custkey " + + "ORDER BY o.orderkey"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + } + + @Test + public void testSortedExchangeWithAggregation() + { + @Language("SQL") String sql = "SELECT custkey, SUM(totalprice) as total " + + "FROM orders " + + "GROUP BY custkey " + + "ORDER BY custkey"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + } + + @Test + public void testSortWithLimit() + { + @Language("SQL") String sql = "SELECT orderkey, custkey FROM orders ORDER BY orderkey LIMIT 100"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + } + + @Test + public void testMultipleSortsInQuery() + { + @Language("SQL") String sql = "SELECT orderkey, custkey FROM " + + "(SELECT orderkey, custkey FROM orders ORDER BY custkey) " + + "ORDER BY orderkey"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + } + + @Test + public void testSortWithNullsFirst() + { + @Language("SQL") String sql = "SELECT orderkey, custkey FROM orders ORDER BY orderkey NULLS FIRST"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + } + + @Test + public void testSortWithNullsLast() + { + @Language("SQL") String sql = "SELECT orderkey, custkey FROM orders ORDER BY orderkey NULLS LAST"; + + Plan plan = plan(getSessionWithSortedExchangesEnabled(), sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false); + List sortedExchanges = findSortedRemoteExchanges(plan.getRoot()); + + assertFalse(sortedExchanges.isEmpty(), "Expected at least one sorted exchange"); + } + + private List findSortedRemoteExchanges(PlanNode root) + { + return PlanNodeSearcher.searchFrom(root) + .where(node -> node instanceof ExchangeNode && + ((ExchangeNode) node).getScope().isRemote() && + ((ExchangeNode) node).getType() == ExchangeNode.Type.REPARTITION && + ((ExchangeNode) node).getOrderingScheme().isPresent()) + .findAll() + .stream() + .map(ExchangeNode.class::cast) + .collect(Collectors.toList()); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java index 4872e86029341..076de6a3f3252 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java @@ -101,6 +101,7 @@ private String domainToPrintedScan(VariableReferenceExpression variable, ColumnH SOURCE_DISTRIBUTION, ImmutableList.of(scanNode.getId()), new PartitioningScheme(Partitioning.create(SOURCE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java b/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java index e04f3a67bdabb..c88db53d01a26 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java @@ -200,6 +200,7 @@ private static PlanFragment createTestPlanFragment(int id, PlanNode node) SOURCE_DISTRIBUTION, ImmutableList.of(TEST_TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), + Optional.empty(), ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index 227feb506fed7..cd1cdb712f50f 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -2294,6 +2294,29 @@ core::PlanFragment VeloxBatchQueryPlanConverter::toVeloxQueryPlan( return planFragment; } + // Convert outputOrderingScheme to sortingKeys and sortingOrders + std::optional> sortingOrders = + std::nullopt; + std::optional> sortingKeys = + std::nullopt; + + if (fragment.outputOrderingScheme && + !fragment.outputOrderingScheme->orderBy.empty()) { + std::vector orders; + std::vector keys; + + orders.reserve(fragment.outputOrderingScheme->orderBy.size()); + keys.reserve(fragment.outputOrderingScheme->orderBy.size()); + + for (const auto& ordering : fragment.outputOrderingScheme->orderBy) { + keys.emplace_back(exprConverter_.toVeloxExpr(ordering.variable)); + orders.emplace_back(toVeloxSortOrder(ordering.sortOrder)); + } + + sortingKeys = std::move(keys); + sortingOrders = std::move(orders); + } + const auto partitionAndSerializeNode = std::make_shared( fmt::format("{}.ps", partitionedOutputNode->id()), @@ -2302,7 +2325,9 @@ core::PlanFragment VeloxBatchQueryPlanConverter::toVeloxQueryPlan( partitionedOutputNode->outputType(), partitionedOutputNode->sources()[0], partitionedOutputNode->isReplicateNullsAndAny(), - partitionedOutputNode->partitionFunctionSpecPtr()); + partitionedOutputNode->partitionFunctionSpecPtr(), + sortingOrders, + sortingKeys); planFragment.planNode = std::make_shared( fmt::format("{}.sw", partitionedOutputNode->id()), diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index efb585849ef31..093f0957f3031 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -8616,6 +8616,13 @@ void to_json(json& j, const PlanFragment& p) { "PlanFragment", "StageExecutionDescriptor", "stageExecutionDescriptor"); + to_json_key( + j, + "outputOrderingScheme", + p.outputOrderingScheme, + "PlanFragment", + "OrderingScheme", + "outputOrderingScheme"); to_json_key( j, "outputTableWriterFragment", @@ -8670,6 +8677,13 @@ void from_json(const json& j, PlanFragment& p) { "PlanFragment", "StageExecutionDescriptor", "stageExecutionDescriptor"); + from_json_key( + j, + "outputOrderingScheme", + p.outputOrderingScheme, + "PlanFragment", + "OrderingScheme", + "outputOrderingScheme"); from_json_key( j, "outputTableWriterFragment", diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index 2b1e4eb66c14e..21f86a3cdf203 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -2002,6 +2002,7 @@ struct PlanFragment { List tableScanSchedulingOrder = {}; PartitioningScheme partitioningScheme = {}; StageExecutionDescriptor stageExecutionDescriptor = {}; + std::shared_ptr outputOrderingScheme = {}; bool outputTableWriterFragment = {}; std::shared_ptr jsonRepresentation = {}; }; diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.cpp.inc index 54c08ebbe1c65..2149c2c7b7457 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.cpp.inc @@ -53,6 +53,13 @@ void to_json(json& j, const PlanFragment& p) { "PlanFragment", "StageExecutionDescriptor", "stageExecutionDescriptor"); + to_json_key( + j, + "outputOrderingScheme", + p.outputOrderingScheme, + "PlanFragment", + "OrderingScheme", + "outputOrderingScheme"); to_json_key( j, "outputTableWriterFragment", @@ -107,6 +114,13 @@ void from_json(const json& j, PlanFragment& p) { "PlanFragment", "StageExecutionDescriptor", "stageExecutionDescriptor"); + from_json_key( + j, + "outputOrderingScheme", + p.outputOrderingScheme, + "PlanFragment", + "OrderingScheme", + "outputOrderingScheme"); from_json_key( j, "outputTableWriterFragment", diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.hpp.inc index b02ee2acdce97..8381767069599 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.hpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/PlanFragment.hpp.inc @@ -21,6 +21,7 @@ struct PlanFragment { List tableScanSchedulingOrder = {}; PartitioningScheme partitioningScheme = {}; StageExecutionDescriptor stageExecutionDescriptor = {}; + std::shared_ptr outputOrderingScheme = {}; bool outputTableWriterFragment = {}; std::shared_ptr jsonRepresentation = {}; }; diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/AbstractPrestoSparkQueryExecution.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/AbstractPrestoSparkQueryExecution.java index b76da7f2ddc2e..e17eee2554696 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/AbstractPrestoSparkQueryExecution.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/AbstractPrestoSparkQueryExecution.java @@ -53,6 +53,7 @@ import com.facebook.presto.spark.classloader_interface.IPrestoSparkQueryExecution; import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor; import com.facebook.presto.spark.classloader_interface.MutablePartitionId; +import com.facebook.presto.spark.classloader_interface.MutablePartitionIdOrdering; import com.facebook.presto.spark.classloader_interface.PrestoSparkExecutionException; import com.facebook.presto.spark.classloader_interface.PrestoSparkJavaExecutionTaskInputs; import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow; @@ -77,6 +78,7 @@ import com.facebook.presto.spi.connector.ConnectorCapabilities; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.page.PagesSerde; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PartitioningHandle; import com.facebook.presto.spi.plan.PartitioningScheme; import com.facebook.presto.spi.plan.PlanFragmentId; @@ -293,13 +295,16 @@ public AbstractPrestoSparkQueryExecution( protected static JavaPairRDD partitionBy( int planFragmentId, JavaPairRDD rdd, - PartitioningScheme partitioningScheme) + PartitioningScheme partitioningScheme, Optional orderingScheme) { Partitioner partitioner = createPartitioner(partitioningScheme); JavaPairRDD javaPairRdd = rdd.partitionBy(partitioner); ShuffledRDD shuffledRdd = (ShuffledRDD) javaPairRdd.rdd(); shuffledRdd.setSerializer(new PrestoSparkShuffleSerializer()); shuffledRdd.setName(getRDDName(planFragmentId)); + if (orderingScheme.isPresent()) { + shuffledRdd.setKeyOrdering(new MutablePartitionIdOrdering()); + } return JavaPairRDD.fromRDD( shuffledRdd, classTag(MutablePartitionId.class), @@ -548,7 +553,11 @@ public RddAndMore createRdd(SubPlan subPlan } else { RddAndMore childRdd = createRdd(child, PrestoSparkMutableRow.class, tableWriteInfo); - rddInputs.put(childFragment.getId(), partitionBy(childFragment.getId().getId(), childRdd.getRdd(), child.getFragment().getPartitioningScheme())); + rddInputs.put(childFragment.getId(), partitionBy( + childFragment.getId().getId(), + childRdd.getRdd(), + child.getFragment().getPartitioningScheme(), + child.getFragment().getOutputOrderingScheme())); broadcastDependencies.addAll(childRdd.getBroadcastDependencies()); } } @@ -890,7 +899,11 @@ protected synchronized RddAndMore createRdd // For intermediate, non-broadcast stages - we use partitioned RDD // These stages produce PrestoSparkMutableRow if (outputType == PrestoSparkMutableRow.class) { - rdd = (JavaPairRDD) partitionBy(subPlan.getFragment().getId().getId(), (JavaPairRDD) rdd, subPlan.getFragment().getPartitioningScheme()); + rdd = (JavaPairRDD) partitionBy( + subPlan.getFragment().getId().getId(), + (JavaPairRDD) rdd, + subPlan.getFragment().getPartitioningScheme(), + subPlan.getFragment().getOutputOrderingScheme()); } RddAndMore rddAndMore = new RddAndMore(rdd, broadcastDependencies.build(), Optional.ofNullable(subPlan.getFragment().getPartitioningScheme().getPartitioning().getHandle())); diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java index 6b0cb50d4267c..e20ef8a002477 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java @@ -151,15 +151,6 @@ public JavaPairRDD crea partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SOURCE_DISTRIBUTION) || partitioning.getConnectorId().isPresent()) { - for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) { - if (remoteSource.isEnsureSourceOrdering() || remoteSource.getOrderingScheme().isPresent()) { - throw new PrestoException(NOT_SUPPORTED, format( - "Order sensitive exchange is not supported by Presto on Spark. fragmentId: %s, sourceFragmentIds: %s", - fragment.getId(), - remoteSource.getSourceFragmentIds())); - } - } - return createRdd( sparkContext, session, diff --git a/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/MutablePartitionIdOrdering.java b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/MutablePartitionIdOrdering.java new file mode 100644 index 0000000000000..21b358732f391 --- /dev/null +++ b/presto-spark-classloader-interface/src/main/java/com/facebook/presto/spark/classloader_interface/MutablePartitionIdOrdering.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spark.classloader_interface; + +import scala.Function1; +import scala.Some; +import scala.math.Ordering; + +public class MutablePartitionIdOrdering + implements Ordering +{ + @Override + public Some tryCompare(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return new scala.Some(new Object()); + } + + @Override + public int compare(MutablePartitionId p1, MutablePartitionId p2) + { + return Integer.compare(p1.getPartition(), p2.getPartition()); + } + + @Override + public boolean lteq(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return mutablePartitionId.getPartition() <= t1.getPartition(); + } + + @Override + public boolean gteq(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return mutablePartitionId.getPartition() >= t1.getPartition(); + } + + @Override + public boolean lt(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return mutablePartitionId.getPartition() < t1.getPartition(); + } + + @Override + public boolean gt(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return mutablePartitionId.getPartition() > t1.getPartition(); + } + + @Override + public boolean equiv(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return mutablePartitionId.getPartition() == t1.getPartition(); + } + + @Override + public MutablePartitionId max(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return mutablePartitionId; + } + + @Override + public MutablePartitionId min(MutablePartitionId mutablePartitionId, MutablePartitionId t1) + { + return mutablePartitionId; + } + + @Override + public Ordering reverse() + { + try { + return (Ordering) this.clone(); + } + catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + } + + @Override + public Ordering on(Function1 function1) + { + return null; + } + + @Override + public Ordering.Ops mkOrderingOps(MutablePartitionId mutablePartitionId) + { + return null; + } +}