Skip to content

Commit 650cc1b

Browse files
authored
Merge branch 'apache:master' into windowFuncFix
2 parents 469baa9 + c2305ed commit 650cc1b

File tree

13 files changed

+857
-19
lines changed

13 files changed

+857
-19
lines changed

python/docs/source/reference/pyspark.sql/functions.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ Mathematical Functions
132132
radians
133133
rand
134134
randn
135+
random
135136
rint
136137
round
137138
sec
@@ -164,6 +165,7 @@ String Functions
164165
char
165166
char_length
166167
character_length
168+
chr
167169
collate
168170
collation
169171
concat_ws
@@ -192,6 +194,7 @@ String Functions
192194
overlay
193195
position
194196
printf
197+
quote
195198
randstr
196199
regexp_count
197200
regexp_extract
@@ -631,6 +634,7 @@ Misc Functions
631634
try_reflect
632635
typeof
633636
user
637+
uuid
634638
version
635639

636640

python/pyspark/ml/connect/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _test() -> None:
6161
from pyspark.testing import should_test_connect
6262

6363
if not should_test_connect:
64-
print(f"Skipping pyspark.ml.connect.functions doctests", file=sys.stderr)
64+
print("Skipping pyspark.ml.connect.functions doctests", file=sys.stderr)
6565
sys.exit(0)
6666

6767
import doctest

python/pyspark/sql/connect/tvf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _test() -> None:
130130
from pyspark.testing import should_test_connect
131131

132132
if not should_test_connect:
133-
print(f"Skipping pyspark.sql.connect.tvf doctests", file=sys.stderr)
133+
print("Skipping pyspark.sql.connect.tvf doctests", file=sys.stderr)
134134
sys.exit(0)
135135

136136
import doctest

python/pyspark/sql/tests/connect/test_df_debug.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,49 +17,45 @@
1717

1818
import unittest
1919

20-
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
21-
from pyspark.testing.connectutils import should_test_connect
20+
from pyspark.testing.connectutils import ReusedConnectTestCase
2221
from pyspark.testing.utils import have_graphviz, graphviz_requirement_message
2322

24-
if should_test_connect:
25-
from pyspark.sql.connect.dataframe import DataFrame
2623

27-
28-
class SparkConnectDataFrameDebug(SparkConnectSQLTestCase):
24+
class SparkConnectDataFrameDebug(ReusedConnectTestCase):
2925
def test_df_debug_basics(self):
30-
df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
26+
df = self.spark.range(100).repartition(10).groupBy("id").count()
3127
x = df.collect() # noqa: F841
3228
ei = df.executionInfo
3329

3430
root, graph = ei.metrics.extract_graph()
3531
self.assertIn(root, graph, "The root must be rooted in the graph")
3632

3733
def test_df_quey_execution_empty_before_execution(self):
38-
df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
34+
df = self.spark.range(100).repartition(10).groupBy("id").count()
3935
ei = df.executionInfo
4036
self.assertIsNone(ei, "The query execution must be None before the action is executed")
4137

4238
def test_df_query_execution_with_writes(self):
43-
df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
39+
df = self.spark.range(100).repartition(10).groupBy("id").count()
4440
df.write.save("/tmp/test_df_query_execution_with_writes", format="json", mode="overwrite")
4541
ei = df.executionInfo
4642
self.assertIsNotNone(
4743
ei, "The query execution must be None after the write action is executed"
4844
)
4945

5046
def test_query_execution_text_format(self):
51-
df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
47+
df = self.spark.range(100).repartition(10).groupBy("id").count()
5248
df.collect()
5349
self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
5450

5551
# Different execution mode.
56-
df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
52+
df = self.spark.range(100).repartition(10).groupBy("id").count()
5753
df.toPandas()
5854
self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
5955

6056
@unittest.skipIf(not have_graphviz, graphviz_requirement_message)
6157
def test_df_query_execution_metrics_to_dot(self):
62-
df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
58+
df = self.spark.range(100).repartition(10).groupBy("id").count()
6359
x = df.collect() # noqa: F841
6460
ei = df.executionInfo
6561

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
409409
throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowName)
410410

411411
case e: Expression if e.checkInputDataTypes().isFailure =>
412-
TypeCoercionValidation.failOnTypeCheckResult(e, operator)
412+
TypeCoercionValidation.failOnTypeCheckResult(e, Some(operator))
413413

414414
case c: Cast if !c.resolved =>
415415
throw SparkException.internalError(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ import org.apache.spark.sql.types.DataType
2727
object TypeCoercionValidation extends QueryErrorsBase {
2828
private val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Unit]("dataTypeMismatchError")
2929

30-
def failOnTypeCheckResult(e: Expression, operator: LogicalPlan): Nothing = {
30+
def failOnTypeCheckResult(e: Expression, operator: Option[LogicalPlan] = None): Nothing = {
3131
e.checkInputDataTypes() match {
3232
case checkRes: TypeCheckResult.DataTypeMismatch =>
3333
e.setTagValue(DATA_TYPE_MISMATCH_ERROR, ())
3434
e.dataTypeMismatch(e, checkRes)
3535
case TypeCheckResult.TypeCheckFailure(message) =>
3636
e.setTagValue(DATA_TYPE_MISMATCH_ERROR, ())
37-
val extraHint = TypeCoercionValidation.getHintForExpressionCoercion(operator)
37+
val extraHint = operator.map(getHintForExpressionCoercion(_)).getOrElse("")
3838
e.failAnalysis(
3939
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
4040
messageParameters = Map("sqlExpr" -> toSQLExpr(e), "msg" -> message, "hint" -> extraHint)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{
2828
withPosition,
2929
FunctionResolution,
3030
GetViewColumnByNameAndOrdinal,
31+
TypeCoercionValidation,
3132
UnresolvedAlias,
3233
UnresolvedAttribute,
3334
UnresolvedFunction,
@@ -984,6 +985,10 @@ class ExpressionResolver(
984985
}
985986

986987
private def validateResolvedExpressionGenerically(resolvedExpression: Expression): Unit = {
988+
if (resolvedExpression.checkInputDataTypes().isFailure) {
989+
TypeCoercionValidation.failOnTypeCheckResult(resolvedExpression)
990+
}
991+
987992
if (!resolvedExpression.resolved) {
988993
throwSinglePassFailedToResolveExpression(resolvedExpression)
989994
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.QueryContext
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.analysis._
23+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
2324
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
2425
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern}
2526
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
26-
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
27+
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
2728
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types._
2930

@@ -90,7 +91,7 @@ object ExtractValue {
9091
}
9192
}
9293

93-
trait ExtractValue extends Expression {
94+
trait ExtractValue extends Expression with QueryErrorsBase {
9495
override def nullIntolerant: Boolean = true
9596
final override val nodePatterns: Seq[TreePattern] = Seq(EXTRACT_VALUE)
9697
val child: Expression
@@ -314,6 +315,30 @@ case class GetArrayItem(
314315
})
315316
}
316317

318+
override def checkInputDataTypes(): TypeCheckResult = {
319+
(left.dataType, right.dataType) match {
320+
case (_: ArrayType, e2) if !e2.isInstanceOf[IntegralType] =>
321+
DataTypeMismatch(
322+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
323+
messageParameters = Map(
324+
"paramIndex" -> ordinalNumber(1),
325+
"requiredType" -> toSQLType(IntegralType),
326+
"inputSql" -> toSQLExpr(right),
327+
"inputType" -> toSQLType(right.dataType))
328+
)
329+
case (e1, _) if !e1.isInstanceOf[ArrayType] =>
330+
DataTypeMismatch(
331+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
332+
messageParameters = Map(
333+
"paramIndex" -> ordinalNumber(0),
334+
"requiredType" -> toSQLType(TypeCollection(ArrayType)),
335+
"inputSql" -> toSQLExpr(left),
336+
"inputType" -> toSQLType(left.dataType))
337+
)
338+
case _ => TypeCheckResult.TypeCheckSuccess
339+
}
340+
}
341+
317342
override protected def withNewChildrenInternal(
318343
newLeft: Expression, newRight: Expression): GetArrayItem =
319344
copy(child = newLeft, ordinal = newRight)

0 commit comments

Comments
 (0)