diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 264d9b7c3a033..20b9682b9c681 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1659,6 +1659,11 @@ "Not allowed to implement multiple UDF interfaces, UDF class ." ] }, + "NAMED_ARGUMENTS_SUPPORT_DISABLED" : { + "message" : [ + "Cannot call function because named argument references are not enabled here. In this case, the named argument reference was . Set \"spark.sql.allowNamedFunctionArguments\" to \"true\" to turn on feature." + ] + }, "NESTED_AGGREGATE_FUNCTION" : { "message" : [ "It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query." diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 6c9b3a712665b..fb440ef8d376e 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -443,6 +443,7 @@ CONCAT_PIPE: '||'; HAT: '^'; COLON: ':'; ARROW: '->'; +FAT_ARROW : '=>'; HENT_START: '/*+'; HENT_END: '*/'; QUESTION: '?'; diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index d1e672e9472dc..ab6c0d0861f89 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -789,7 +789,7 @@ inlineTable ; functionTable - : funcName=functionName LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN tableAlias + : funcName=functionName LEFT_PAREN (functionArgument (COMMA functionArgument)*)? RIGHT_PAREN tableAlias ; tableAlias @@ -862,6 +862,15 @@ expression : booleanExpression ; +namedArgumentExpression + : key=identifier FAT_ARROW value=expression + ; + +functionArgument + : expression + | namedArgumentExpression + ; + expressionSeq : expression (COMMA expression)* ; @@ -921,7 +930,8 @@ primaryExpression | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor | LEFT_PAREN query RIGHT_PAREN #subqueryExpression | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause - | functionName LEFT_PAREN (setQuantifier? argument+=expression (COMMA argument+=expression)*)? RIGHT_PAREN + | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument + (COMMA argument+=functionArgument)*)? RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall | identifier ARROW expression #lambda diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NamedArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NamedArgumentExpression.scala new file mode 100644 index 0000000000000..e8e6980805bda --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NamedArgumentExpression.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types.DataType + +/** + * This represents an argument expression to a function call accompanied with an + * explicit reference to the corresponding argument name as a string. In this way, + * the analyzer can make sure that the provided values match up to the arguments + * as intended, and the arguments may appear in any order. + * This unary expression is unevaluable because we intend to replace it with + * the provided value itself during query analysis (after possibly rearranging + * the parsed argument list to match up the names to the expected function + * signature). + * + * SQL Syntax: key => value + * SQL grammar: key=identifier FAT_ARROW value=expression + * + * Example usage with the "encode" scalar function: + * SELECT encode("abc", charset => "utf-8"); + * The second argument generates NamedArgumentExpression("charset", Literal("utf-8")) + * SELECT encode(charset => "utf-8", value => "abc"); + * + * @param key The name of the function argument + * @param value The value of the function argument + */ +case class NamedArgumentExpression(key: String, value: Expression) + extends UnaryExpression with Unevaluable { + + override def dataType: DataType = value.dataType + + override def toString: String = s"$key => $value" + + // NamedArgumentExpression has a single child, which is its value expression, + // so the value expression can be resolved by Analyzer rules recursively. + // For example, when the value is a built-in function expression, + // it must be resolved by [[ResolveFunctions]] + override def child: Expression = value + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(value = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ca62de12e7b7b..9a395924c451c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1527,6 +1527,18 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit } } + private def extractNamedArgument(expr: FunctionArgumentContext, funcName: String) : Expression = { + Option(expr.namedArgumentExpression).map { n => + if (conf.getConf(SQLConf.ALLOW_NAMED_FUNCTION_ARGUMENTS)) { + NamedArgumentExpression(n.key.getText, expression(n.value)) + } else { + throw QueryCompilationErrors.namedArgumentsNotEnabledError(funcName, n.key.getText) + } + }.getOrElse { + expression(expr) + } + } + private def withTimeTravel( ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { val v = ctx.version @@ -1551,17 +1563,22 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit Seq.empty } - withFuncIdentClause(func.functionName, ident => { - if (ident.length > 1) { - throw QueryParsingErrors.invalidTableValuedFunctionNameError(ident, ctx) - } + withFuncIdentClause( + func.functionName, + ident => { + if (ident.length > 1) { + throw QueryParsingErrors.invalidTableValuedFunctionNameError(ident, ctx) + } + val args = func.functionArgument.asScala.map { e => + extractNamedArgument(e, func.functionName.getText) + }.toSeq - val tvf = UnresolvedTableValuedFunction(ident, func.expression.asScala.map(expression).toSeq) + val tvf = UnresolvedTableValuedFunction(ident, args) - val tvfAliases = if (aliases.nonEmpty) UnresolvedTVFAliases(ident, tvf, aliases) else tvf + val tvfAliases = if (aliases.nonEmpty) UnresolvedTVFAliases(ident, tvf, aliases) else tvf - tvfAliases.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) - }) + tvfAliases.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) + }) } /** @@ -2186,7 +2203,9 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit val name = ctx.functionName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) // Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13 - val arguments = ctx.argument.asScala.map(expression).toSeq match { + val arguments = ctx.argument.asScala.map { e => + extractNamedArgument(e, name) + }.toSeq match { case Seq(UnresolvedStar(None)) if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 91ebc12b5cd29..e6cbe068a9288 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -193,6 +193,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map("configKey" -> configKey)) } + def namedArgumentsNotEnabledError(functionName: String, argumentName: String): Throwable = { + new AnalysisException( + errorClass = "NAMED_ARGUMENTS_SUPPORT_DISABLED", + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "argument" -> toSQLId(argumentName)) + ) + } + def unresolvedUsingColForJoinError( colName: String, suggestion: String, side: String): Throwable = { new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 825cee5c6b985..d60f5d170e709 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -321,6 +321,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ALLOW_NAMED_FUNCTION_ARGUMENTS = buildConf("spark.sql.allowNamedFunctionArguments") + .doc("If true, Spark will turn on support for named parameters for all functions that has" + + " it implemented.") + .version("3.5.0") + .booleanConf + .createWithDefault(true) + val DYNAMIC_PARTITION_PRUNING_ENABLED = buildConf("spark.sql.optimizer.dynamicPartitionPruning.enabled") .doc("When true, we will generate predicate for partition column when it's used as join key") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 08d3f6b3d0730..1b9c2709ecd1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.internal.SQLConf @@ -329,6 +330,23 @@ class ExpressionParserSuite extends AnalysisTest { parameters = Map("error" -> "'x'", "hint" -> ": extra input 'x'")) } + test("function expressions with named arguments") { + assertEqual("encode(value => 'abc', charset => 'utf-8')", + $"encode".function(NamedArgumentExpression("value", Literal("abc")), + NamedArgumentExpression("charset", Literal("utf-8")))) + assertEqual("encode('abc', charset => 'utf-8')", + $"encode".function(Literal("abc"), NamedArgumentExpression("charset", Literal("utf-8")))) + assertEqual("encode(charset => 'utf-8', 'abc')", + $"encode".function(NamedArgumentExpression("charset", Literal("utf-8")), Literal("abc"))) + assertEqual("encode('abc', charset => 'utf' || '-8')", + $"encode".function(Literal("abc"), NamedArgumentExpression("charset", + Concat(Literal("utf") :: Literal("-8") :: Nil)))) + val unresolvedAlias = Project(Seq(UnresolvedAlias(Literal("1"))), OneRowRelation()) + assertEqual("encode(value => ((select '1')), charset => 'utf-8')", + $"encode".function(NamedArgumentExpression("value", ScalarSubquery(plan = unresolvedAlias)), + NamedArgumentExpression("charset", Literal("utf-8")))) + } + private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) test("lambda functions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index ded8aaf74305d..228a287e14f49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -1412,6 +1412,35 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select a, b from db.c; ;; ;", table("db", "c").select($"a", $"b")) } + test("table valued function with named arguments") { + // All named arguments + assertEqual( + "select * from my_tvf(arg1 => 'value1', arg2 => true)", + UnresolvedTableValuedFunction("my_tvf", + NamedArgumentExpression("arg1", Literal("value1")) :: + NamedArgumentExpression("arg2", Literal(true)) :: Nil).select(star())) + + // Unnamed and named arguments + assertEqual( + "select * from my_tvf(2, arg1 => 'value1', arg2 => true)", + UnresolvedTableValuedFunction("my_tvf", + Literal(2) :: + NamedArgumentExpression("arg1", Literal("value1")) :: + NamedArgumentExpression("arg2", Literal(true)) :: Nil).select(star())) + + // Mixed arguments + assertEqual( + "select * from my_tvf(arg1 => 'value1', 2, arg2 => true)", + UnresolvedTableValuedFunction("my_tvf", + NamedArgumentExpression("arg1", Literal("value1")) :: + Literal(2) :: + NamedArgumentExpression("arg2", Literal(true)) :: Nil).select(star())) + assertEqual( + "select * from my_tvf(group => 'abc')", + UnresolvedTableValuedFunction("my_tvf", + NamedArgumentExpression("group", Literal("abc")) :: Nil).select(star())) + } + test("SPARK-32106: TRANSFORM plan") { // verify schema less assertEqual( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out new file mode 100644 index 0000000000000..faa05535cb322 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out @@ -0,0 +1,112 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 98, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(Q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 105, + "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#')" + } ] +} + + +-- !query +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd')" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(Q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#), NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 87, + "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#')" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.INPUT_SIZE_NOT_ONE", + "sqlState" : "42K09", + "messageParameters" : { + "exprName" : "upperChar", + "sqlExpr" : "\"mask(namedargumentexpression(q), AbCD123-@$#, namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 98, + "fragment" : "mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/named-function-arguments.sql b/sql/core/src/test/resources/sql-tests/inputs/named-function-arguments.sql new file mode 100644 index 0000000000000..aeb7b1e85cd8c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/named-function-arguments.sql @@ -0,0 +1,5 @@ +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd'); +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#'); +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd'); +SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#'); +SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd'); diff --git a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out new file mode 100644 index 0000000000000..842374542ec6e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out @@ -0,0 +1,122 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 98, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(Q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 105, + "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#')" + } ] +} + + +-- !query +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd')" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"namedargumentexpression(Q)\"", + "inputName" : "upperChar", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#), NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 87, + "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#')" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.INPUT_SIZE_NOT_ONE", + "sqlState" : "42K09", + "messageParameters" : { + "exprName" : "upperChar", + "sqlExpr" : "\"mask(namedargumentexpression(q), AbCD123-@$#, namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 98, + "fragment" : "mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + } ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index 58e3fefc8bf0a..a7d5046245df9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -20,16 +20,33 @@ package org.apache.spark.sql.errors import org.apache.spark.SparkThrowable import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.test.SharedSparkSession // Turn of the length check because most of the tests check entire error messages // scalastyle:off line.size.limit -class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession { +class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQLHelper { private def parseException(sqlText: String): SparkThrowable = { intercept[ParseException](sql(sqlText).collect()) } + test("NAMED_ARGUMENTS_SUPPORT_DISABLED: named arguments not turned on") { + withSQLConf("spark.sql.allowNamedFunctionArguments" -> "false") { + checkError( + exception = parseException("SELECT * FROM encode(value => 'abc', charset => 'utf-8')"), + errorClass = "NAMED_ARGUMENTS_SUPPORT_DISABLED", + parameters = Map("functionName" -> toSQLId("encode"), "argument" -> toSQLId("value")) + ) + checkError( + exception = parseException("SELECT explode(arr => array(10, 20))"), + errorClass = "NAMED_ARGUMENTS_SUPPORT_DISABLED", + parameters = Map("functionName"-> toSQLId("explode"), "argument" -> toSQLId("arr")) + ) + } + } + test("UNSUPPORTED_FEATURE: LATERAL join with NATURAL join not supported") { checkError( exception = parseException("SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2)"), @@ -368,6 +385,25 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession { parameters = Map("error" -> "end of input", "hint" -> "")) } + def checkParseSyntaxError(sqlCommand: String, errorString: String, hint: String = ""): Unit = { + checkError( + exception = parseException(sqlCommand), + errorClass = "PARSE_SYNTAX_ERROR", + sqlState = "42601", + parameters = Map("error" -> errorString, "hint" -> hint) + ) + } + + test("PARSE_SYNTAX_ERROR: named arguments invalid syntax") { + checkParseSyntaxError("select * from my_tvf(arg1 ==> 'value1')", "'>'") + checkParseSyntaxError("select * from my_tvf(arg1 = => 'value1')", "'=>'") + checkParseSyntaxError("select * from my_tvf((arg1 => 'value1'))", "'=>'") + checkParseSyntaxError("select * from my_tvf(arg1 => )", "')'") + checkParseSyntaxError("select * from my_tvf(arg1 => , 42)", "','") + checkParseSyntaxError("select * from my_tvf(my_tvf.arg1 => 'value1')", "'=>'") + checkParseSyntaxError("select * from my_tvf(arg1 => table t1)", "'t1'", hint = ": extra input 't1'") + } + test("PARSE_SYNTAX_ERROR: extraneous input") { checkError( exception = parseException("select 1 1"),