|
18 | 18 |
|
19 | 19 | package org.apache.spark.sql.catalyst.analysis
|
20 | 20 |
|
21 |
| -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.ResolveSubqueryColumnAliases |
| 21 | +import org.apache.spark.sql.catalyst.dsl.expressions._ |
| 22 | +import org.apache.spark.sql.catalyst.dsl.plans._ |
22 | 23 | import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
|
23 | 24 | import org.apache.spark.sql.catalyst.plans.logical._
|
24 |
| -import org.apache.spark.sql.catalyst.rules.RuleExecutor |
25 | 25 |
|
26 | 26 | class ResolveRecursiveCTESuite extends AnalysisTest {
|
27 | 27 | // Motivated by:
|
28 | 28 | // WITH RECURSIVE t AS (SELECT 1 UNION ALL SELECT * FROM t) SELECT * FROM t;
|
29 | 29 | test("ResolveWithCTE rule on recursive CTE without UnresolvedSubqueryColumnAliases") {
|
30 |
| - // The analyzer will repeat ResolveWithCTE rule twice. |
31 |
| - val rules = Seq(ResolveWithCTE, ResolveWithCTE) |
32 |
| - val analyzer = new RuleExecutor[LogicalPlan] { |
33 |
| - override val batches = Seq(Batch("Resolution", Once, rules: _*)) |
34 |
| - } |
35 |
| - // Since cteDef IDs need to be the same, cteDef for each case will be created by copying |
36 |
| - // this one with its child replaced. |
37 |
| - val cteDef = CTERelationDef(OneRowRelation()) |
38 |
| - val anchor = Project(Seq(Alias(Literal(1), "1")()), OneRowRelation()) |
39 |
| - |
40 |
| - def getBeforePlan(cteDef: CTERelationDef): LogicalPlan = { |
41 |
| - val recursionPart = SubqueryAlias("t", |
42 |
| - CTERelationRef(cteDef.id, false, Seq(), false, recursive = true)) |
43 |
| - |
44 |
| - val cteDefFinal = cteDef.copy(child = |
45 |
| - SubqueryAlias("t", Union(Seq(anchor, recursionPart)))) |
46 |
| - |
| 30 | + val cteId = 0 |
| 31 | + val anchor = Project(Seq(Alias(Literal(1), "c")()), OneRowRelation()) |
| 32 | + |
| 33 | + def getBeforePlan(): LogicalPlan = { |
| 34 | + val cteRef = CTERelationRef( |
| 35 | + cteId, |
| 36 | + _resolved = false, |
| 37 | + output = Seq(), |
| 38 | + isStreaming = false) |
| 39 | + val recursion = cteRef.copy(recursive = true).subquery("t") |
47 | 40 | WithCTE(
|
48 |
| - SubqueryAlias("t", CTERelationRef(cteDefFinal.id, false, Seq(), false, recursive = false)), |
49 |
| - Seq(cteDefFinal)) |
| 41 | + cteRef.copy(recursive = false), |
| 42 | + Seq(CTERelationDef(anchor.union(recursion).subquery("t"), cteId))) |
50 | 43 | }
|
51 | 44 |
|
52 |
| - def getAfterPlan(cteDef: CTERelationDef): LogicalPlan = { |
53 |
| - val saRecursion = SubqueryAlias("t", |
54 |
| - UnionLoopRef(cteDef.id, anchor.output, false)) |
55 |
| - |
56 |
| - val cteDefFinal = cteDef.copy(child = |
57 |
| - SubqueryAlias("t", UnionLoop(cteDef.id, anchor, saRecursion))) |
58 |
| - |
59 |
| - val outerCteRef = CTERelationRef(cteDefFinal.id, true, cteDefFinal.output, false, |
60 |
| - recursive = false) |
61 |
| - |
62 |
| - WithCTE(SubqueryAlias("t", outerCteRef), Seq(cteDefFinal)) |
| 45 | + def getAfterPlan(): LogicalPlan = { |
| 46 | + val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false).subquery("t") |
| 47 | + val cteDef = CTERelationDef(UnionLoop(cteId, anchor, recursion).subquery("t"), cteId) |
| 48 | + val cteRef = CTERelationRef( |
| 49 | + cteId, |
| 50 | + _resolved = true, |
| 51 | + output = cteDef.output, |
| 52 | + isStreaming = false) |
| 53 | + WithCTE(cteRef, Seq(cteDef)) |
63 | 54 | }
|
64 | 55 |
|
65 |
| - val beforePlan = getBeforePlan(cteDef) |
66 |
| - val afterPlan = getAfterPlan(cteDef) |
67 |
| - |
68 |
| - comparePlans(analyzer.execute(beforePlan), afterPlan) |
| 56 | + comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan()) |
69 | 57 | }
|
70 | 58 |
|
71 | 59 | // Motivated by:
|
72 | 60 | // WITH RECURSIVE t(n) AS (SELECT 1 UNION ALL SELECT * FROM t) SELECT * FROM t;
|
73 | 61 | test("ResolveWithCTE rule on recursive CTE with UnresolvedSubqueryColumnAliases") {
|
74 |
| - // The analyzer will repeat ResolveWithCTE rule twice. |
75 |
| - val rules = Seq(ResolveWithCTE, ResolveSubqueryColumnAliases, ResolveWithCTE) |
76 |
| - val analyzer = new RuleExecutor[LogicalPlan] { |
77 |
| - override val batches = Seq(Batch("Resolution", Once, rules: _*)) |
| 62 | + val cteId = 0 |
| 63 | + val anchor = Project(Seq(Alias(Literal(1), "c")()), OneRowRelation()) |
| 64 | + |
| 65 | + def getBeforePlan(): LogicalPlan = { |
| 66 | + val cteRef = CTERelationRef( |
| 67 | + cteId, |
| 68 | + _resolved = false, |
| 69 | + output = Seq(), |
| 70 | + isStreaming = false) |
| 71 | + val recursion = cteRef.copy(recursive = true).subquery("t") |
| 72 | + val cteDef = CTERelationDef( |
| 73 | + UnresolvedSubqueryColumnAliases(Seq("n"), anchor.union(recursion)).subquery("t"), |
| 74 | + cteId) |
| 75 | + WithCTE(cteRef.copy(recursive = false), Seq(cteDef)) |
78 | 76 | }
|
79 |
| - // Since cteDef IDs need to be the same, cteDef for each case will be created by copying |
80 |
| - // this one with its child replaced. |
81 |
| - val cteDef = CTERelationDef(OneRowRelation()) |
82 |
| - val anchor = Project(Seq(Alias(Literal(1), "1")()), OneRowRelation()) |
83 |
| - |
84 |
| - def getBeforePlan(cteDef: CTERelationDef): LogicalPlan = { |
85 |
| - val recursionPart = SubqueryAlias("t", |
86 |
| - CTERelationRef(cteDef.id, false, Seq(), false, recursive = true)) |
87 | 77 |
|
88 |
| - val cteDefFinal = cteDef.copy(child = |
89 |
| - SubqueryAlias("t", |
90 |
| - UnresolvedSubqueryColumnAliases(Seq("n"), |
91 |
| - Union(Seq(anchor, recursionPart))))) |
92 |
| - |
93 |
| - WithCTE( |
94 |
| - SubqueryAlias("t", CTERelationRef(cteDefFinal.id, false, Seq(), false, recursive = false)), |
95 |
| - Seq(cteDefFinal)) |
96 |
| - } |
97 |
| - |
98 |
| - def getAfterPlan(cteDef: CTERelationDef): LogicalPlan = { |
99 |
| - val saRecursion = SubqueryAlias("t", |
100 |
| - Project(Seq(Alias(anchor.output.head, "n")()), |
101 |
| - UnionLoopRef(cteDef.id, anchor.output, false))) |
102 |
| - |
103 |
| - val cteDefFinal = cteDef.copy(child = |
104 |
| - SubqueryAlias("t", |
105 |
| - Project(Seq(Alias(anchor.output.head, "n")()), |
106 |
| - UnionLoop(cteDef.id, anchor, saRecursion)))) |
107 |
| - |
108 |
| - val outerCteRef = CTERelationRef(cteDefFinal.id, true, cteDefFinal.output, false, |
109 |
| - recursive = false) |
110 |
| - |
111 |
| - WithCTE(SubqueryAlias("t", outerCteRef), Seq(cteDefFinal)) |
| 78 | + def getAfterPlan(): LogicalPlan = { |
| 79 | + val col = anchor.output.head |
| 80 | + val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false) |
| 81 | + .select(col.as("n")) |
| 82 | + .subquery("t") |
| 83 | + val cteDef = CTERelationDef( |
| 84 | + UnionLoop(cteId, anchor, recursion).select(col.as("n")).subquery("t"), |
| 85 | + cteId) |
| 86 | + val cteRef = CTERelationRef( |
| 87 | + cteId, |
| 88 | + _resolved = true, |
| 89 | + output = cteDef.output, |
| 90 | + isStreaming = false) |
| 91 | + WithCTE(cteRef, Seq(cteDef)) |
112 | 92 | }
|
113 | 93 |
|
114 |
| - val beforePlan = getBeforePlan(cteDef) |
115 |
| - val afterPlan = getAfterPlan(cteDef) |
116 |
| - |
117 |
| - comparePlans(analyzer.execute(beforePlan), afterPlan) |
| 94 | + comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan()) |
118 | 95 | }
|
119 | 96 | }
|
0 commit comments