From 8d1cb76066e45bf24952124d3edc4357303067e5 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 16 Oct 2024 14:57:31 +0200 Subject: [PATCH] [SPARK-49987][SQL] Fix the error prompt when `seedExpression` is non-foldable in `randstr` ### What changes were proposed in this pull request? The pr aims to - fix the `error prompt` when `seedExpression` is `non-foldable` in `randstr`. - use `toSQLId` to set the parameter value `inputName` for `randstr ` and `uniform` of `NON_FOLDABLE_INPUT`. ### Why are the changes needed? - Let me take an example ```scala val df = Seq(1.1).toDF("a") df.createOrReplaceTempView("t") sql("SELECT randstr(1, a) from t").show(false) ``` - Before image ```shell [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "randstr(1, a)" due to data type mismatch: the input seedExpression should be a foldable INT or SMALLINT expression; however, got "a". SQLSTATE: 42K09; line 1 pos 7; 'Project [unresolvedalias(randstr(1, a#5, false))] +- SubqueryAlias t +- View (`t`, [a#5]) +- Project [value#1 AS a#5] +- LocalRelation [value#1] ``` - After ```shell [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "randstr(1, a)" due to data type mismatch: the input seed should be a foldable INT or SMALLINT expression; however, got "a". SQLSTATE: 42K09; line 1 pos 7; 'Project [unresolvedalias(randstr(1, a#5, false))] +- SubqueryAlias t +- View (`t`, [a#5]) +- Project [value#1 AS a#5] +- LocalRelation [value#1] ``` - The `parameter` name (`seedExpression`) in the error message does not match the `parameter` name (`seed`) seen in docs by the end-user. image ### Does this PR introduce _any_ user-facing change? Yes, When `seed` is `non-foldable `, the end-user will get a consistent experience in the error prompt. ### How was this patch tested? Update existed UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48490 from panbingkun/SPARK-49987. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../sql/catalyst/expressions/randomExpressions.scala | 8 ++++---- .../resources/sql-tests/analyzer-results/random.sql.out | 8 ++++---- .../src/test/resources/sql-tests/results/random.sql.out | 8 ++++---- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 3cec83facd01d..16bdaa1f7f708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} +import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} @@ -263,7 +263,7 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression, result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { @@ -374,14 +374,14 @@ case class RandStr( var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess def requiredType = "INT or SMALLINT" Seq((length, "length", 0), - (seedExpression, "seedExpression", 1)).foreach { + (seedExpression, "seed", 1)).foreach { case (expr: Expression, name: String, index: Int) => if (result == TypeCheckResult.TypeCheckSuccess) { if (!expr.foldable) { result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 133cd6a60a4fb..31919381c99b6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -188,7 +188,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -211,7 +211,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -436,7 +436,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -459,7 +459,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 0b4e5e078ee15..01638abdcec6e 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -240,7 +240,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -265,7 +265,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -520,7 +520,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -545,7 +545,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 47691e1ccd40f..39c839ae5a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -478,7 +478,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "length", + "inputName" -> "`length`", "inputType" -> "INT or SMALLINT", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"randstr(a, 10)\""), @@ -530,7 +530,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "min", + "inputName" -> "`min`", "inputType" -> "integer or floating-point", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"uniform(a, 10)\""),