diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala new file mode 100644 index 0000000000000..6fbd70318b256 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST, SORT_NULLS_LAST} +import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING, SORT_DIRECTION_DESCENDING} +import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, FrameType} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder +import org.apache.spark.sql.expressions.ScalaUserDefinedFunction +import org.apache.spark.sql.internal._ + +/** + * Converter for [[ColumnNode]] to [[proto.Expression]] conversions. + */ +object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { + override def apply(node: ColumnNode): proto.Expression = { + val builder = proto.Expression.newBuilder() + // TODO(SPARK-49273) support Origin in Connect Scala Client. + node match { + case Literal(value, None, _) => + builder.setLiteral(toLiteralProtoBuilder(value)) + + case Literal(value, Some(dataType), _) => + builder.setLiteral(toLiteralProtoBuilder(value, dataType)) + + case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + val b = builder.getUnresolvedAttributeBuilder + .setUnparsedIdentifier(unparsedIdentifier) + .setIsMetadataColumn(isMetadataColumn) + planId.foreach(b.setPlanId) + + case UnresolvedStar(unparsedTarget, planId, _) => + val b = builder.getUnresolvedStarBuilder + unparsedTarget.foreach(b.setUnparsedTarget) + planId.foreach(b.setPlanId) + + case UnresolvedRegex(regex, planId, _) => + val b = builder.getUnresolvedRegexBuilder + .setColName(regex) + planId.foreach(b.setPlanId) + + case UnresolvedFunction(functionName, arguments, isDistinct, isUserDefinedFunction, _, _) => + // TODO(SPARK-49087) use internal namespace. + builder.getUnresolvedFunctionBuilder + .setFunctionName(functionName) + .setIsUserDefinedFunction(isUserDefinedFunction) + .setIsDistinct(isDistinct) + .addAllArguments(arguments.map(apply).asJava) + + case Alias(child, name, metadata, _) => + val b = builder.getAliasBuilder.setExpr(apply(child)) + name.foreach(b.addName) + metadata.foreach(m => b.setMetadata(m.json)) + + case Cast(child, dataType, evalMode, _) => + val b = builder.getCastBuilder + .setExpr(apply(child)) + .setType(DataTypeProtoConverter.toConnectProtoType(dataType)) + evalMode.foreach { mode => + val convertedMode = mode match { + case Cast.Try => proto.Expression.Cast.EvalMode.EVAL_MODE_TRY + case Cast.Ansi => proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI + case Cast.Legacy => proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY + } + b.setEvalMode(convertedMode) + } + + case SqlExpression(expression, _) => + builder.getExpressionStringBuilder.setExpression(expression) + + case s: SortOrder => + builder.setSortOrder(convertSortOrder(s)) + + case Window(windowFunction, windowSpec, _) => + val b = builder.getWindowBuilder + .setWindowFunction(apply(windowFunction)) + .addAllPartitionSpec(windowSpec.partitionColumns.map(apply).asJava) + .addAllOrderSpec(windowSpec.sortColumns.map(convertSortOrder).asJava) + windowSpec.frame.foreach { frame => + b.getFrameSpecBuilder + .setFrameType(frame.frameType match { + case WindowFrame.Row => FrameType.FRAME_TYPE_ROW + case WindowFrame.Range => FrameType.FRAME_TYPE_RANGE + }) + .setLower(convertFrameBoundary(frame.lower)) + .setUpper(convertFrameBoundary(frame.upper)) + } + + case UnresolvedExtractValue(child, extraction, _) => + builder.getUnresolvedExtractValueBuilder + .setChild(apply(child)) + .setExtraction(apply(extraction)) + + case UpdateFields(structExpression, fieldName, valueExpression, _) => + val b = builder.getUpdateFieldsBuilder + .setStructExpression(apply(structExpression)) + .setFieldName(fieldName) + valueExpression.foreach(v => b.setValueExpression(apply(v))) + + case v: UnresolvedNamedLambdaVariable => + builder.setUnresolvedNamedLambdaVariable(convertNamedLambdaVariable(v)) + + case LambdaFunction(function, arguments, _) => + builder.getLambdaFunctionBuilder + .setFunction(apply(function)) + .addAllArguments(arguments.map(convertNamedLambdaVariable).asJava) + + case InvokeInlineUserDefinedFunction(udf: ScalaUserDefinedFunction, arguments, false, _) => + val b = builder.getCommonInlineUserDefinedFunctionBuilder + .setScalarScalaUdf(udf.udf) + .setDeterministic(udf.deterministic) + .addAllArguments(arguments.map(apply).asJava) + udf.givenName.foreach(b.setFunctionName) + + case CaseWhenOtherwise(branches, otherwise, _) => + val b = builder.getUnresolvedFunctionBuilder + .setFunctionName("when") + branches.foreach { case (condition, value) => + b.addArguments(apply(condition)) + b.addArguments(apply(value)) + } + otherwise.foreach { value => + b.addArguments(apply(value)) + } + + case ProtoColumnNode(e, _) => + return e + + case node => + throw SparkException.internalError("Unsupported ColumnNode: " + node) + } + builder.build() + } + + private def convertSortOrder(s: SortOrder): proto.Expression.SortOrder = { + proto.Expression.SortOrder + .newBuilder() + .setChild(apply(s.child)) + .setDirection(s.sortDirection match { + case SortOrder.Ascending => SORT_DIRECTION_ASCENDING + case SortOrder.Descending => SORT_DIRECTION_DESCENDING + }) + .setNullOrdering(s.nullOrdering match { + case SortOrder.NullsFirst => SORT_NULLS_FIRST + case SortOrder.NullsLast => SORT_NULLS_LAST + }) + .build() + } + + private def convertFrameBoundary(boundary: WindowFrame.FrameBoundary): FrameBoundary = { + val builder = FrameBoundary.newBuilder() + boundary match { + case WindowFrame.UnboundedPreceding => builder.setUnbounded(true) + case WindowFrame.UnboundedFollowing => builder.setUnbounded(true) + case WindowFrame.CurrentRow => builder.setCurrentRow(true) + case WindowFrame.Value(value) => builder.setValue(apply(value)) + } + builder.build() + } + + private def convertNamedLambdaVariable( + v: UnresolvedNamedLambdaVariable): proto.Expression.UnresolvedNamedLambdaVariable = { + proto.Expression.UnresolvedNamedLambdaVariable.newBuilder().addNameParts(v.name).build() + } +} + +case class ProtoColumnNode( + expr: proto.Expression, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override def sql: String = expr.toString +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index dcf7f67551d30..781e5c17f5f83 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket} +import org.apache.spark.sql.internal.UserDefinedFunctionLike import org.apache.spark.sql.types.DataType import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils} @@ -101,13 +102,14 @@ case class ScalaUserDefinedFunction private[sql] ( serializedUdfPacket: Array[Byte], inputTypes: Seq[proto.DataType], outputType: proto.DataType, - name: Option[String], + givenName: Option[String], override val nullable: Boolean, override val deterministic: Boolean, aggregate: Boolean) - extends UserDefinedFunction { + extends UserDefinedFunction + with UserDefinedFunctionLike { - private[expressions] lazy val udf = { + private[sql] lazy val udf = { val scalaUdfBuilder = proto.ScalarScalaUDF .newBuilder() .setPayload(ByteString.copyFrom(serializedUdfPacket)) @@ -128,10 +130,10 @@ case class ScalaUserDefinedFunction private[sql] ( .setScalarScalaUdf(udf) .addAllArguments(exprs.map(_.expr).asJava) - name.foreach(udfBuilder.setFunctionName) + givenName.foreach(udfBuilder.setFunctionName) } - override def withName(name: String): ScalaUserDefinedFunction = copy(name = Option(name)) + override def withName(name: String): ScalaUserDefinedFunction = copy(givenName = Option(name)) override def asNonNullable(): ScalaUserDefinedFunction = copy(nullable = false) @@ -143,9 +145,11 @@ case class ScalaUserDefinedFunction private[sql] ( .setDeterministic(deterministic) .setScalarScalaUdf(udf) - name.foreach(builder.setFunctionName) + givenName.foreach(builder.setFunctionName) builder.build() } + + override def name: String = givenName.getOrElse("UDF") } object ScalaUserDefinedFunction { @@ -195,7 +199,7 @@ object ScalaUserDefinedFunction { serializedUdfPacket = udfPacketBytes, inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType), outputType = DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType), - name = None, + givenName = None, nullable = true, deterministic = true, aggregate = aggregate) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala index 1d8d164c9541c..d884d176841b7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala @@ -43,7 +43,7 @@ class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession { serializedUdfPacket = udfByteArray, inputTypes = Seq(ProtoDataTypes.IntegerType), outputType = ProtoDataTypes.IntegerType, - name = Some("dummyUdf"), + givenName = Some("dummyUdf"), nullable = true, deterministic = true, aggregate = false) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala new file mode 100644 index 0000000000000..28025534dc67a --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Expression.Window.WindowFrame.FrameBoundary +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveIntEncoder +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoDataTypes} +import org.apache.spark.sql.expressions.ScalaUserDefinedFunction +import org.apache.spark.sql.internal._ +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.types.{BinaryType, DataType, DoubleType, LongType, MetadataBuilder, ShortType, StringType, StructType} + +/** + * Test suite for [[ColumnNode]] to [[proto.Expression]] conversions. + */ +class ColumnNodeToProtoConverterSuite extends ConnectFunSuite { + private def testConversion( + node: => ColumnNode, + expected: proto.Expression): proto.Expression = { + val expression = ColumnNodeToProtoConverter(node) + assert(expression == expected) + expression + } + + private def expr(f: proto.Expression.Builder => Unit): proto.Expression = { + val builder = proto.Expression.newBuilder() + f(builder) + builder.build() + } + + private def attribute(name: String): proto.Expression = + expr(_.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name).setIsMetadataColumn(false)) + + private def structField( + name: String, + dataType: proto.DataType, + nullable: Boolean = true): proto.DataType.StructField = { + proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(dataType) + .setNullable(nullable) + .build() + } + + test("literal") { + testConversion(Literal(1), expr(_.getLiteralBuilder.setInteger(1).build())) + testConversion( + Literal("foo", Option(StringType)), + expr(_.getLiteralBuilder.setString("foo").build())) + val dataType = new StructType() + .add("_1", DoubleType) + .add("_2", StringType) + .add("_3", DoubleType) + .add("_4", StringType) + val stringTypeWithCollation = proto.DataType + .newBuilder() + .setString(proto.DataType.String.newBuilder().setCollation("UTF8_BINARY")) + .build() + testConversion( + Literal((12.0, "north", 60.0, "west"), Option(dataType)), + expr { b => + val builder = b.getLiteralBuilder.getStructBuilder + builder.getStructTypeBuilder.getStructBuilder + .addFields(structField("_1", ProtoDataTypes.DoubleType)) + .addFields(structField("_2", stringTypeWithCollation)) + .addFields(structField("_3", ProtoDataTypes.DoubleType)) + .addFields(structField("_4", stringTypeWithCollation)) + builder.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0)) + builder.addElements(proto.Expression.Literal.newBuilder().setString("north")) + builder.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0)) + builder.addElements(proto.Expression.Literal.newBuilder().setString("west")) + }) + } + + test("attribute") { + testConversion(UnresolvedAttribute("x"), attribute("x")) + testConversion( + UnresolvedAttribute("y", Option(44L), isMetadataColumn = true), + expr( + _.getUnresolvedAttributeBuilder + .setUnparsedIdentifier("y") + .setPlanId(44L) + .setIsMetadataColumn(true))) + } + + test("star") { + testConversion(UnresolvedStar(None), expr(_.getUnresolvedStarBuilder)) + testConversion( + UnresolvedStar(Option("x.y.z.*")), + expr(_.getUnresolvedStarBuilder.setUnparsedTarget("x.y.z.*"))) + testConversion( + UnresolvedStar(None, Option(10L)), + expr(_.getUnresolvedStarBuilder.setPlanId(10L))) + } + + test("regex") { + testConversion( + UnresolvedRegex("`(_1)?+.+`"), + expr(_.getUnresolvedRegexBuilder.setColName("`(_1)?+.+`"))) + testConversion( + UnresolvedRegex("a", planId = Option(11L)), + expr(_.getUnresolvedRegexBuilder.setColName("a").setPlanId(11L))) + } + + test("function") { + testConversion( + UnresolvedFunction("+", Seq(UnresolvedAttribute("a"), Literal(1))), + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("+") + .setIsDistinct(false) + .addArguments(attribute("a")) + .addArguments(expr(_.getLiteralBuilder.setInteger(1))))) + testConversion( + UnresolvedFunction( + "db1.myAgg", + Seq(UnresolvedAttribute("a")), + isDistinct = true, + isUserDefinedFunction = true), + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("db1.myAgg") + .setIsDistinct(true) + .setIsUserDefinedFunction(true) + .addArguments(attribute("a")))) + } + + test("alias") { + testConversion( + Alias(Literal("qwe"), "newA" :: Nil), + expr( + _.getAliasBuilder + .setExpr(expr(_.getLiteralBuilder.setString("qwe"))) + .addName("newA"))) + val metadata = new MetadataBuilder().putLong("q", 10).build() + testConversion( + Alias(UnresolvedAttribute("a"), "b" :: Nil, Option(metadata)), + expr( + _.getAliasBuilder + .setExpr(attribute("a")) + .addName("b") + .setMetadata("""{"q":10}"""))) + testConversion( + Alias(UnresolvedAttribute("complex"), "newA" :: "newB" :: Nil), + expr( + _.getAliasBuilder + .setExpr(attribute("complex")) + .addName("newA") + .addName("newB"))) + } + + private def testCast( + dataType: DataType, + colEvalMode: Cast.EvalMode, + catEvalMode: proto.Expression.Cast.EvalMode): Unit = { + testConversion( + Cast(UnresolvedAttribute("attr"), dataType, Option(colEvalMode)), + expr( + _.getCastBuilder + .setExpr(attribute("attr")) + .setType(DataTypeProtoConverter.toConnectProtoType(dataType)) + .setEvalMode(catEvalMode))) + } + + test("cast") { + testConversion( + Cast(UnresolvedAttribute("str"), DoubleType), + expr( + _.getCastBuilder + .setExpr(attribute("str")) + .setType(ProtoDataTypes.DoubleType))) + + testCast(LongType, Cast.Legacy, proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY) + testCast(BinaryType, Cast.Try, proto.Expression.Cast.EvalMode.EVAL_MODE_TRY) + testCast(ShortType, Cast.Ansi, proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI) + } + + private def testSortOrder( + colDirection: SortOrder.SortDirection, + colNullOrdering: SortOrder.NullOrdering, + catDirection: proto.Expression.SortOrder.SortDirection, + catNullOrdering: proto.Expression.SortOrder.NullOrdering): Unit = { + testConversion( + SortOrder(UnresolvedAttribute("unsorted"), colDirection, colNullOrdering), + expr( + _.getSortOrderBuilder + .setChild(attribute("unsorted")) + .setNullOrdering(catNullOrdering) + .setDirection(catDirection))) + } + + test("sortOrder") { + testSortOrder( + SortOrder.Ascending, + SortOrder.NullsFirst, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST) + testSortOrder( + SortOrder.Ascending, + SortOrder.NullsLast, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST) + testSortOrder( + SortOrder.Descending, + SortOrder.NullsFirst, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST) + testSortOrder( + SortOrder.Descending, + SortOrder.NullsLast, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST) + } + + private def testWindowFrame( + colFrameType: WindowFrame.FrameType, + colLower: WindowFrame.FrameBoundary, + colUpper: WindowFrame.FrameBoundary, + catFrameType: proto.Expression.Window.WindowFrame.FrameType, + catLower: proto.Expression.Window.WindowFrame.FrameBoundary, + catUpper: proto.Expression.Window.WindowFrame.FrameBoundary): Unit = { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Seq(SortOrder(UnresolvedAttribute("d"), SortOrder.Descending, SortOrder.NullsLast)), + Option(WindowFrame(colFrameType, colLower, colUpper)))), + expr( + _.getWindowBuilder + .setWindowFunction( + expr(_.getUnresolvedFunctionBuilder + .setFunctionName("sum") + .setIsDistinct(false) + .addArguments(attribute("a")))) + .addPartitionSpec(attribute("b")) + .addPartitionSpec(attribute("c")) + .addOrderSpec(proto.Expression.SortOrder + .newBuilder() + .setChild(attribute("d")) + .setDirection(proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING) + .setNullOrdering(proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST)) + .getFrameSpecBuilder + .setFrameType(catFrameType) + .setLower(catLower) + .setUpper(catUpper))) + } + + test("window") { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec(Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), Nil, None)), + expr( + _.getWindowBuilder + .setWindowFunction( + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("sum") + .setIsDistinct(false) + .addArguments(attribute("a")))) + .addPartitionSpec(attribute("b")) + .addPartitionSpec(attribute("c")))) + testWindowFrame( + WindowFrame.Row, + WindowFrame.Value(Literal(-10)), + WindowFrame.UnboundedFollowing, + proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_ROW, + FrameBoundary.newBuilder().setValue(expr(_.getLiteralBuilder.setInteger(-10))).build(), + FrameBoundary.newBuilder().setUnbounded(true).build()) + testWindowFrame( + WindowFrame.Range, + WindowFrame.UnboundedPreceding, + WindowFrame.CurrentRow, + proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_RANGE, + FrameBoundary.newBuilder().setUnbounded(true).build(), + FrameBoundary.newBuilder().setCurrentRow(true).build()) + } + + test("lambda") { + val colX = UnresolvedNamedLambdaVariable("x") + val catX = proto.Expression.UnresolvedNamedLambdaVariable + .newBuilder() + .addNameParts(colX.name) + .build() + testConversion( + LambdaFunction(UnresolvedFunction("+", Seq(colX, UnresolvedAttribute("y"))), Seq(colX)), + expr( + _.getLambdaFunctionBuilder + .setFunction( + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("+") + .addArguments(expr(_.setUnresolvedNamedLambdaVariable(catX))) + .addArguments(attribute("y")))) + .addArguments(catX))) + } + + test("sql") { + testConversion( + SqlExpression("1 + 1"), + expr(_.getExpressionStringBuilder.setExpression("1 + 1"))) + } + + test("caseWhen") { + testConversion( + CaseWhenOtherwise( + Seq(UnresolvedAttribute("c1") -> Literal("r1")), + Option(Literal("fallback"))), + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("when") + .addArguments(attribute("c1")) + .addArguments(expr(_.getLiteralBuilder.setString("r1"))) + .addArguments(expr(_.getLiteralBuilder.setString("fallback"))))) + } + + test("extract field") { + testConversion( + UnresolvedExtractValue(UnresolvedAttribute("struct"), Literal("cl_a")), + expr( + _.getUnresolvedExtractValueBuilder + .setChild(attribute("struct")) + .setExtraction(expr(_.getLiteralBuilder.setString("cl_a"))))) + } + + test("update field") { + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_b", Option(Literal("cl_a"))), + expr( + _.getUpdateFieldsBuilder + .setStructExpression(attribute("struct")) + .setFieldName("col_b") + .setValueExpression(expr(_.getLiteralBuilder.setString("cl_a"))))) + + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_c", None), + expr( + _.getUpdateFieldsBuilder + .setStructExpression(attribute("struct")) + .setFieldName("col_c"))) + } + + test("udf") { + val udf = + ScalaUserDefinedFunction((i: Int) => i, Seq(PrimitiveIntEncoder), PrimitiveIntEncoder) + val named = udf.withName("boo") + testConversion( + InvokeInlineUserDefinedFunction(named, Seq(UnresolvedAttribute(("a")))), + expr( + _.getCommonInlineUserDefinedFunctionBuilder + .setFunctionName("boo") + .setDeterministic(true) + .setScalarScalaUdf(named.udf) + .addArguments(attribute("a")))) + } + + test("extension") { + val e = attribute("name") + testConversion(ProtoColumnNode(e), e) + } + + test("unsupported") { + intercept[SparkException](ColumnNodeToProtoConverter(Nope())) + } +} + +private[connect] case class Nope(override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override def sql: String = "nope" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f3ae2187c579a..e92ce9139aae1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -121,7 +121,7 @@ class Column(val node: ColumnNode) extends Logging { def this(name: String) = this(withOrigin { name match { case "*" => internal.UnresolvedStar(None) - case _ if name.endsWith(".*") => internal.UnresolvedStar(Option(name.dropRight(2))) + case _ if name.endsWith(".*") => internal.UnresolvedStar(Option(name)) case _ => internal.UnresolvedAttribute(name) } }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 55fa107a57106..50838858d5886 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -57,7 +57,10 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) case UnresolvedStar(unparsedTarget, None, _) => - analysis.UnresolvedStar(unparsedTarget.map(analysis.UnresolvedAttribute.parseAttributeName)) + val target = unparsedTarget.map { t => + analysis.UnresolvedAttribute.parseAttributeName(t.stripSuffix(".*")) + } + analysis.UnresolvedStar(target) case UnresolvedStar(None, Some(planId), _) => analysis.UnresolvedDataFrameStar(planId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala index 0fbfe762df918..c993aa8e52031 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -95,7 +95,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { test("star") { testConversion(UnresolvedStar(None), analysis.UnresolvedStar(None)) testConversion( - UnresolvedStar(Option("x.y.z")), + UnresolvedStar(Option("x.y.z.*")), analysis.UnresolvedStar(Option(Seq("x", "y", "z")))) testConversion( UnresolvedStar(None, Option(10L)), @@ -282,6 +282,11 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { Seq(catX))) } + test("sql") { + // Direct comparison because Origin is a bit messed up. + assert(Converter(SqlExpression("1 + 1")) == Converter.parser.parseExpression("1 + 1")) + } + test("caseWhen") { testConversion( CaseWhenOtherwise(