|
6 | 6 | package org.opensearch.flint.spark.ppl
|
7 | 7 |
|
8 | 8 | import org.opensearch.flint.spark.ppl.PlaneUtils.plan
|
| 9 | +import org.opensearch.sql.common.antlr.SyntaxCheckException |
9 | 10 | import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
|
10 | 11 | import org.scalatest.matchers.should.Matchers
|
11 | 12 |
|
12 | 13 | import org.apache.spark.SparkFunSuite
|
13 | 14 | import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
|
14 | 15 | import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder}
|
15 | 16 | import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
|
16 |
| -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias} |
| 17 | +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias, Union} |
17 | 18 |
|
18 | 19 | class PPLLogicalPlanJoinTranslatorTestSuite
|
19 | 20 | extends SparkFunSuite
|
@@ -905,4 +906,76 @@ class PPLLogicalPlanJoinTranslatorTestSuite
|
905 | 906 | val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
|
906 | 907 | comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
|
907 | 908 | }
|
| 909 | + |
| 910 | + test("join syntax check: only allows two-tables join") { |
| 911 | + val expectedMessage = |
| 912 | + s"Join command only support two tables. But got [$testTable2, $testTable3]" |
| 913 | + val thrown1 = intercept[SyntaxCheckException] { |
| 914 | + planTransformer.visit( |
| 915 | + plan( |
| 916 | + pplParser, |
| 917 | + s""" |
| 918 | + | source = $testTable1 |
| 919 | + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2, $testTable3 |
| 920 | + | | fields t1.name, t2.name |
| 921 | + | """.stripMargin), |
| 922 | + new CatalystPlanContext) |
| 923 | + } |
| 924 | + assert(thrown1.getMessage === expectedMessage) |
| 925 | + |
| 926 | + val thrown2 = intercept[SyntaxCheckException] { |
| 927 | + planTransformer.visit( |
| 928 | + plan( |
| 929 | + pplParser, |
| 930 | + s""" |
| 931 | + | source = $testTable1 |
| 932 | + | | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2, $testTable3 as t2 |
| 933 | + | | fields t1.name, t2.name |
| 934 | + | """.stripMargin), |
| 935 | + new CatalystPlanContext) |
| 936 | + } |
| 937 | + assert(thrown2.getMessage === expectedMessage) |
| 938 | + } |
| 939 | + |
| 940 | + test("multiple tables in subsearch in right side should work") { |
| 941 | + val logicalPlan = planTransformer.visit( |
| 942 | + plan( |
| 943 | + pplParser, |
| 944 | + s""" |
| 945 | + | source = $testTable1 |
| 946 | + | | JOIN ON name = col [ source = $testTable2, $testTable3] |
| 947 | + | """.stripMargin), |
| 948 | + new CatalystPlanContext) |
| 949 | + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) |
| 950 | + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) |
| 951 | + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) |
| 952 | + val join = Join( |
| 953 | + table1, |
| 954 | + Union(Seq(table2, table3), byName = true, allowMissingCol = true), |
| 955 | + Inner, |
| 956 | + Some(EqualTo(UnresolvedAttribute("name"), UnresolvedAttribute("col"))), |
| 957 | + JoinHint.NONE) |
| 958 | + val expectedPlan = Project(Seq(UnresolvedStar(None)), join) |
| 959 | + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) |
| 960 | + |
| 961 | + val logicalPlan2 = planTransformer.visit( |
| 962 | + plan( |
| 963 | + pplParser, |
| 964 | + s""" |
| 965 | + | source = $testTable1 as t1 |
| 966 | + | | JOIN ON name = col [ source = $testTable2, $testTable3] as t2 |
| 967 | + | """.stripMargin), |
| 968 | + new CatalystPlanContext) |
| 969 | + val join2 = Join( |
| 970 | + SubqueryAlias("t1", table1), |
| 971 | + Union( |
| 972 | + Seq(SubqueryAlias("t2", table2), SubqueryAlias("t2", table3)), |
| 973 | + byName = true, |
| 974 | + allowMissingCol = true), |
| 975 | + Inner, |
| 976 | + Some(EqualTo(UnresolvedAttribute("name"), UnresolvedAttribute("col"))), |
| 977 | + JoinHint.NONE) |
| 978 | + val expectedPlan2 = Project(Seq(UnresolvedStar(None)), join2) |
| 979 | + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) |
| 980 | + } |
908 | 981 | }
|
0 commit comments