Skip to content

Commit 66dd7dd

Browse files
cloud-fandongjoon-hyun
authored andcommitted
[SPARK-50739][SQL][FOLLOW] Simplify ResolveRecursiveCTESuite with dsl
### What changes were proposed in this pull request? A followup of #49351 to simplify the test via dsl. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? no Closes #49557 from cloud-fan/clean. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 6b491bc commit 66dd7dd

File tree

3 files changed

+57
-87
lines changed

3 files changed

+57
-87
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import scala.collection.mutable
2121

22+
import org.apache.spark.sql.AnalysisException
2223
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
2324
import org.apache.spark.sql.catalyst.plans.logical._
2425
import org.apache.spark.sql.catalyst.rules.Rule
2526
import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
26-
import org.apache.spark.sql.errors.QueryCompilationErrors
2727

2828
/**
2929
* Updates CTE references with the resolve output attributes of corresponding CTE definitions.
@@ -144,8 +144,9 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
144144
// Project (as UnresolvedSubqueryColumnAliases have not been substituted with the
145145
// Project yet), leaving us with cases of SubqueryAlias->Union and SubqueryAlias->
146146
// UnresolvedSubqueryColumnAliases->Union. The same applies to Distinct Union.
147-
throw QueryCompilationErrors.invalidRecursiveCteError(
148-
"Unsupported recursive CTE UNION placement.")
147+
throw new AnalysisException(
148+
errorClass = "INVALID_RECURSIVE_CTE",
149+
messageParameters = Map.empty)
149150
}
150151
}
151152
withCTE.copy(cteDefs = newCTEDefs)

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4332,12 +4332,4 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
43324332
origin = origin
43334333
)
43344334
}
4335-
4336-
def invalidRecursiveCteError(error: String): Throwable = {
4337-
new AnalysisException(
4338-
errorClass = "INVALID_RECURSIVE_CTE",
4339-
messageParameters = Map(
4340-
"error" -> error
4341-
))
4342-
}
43434335
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveRecursiveCTESuite.scala

Lines changed: 53 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,102 +18,79 @@
1818

1919
package org.apache.spark.sql.catalyst.analysis
2020

21-
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.ResolveSubqueryColumnAliases
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.dsl.plans._
2223
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
2324
import org.apache.spark.sql.catalyst.plans.logical._
24-
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2525

2626
class ResolveRecursiveCTESuite extends AnalysisTest {
2727
// Motivated by:
2828
// WITH RECURSIVE t AS (SELECT 1 UNION ALL SELECT * FROM t) SELECT * FROM t;
2929
test("ResolveWithCTE rule on recursive CTE without UnresolvedSubqueryColumnAliases") {
30-
// The analyzer will repeat ResolveWithCTE rule twice.
31-
val rules = Seq(ResolveWithCTE, ResolveWithCTE)
32-
val analyzer = new RuleExecutor[LogicalPlan] {
33-
override val batches = Seq(Batch("Resolution", Once, rules: _*))
34-
}
35-
// Since cteDef IDs need to be the same, cteDef for each case will be created by copying
36-
// this one with its child replaced.
37-
val cteDef = CTERelationDef(OneRowRelation())
38-
val anchor = Project(Seq(Alias(Literal(1), "1")()), OneRowRelation())
39-
40-
def getBeforePlan(cteDef: CTERelationDef): LogicalPlan = {
41-
val recursionPart = SubqueryAlias("t",
42-
CTERelationRef(cteDef.id, false, Seq(), false, recursive = true))
43-
44-
val cteDefFinal = cteDef.copy(child =
45-
SubqueryAlias("t", Union(Seq(anchor, recursionPart))))
46-
30+
val cteId = 0
31+
val anchor = Project(Seq(Alias(Literal(1), "c")()), OneRowRelation())
32+
33+
def getBeforePlan(): LogicalPlan = {
34+
val cteRef = CTERelationRef(
35+
cteId,
36+
_resolved = false,
37+
output = Seq(),
38+
isStreaming = false)
39+
val recursion = cteRef.copy(recursive = true).subquery("t")
4740
WithCTE(
48-
SubqueryAlias("t", CTERelationRef(cteDefFinal.id, false, Seq(), false, recursive = false)),
49-
Seq(cteDefFinal))
41+
cteRef.copy(recursive = false),
42+
Seq(CTERelationDef(anchor.union(recursion).subquery("t"), cteId)))
5043
}
5144

52-
def getAfterPlan(cteDef: CTERelationDef): LogicalPlan = {
53-
val saRecursion = SubqueryAlias("t",
54-
UnionLoopRef(cteDef.id, anchor.output, false))
55-
56-
val cteDefFinal = cteDef.copy(child =
57-
SubqueryAlias("t", UnionLoop(cteDef.id, anchor, saRecursion)))
58-
59-
val outerCteRef = CTERelationRef(cteDefFinal.id, true, cteDefFinal.output, false,
60-
recursive = false)
61-
62-
WithCTE(SubqueryAlias("t", outerCteRef), Seq(cteDefFinal))
45+
def getAfterPlan(): LogicalPlan = {
46+
val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false).subquery("t")
47+
val cteDef = CTERelationDef(UnionLoop(cteId, anchor, recursion).subquery("t"), cteId)
48+
val cteRef = CTERelationRef(
49+
cteId,
50+
_resolved = true,
51+
output = cteDef.output,
52+
isStreaming = false)
53+
WithCTE(cteRef, Seq(cteDef))
6354
}
6455

65-
val beforePlan = getBeforePlan(cteDef)
66-
val afterPlan = getAfterPlan(cteDef)
67-
68-
comparePlans(analyzer.execute(beforePlan), afterPlan)
56+
comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
6957
}
7058

7159
// Motivated by:
7260
// WITH RECURSIVE t(n) AS (SELECT 1 UNION ALL SELECT * FROM t) SELECT * FROM t;
7361
test("ResolveWithCTE rule on recursive CTE with UnresolvedSubqueryColumnAliases") {
74-
// The analyzer will repeat ResolveWithCTE rule twice.
75-
val rules = Seq(ResolveWithCTE, ResolveSubqueryColumnAliases, ResolveWithCTE)
76-
val analyzer = new RuleExecutor[LogicalPlan] {
77-
override val batches = Seq(Batch("Resolution", Once, rules: _*))
62+
val cteId = 0
63+
val anchor = Project(Seq(Alias(Literal(1), "c")()), OneRowRelation())
64+
65+
def getBeforePlan(): LogicalPlan = {
66+
val cteRef = CTERelationRef(
67+
cteId,
68+
_resolved = false,
69+
output = Seq(),
70+
isStreaming = false)
71+
val recursion = cteRef.copy(recursive = true).subquery("t")
72+
val cteDef = CTERelationDef(
73+
UnresolvedSubqueryColumnAliases(Seq("n"), anchor.union(recursion)).subquery("t"),
74+
cteId)
75+
WithCTE(cteRef.copy(recursive = false), Seq(cteDef))
7876
}
79-
// Since cteDef IDs need to be the same, cteDef for each case will be created by copying
80-
// this one with its child replaced.
81-
val cteDef = CTERelationDef(OneRowRelation())
82-
val anchor = Project(Seq(Alias(Literal(1), "1")()), OneRowRelation())
83-
84-
def getBeforePlan(cteDef: CTERelationDef): LogicalPlan = {
85-
val recursionPart = SubqueryAlias("t",
86-
CTERelationRef(cteDef.id, false, Seq(), false, recursive = true))
8777

88-
val cteDefFinal = cteDef.copy(child =
89-
SubqueryAlias("t",
90-
UnresolvedSubqueryColumnAliases(Seq("n"),
91-
Union(Seq(anchor, recursionPart)))))
92-
93-
WithCTE(
94-
SubqueryAlias("t", CTERelationRef(cteDefFinal.id, false, Seq(), false, recursive = false)),
95-
Seq(cteDefFinal))
96-
}
97-
98-
def getAfterPlan(cteDef: CTERelationDef): LogicalPlan = {
99-
val saRecursion = SubqueryAlias("t",
100-
Project(Seq(Alias(anchor.output.head, "n")()),
101-
UnionLoopRef(cteDef.id, anchor.output, false)))
102-
103-
val cteDefFinal = cteDef.copy(child =
104-
SubqueryAlias("t",
105-
Project(Seq(Alias(anchor.output.head, "n")()),
106-
UnionLoop(cteDef.id, anchor, saRecursion))))
107-
108-
val outerCteRef = CTERelationRef(cteDefFinal.id, true, cteDefFinal.output, false,
109-
recursive = false)
110-
111-
WithCTE(SubqueryAlias("t", outerCteRef), Seq(cteDefFinal))
78+
def getAfterPlan(): LogicalPlan = {
79+
val col = anchor.output.head
80+
val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false)
81+
.select(col.as("n"))
82+
.subquery("t")
83+
val cteDef = CTERelationDef(
84+
UnionLoop(cteId, anchor, recursion).select(col.as("n")).subquery("t"),
85+
cteId)
86+
val cteRef = CTERelationRef(
87+
cteId,
88+
_resolved = true,
89+
output = cteDef.output,
90+
isStreaming = false)
91+
WithCTE(cteRef, Seq(cteDef))
11292
}
11393

114-
val beforePlan = getBeforePlan(cteDef)
115-
val afterPlan = getAfterPlan(cteDef)
116-
117-
comparePlans(analyzer.execute(beforePlan), afterPlan)
94+
comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
11895
}
11996
}

0 commit comments

Comments
 (0)