Skip to content

Commit 1a67dc6

Browse files
authored
Multiple tables as join right without subsearch is not allowed (#1088)
1 parent 1a81198 commit 1a67dc6

File tree

4 files changed

+133
-4
lines changed

4 files changed

+133
-4
lines changed

integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala

+41-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
99
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
1010
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder}
1111
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
12-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, LogicalPlan, Project, Sort, SubqueryAlias}
12+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, LogicalPlan, Project, Sort, SubqueryAlias, Union}
1313
import org.apache.spark.sql.streaming.StreamTest
1414

1515
class FlintSparkPPLJoinITSuite
@@ -1191,4 +1191,44 @@ class FlintSparkPPLJoinITSuite
11911191
| """.stripMargin))
11921192
assert(ex.getMessage.contains("`tt`.`name` cannot be resolved"))
11931193
}
1194+
1195+
test("test join with union") {
1196+
val frame = sql(s"""
1197+
| source = $testTable1
1198+
| | inner join left=a, right=b
1199+
| ON a.name = b.name
1200+
| [ source = $testTable2, $testTable2 ]
1201+
| | stats count(salary) by span(age, 10) as age_span
1202+
| """.stripMargin)
1203+
val expectedResults: Array[Row] =
1204+
Array(Row(2, 70), Row(4, 20), Row(4, 40), Row(2, 30))
1205+
assertSameRows(expectedResults, frame)
1206+
1207+
val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
1208+
val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))
1209+
val plan1 = SubqueryAlias("a", table1)
1210+
val plan2 = Union(
1211+
Seq(SubqueryAlias("b", table2), SubqueryAlias("b", table2)),
1212+
byName = true,
1213+
allowMissingCol = true)
1214+
1215+
val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name"))
1216+
val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE)
1217+
val salaryField = UnresolvedAttribute("salary")
1218+
val star = Seq(UnresolvedStar(None))
1219+
val aggregateExpressions =
1220+
Alias(
1221+
UnresolvedFunction(Seq("COUNT"), Seq(salaryField), isDistinct = false),
1222+
"count(salary)")()
1223+
val span = Alias(
1224+
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
1225+
"age_span")()
1226+
val aggregatePlan =
1227+
Aggregate(Seq(span), Seq(aggregateExpressions, span), joinPlan)
1228+
1229+
val expectedPlan = Project(star, aggregatePlan)
1230+
val logicalPlan: LogicalPlan = frame.queryExecution.logical
1231+
1232+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
1233+
}
11941234
}

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,9 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) {
354354
Optional<Expression> joinCondition = node.getJoinCondition()
355355
.map(c -> expressionAnalyzer.analyzeJoinCondition(c, context));
356356
context.resetNamedParseExpressions();
357+
LogicalPlan join = join(left, right, node.getJoinType(), joinCondition, node.getJoinHint()).clone();
357358
context.retainAllPlans(p -> p);
358-
return join(left, right, node.getJoinType(), joinCondition, node.getJoinHint());
359+
return join;
359360
});
360361
}
361362

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.antlr.v4.runtime.ParserRuleContext;
1111
import org.antlr.v4.runtime.Token;
1212
import org.antlr.v4.runtime.tree.ParseTree;
13+
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
1314
import org.opensearch.flint.spark.ppl.OpenSearchPPLParser;
1415
import org.opensearch.flint.spark.ppl.OpenSearchPPLParser.FillNullWithFieldVariousValuesContext;
1516
import org.opensearch.flint.spark.ppl.OpenSearchPPLParser.FillNullWithTheSameValueContext;
@@ -170,7 +171,21 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct
170171
if (ctx.sideAlias().rightAlias != null) {
171172
rightAlias = Optional.of(internalVisitExpression(ctx.sideAlias().rightAlias).toString());
172173
}
173-
174+
// "JOIN on id = uid table1,table2" are not allowed
175+
// "JOIN on id = uid table1,table2 as t2" are not allowed
176+
// "JOIN on id = uid [ source = table1,table2 ]" are allowed
177+
if (ctx.tableOrSubqueryClause().subSearch() == null
178+
&& ctx.tableOrSubqueryClause().tableSourceClause().tableSource().size() > 1) {
179+
UnresolvedPlan plan = visit(ctx.tableOrSubqueryClause());
180+
Relation relation = null;
181+
if (plan instanceof Relation) {
182+
relation = (Relation) plan;
183+
} else if (plan instanceof SubqueryAlias) {
184+
relation = (Relation)((SubqueryAlias) plan).getChild().get(0);
185+
}
186+
throw new SyntaxCheckException("Join command only support two tables."
187+
+ (relation == null ? "" : " But got " + relation.getQualifiedNames()));
188+
}
174189
UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause());
175190
// Add a SubqueryAlias to the right plan when the right alias is present and no duplicated alias existing in right.
176191
UnresolvedPlan right;

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala

+74-1
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
package org.opensearch.flint.spark.ppl
77

88
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
9+
import org.opensearch.sql.common.antlr.SyntaxCheckException
910
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
1011
import org.scalatest.matchers.should.Matchers
1112

1213
import org.apache.spark.SparkFunSuite
1314
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
1415
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder}
1516
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
16-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias}
17+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias, Union}
1718

1819
class PPLLogicalPlanJoinTranslatorTestSuite
1920
extends SparkFunSuite
@@ -905,4 +906,76 @@ class PPLLogicalPlanJoinTranslatorTestSuite
905906
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
906907
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
907908
}
909+
910+
test("join syntax check: only allows two-tables join") {
911+
val expectedMessage =
912+
s"Join command only support two tables. But got [$testTable2, $testTable3]"
913+
val thrown1 = intercept[SyntaxCheckException] {
914+
planTransformer.visit(
915+
plan(
916+
pplParser,
917+
s"""
918+
| source = $testTable1
919+
| | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2, $testTable3
920+
| | fields t1.name, t2.name
921+
| """.stripMargin),
922+
new CatalystPlanContext)
923+
}
924+
assert(thrown1.getMessage === expectedMessage)
925+
926+
val thrown2 = intercept[SyntaxCheckException] {
927+
planTransformer.visit(
928+
plan(
929+
pplParser,
930+
s"""
931+
| source = $testTable1
932+
| | JOIN left = t1 right = t2 ON t1.name = t2.name $testTable2, $testTable3 as t2
933+
| | fields t1.name, t2.name
934+
| """.stripMargin),
935+
new CatalystPlanContext)
936+
}
937+
assert(thrown2.getMessage === expectedMessage)
938+
}
939+
940+
test("multiple tables in subsearch in right side should work") {
941+
val logicalPlan = planTransformer.visit(
942+
plan(
943+
pplParser,
944+
s"""
945+
| source = $testTable1
946+
| | JOIN ON name = col [ source = $testTable2, $testTable3]
947+
| """.stripMargin),
948+
new CatalystPlanContext)
949+
val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
950+
val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))
951+
val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3"))
952+
val join = Join(
953+
table1,
954+
Union(Seq(table2, table3), byName = true, allowMissingCol = true),
955+
Inner,
956+
Some(EqualTo(UnresolvedAttribute("name"), UnresolvedAttribute("col"))),
957+
JoinHint.NONE)
958+
val expectedPlan = Project(Seq(UnresolvedStar(None)), join)
959+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
960+
961+
val logicalPlan2 = planTransformer.visit(
962+
plan(
963+
pplParser,
964+
s"""
965+
| source = $testTable1 as t1
966+
| | JOIN ON name = col [ source = $testTable2, $testTable3] as t2
967+
| """.stripMargin),
968+
new CatalystPlanContext)
969+
val join2 = Join(
970+
SubqueryAlias("t1", table1),
971+
Union(
972+
Seq(SubqueryAlias("t2", table2), SubqueryAlias("t2", table3)),
973+
byName = true,
974+
allowMissingCol = true),
975+
Inner,
976+
Some(EqualTo(UnresolvedAttribute("name"), UnresolvedAttribute("col"))),
977+
JoinHint.NONE)
978+
val expectedPlan2 = Project(Seq(UnresolvedStar(None)), join2)
979+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
980+
}
908981
}

0 commit comments

Comments
 (0)