diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala index 1af7558200de3..828a609a10e9c 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.Column +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} // scalastyle:off: object.name @@ -41,7 +42,7 @@ object functions { def from_avro( data: Column, jsonFormatSchema: String): Column = { - Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty)) + AvroDataToCatalyst(data, jsonFormatSchema, Map.empty) } /** @@ -62,7 +63,7 @@ object functions { data: Column, jsonFormatSchema: String, options: java.util.Map[String, String]): Column = { - Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap)) + AvroDataToCatalyst(data, jsonFormatSchema, options.asScala.toMap) } /** @@ -74,7 +75,7 @@ object functions { */ @Experimental def to_avro(data: Column): Column = { - Column(CatalystDataToAvro(data.expr, None)) + CatalystDataToAvro(data, None) } /** @@ -87,6 +88,6 @@ object functions { */ @Experimental def to_avro(data: Column, jsonFormatSchema: String): Column = { - Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema))) + CatalystDataToAvro(data, Some(jsonFormatSchema)) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 07c9e5190da00..9a5935e1c7410 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -306,6 +306,12 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.TypedColumn.expr"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.TypedColumn$"), + // ColumnNode conversions + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.Converter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$Converter$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$RichColumn"), + // Datasource V2 partition transforms ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"), @@ -433,6 +439,9 @@ object CheckConnectJvmClientCompatibility { // SQLImplicits ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.session"), + // Column API + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Column.expr"), + // Steaming API ProblemFilters.exclude[MissingTypesProblem]( "org.apache.spark.sql.streaming.DataStreamWriter" // Client version extends Logging diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 12a71dbd7c7f8..3eca9c5cb0612 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -24,9 +24,9 @@ import java.time.LocalDateTime import java.util.Properties import org.apache.spark.SparkException -import org.apache.spark.sql.{Column, DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -303,7 +303,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { ArrayType(DecimalType(2, 2), true)) // Test write null values. df.select(df.queryExecution.analyzed.output.map { a => - Column(Literal.create(null, a.dataType)).as(a.name) + lit(null).cast(a.dataType).as(a.name) }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 2700764399606..31050887936bd 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -20,6 +20,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.Column +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.protobuf.utils.ProtobufUtils // scalastyle:off: object.name @@ -66,15 +67,11 @@ object functions { */ @Experimental def from_protobuf( - data: Column, - messageName: String, - binaryFileDescriptorSet: Array[Byte], - options: java.util.Map[String, String]): Column = { - Column( - ProtobufDataToCatalyst( - data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap - ) - ) + data: Column, + messageName: String, + binaryFileDescriptorSet: Array[Byte], + options: java.util.Map[String, String]): Column = { + ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap) } /** @@ -93,7 +90,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent))) + ProtobufDataToCatalyst(data, messageName, Some(fileContent)) } /** @@ -112,7 +109,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet))) + ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet)) } /** @@ -132,7 +129,7 @@ object functions { */ @Experimental def from_protobuf(data: Column, messageClassName: String): Column = { - Column(ProtobufDataToCatalyst(data.expr, messageClassName)) + ProtobufDataToCatalyst(data, messageClassName) } /** @@ -156,7 +153,7 @@ object functions { data: Column, messageClassName: String, options: java.util.Map[String, String]): Column = { - Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap)) + ProtobufDataToCatalyst(data, messageClassName, None, options.asScala.toMap) } /** @@ -194,7 +191,7 @@ object functions { @Experimental def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet))) + CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet)) } /** * Converts a column into binary of protobuf format. The Protobuf definition is provided @@ -216,9 +213,7 @@ object functions { descFilePath: String, options: java.util.Map[String, String]): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - Column( - CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap) - ) + CatalystDataToProtobuf(data, messageName, Some(fileContent), options.asScala.toMap) } /** @@ -242,11 +237,7 @@ object functions { binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String] ): Column = { - Column( - CatalystDataToProtobuf( - data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap - ) - ) + CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap) } /** @@ -266,7 +257,7 @@ object functions { */ @Experimental def to_protobuf(data: Column, messageClassName: String): Column = { - Column(CatalystDataToProtobuf(data.expr, messageClassName)) + CatalystDataToProtobuf(data, messageClassName) } /** @@ -288,6 +279,6 @@ object functions { @Experimental def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) : Column = { - Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap)) + CatalystDataToProtobuf(data, messageClassName, None, options.asScala.toMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e6fec7f014d49..30f3e4c4af021 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -118,7 +118,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } val mappedOutputCols = inputColNames.zip(tds).map { case (colName, td) => - dataset.col(colName).expr.dataType match { + SchemaUtils.getSchemaField(dataset.schema, colName).dataType match { case DoubleType => when(!col(colName).isNaN && col(colName) > td, lit(1.0)) .otherwise(lit(0.0)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7cf39d4750314..6c10630e7bb82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -189,13 +189,12 @@ class StringIndexer @Since("1.4.0") ( private def getSelectedCols(dataset: Dataset[_], inputCols: Seq[String]): Seq[Column] = { inputCols.map { colName => val col = dataset.col(colName) - if (col.expr.dataType == StringType) { - col - } else { - // We don't count for NaN values. Because `StringIndexerAggregator` only processes strings, - // we replace NaNs with null in advance. - when(!isnan(col), col).cast(StringType) - } + // We don't count for NaN values. Because `StringIndexerAggregator` only processes strings, + // we replace NaNs with null in advance. + val fpTypes = Seq(DoubleType, FloatType).map(_.catalogString) + when(typeof(col).isin(fpTypes: _*) && isnan(col), lit(null)) + .otherwise(col) + .cast(StringType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index e9aeec0876dc8..831a8a33afecb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -87,17 +87,17 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) // Schema transformation. val schema = dataset.schema - val vectorCols = $(inputCols).filter { c => - dataset.col(c).expr.dataType match { - case _: VectorUDT => true - case _ => false - } + val inputColsWithField = $(inputCols).map { c => + c -> SchemaUtils.getSchemaField(schema, c) + } + + val vectorCols = inputColsWithField.collect { + case (c, field) if field.dataType.isInstanceOf[VectorUDT] => c } val vectorColsLengths = VectorAssembler.getLengths( dataset, vectorCols.toImmutableArraySeq, $(handleInvalid)) - val featureAttributesMap = $(inputCols).map { c => - val field = SchemaUtils.getSchemaField(schema, c) + val featureAttributesMap = inputColsWithField.map { case (c, field) => field.dataType match { case DoubleType => val attribute = Attribute.fromStructField(field) @@ -144,8 +144,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) val assembleFunc = udf { r: Row => VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() - val args = $(inputCols).map { c => - dataset(c).expr.dataType match { + val args = inputColsWithField.map { case (c, field) => + field.dataType match { case DoubleType => dataset(c) case _: VectorUDT => dataset(c) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 9388205a751ec..4c3242c132090 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputT import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -248,16 +249,13 @@ private[ml] class SummaryBuilderImpl( ) extends SummaryBuilder { override def summary(featuresCol: Column, weightCol: Column): Column = { - - val agg = SummaryBuilderImpl.MetricsAggregate( + SummaryBuilderImpl.MetricsAggregate( requestedMetrics, requestedCompMetrics, - featuresCol.expr, - weightCol.expr, + featuresCol, + weightCol, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - Column(agg.toAggregateExpression()) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 03bf9c89aa2dc..68fce9d2ff15d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -110,6 +110,7 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy"), // SPARK-49022: Use Column API ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.TypedColumn.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.this") ) diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index 2504ec7406482..931378a08187f 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -75,10 +75,6 @@ def _to_java_column(col: "ColumnOrName") -> "JavaObject": return jcol -def _to_java_expr(col: "ColumnOrName") -> "JavaObject": - return _to_java_column(col).expr() - - @overload def _to_seq(sc: "SparkContext", cols: Iterable["JavaObject"]) -> "JavaObject": ... diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index b93723fbc6254..06834553ea96a 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -238,7 +238,7 @@ def applyInPandas( udf = pandas_udf(func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP) df = self._df udf_column = udf(*[df[col] for col in df.columns]) - jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc) return DataFrame(jdf, self.session) def applyInPandasWithState( @@ -356,7 +356,7 @@ def applyInPandasWithState( df = self._df udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.applyInPandasWithState( - udf_column._jc.expr(), + udf_column._jc, self.session._jsparkSession.parseDataType(outputStructType.json()), self.session._jsparkSession.parseDataType(stateStructType.json()), outputMode, @@ -523,7 +523,7 @@ def transformWithStateUDF( udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.transformWithStateInPandas( - udf_column._jc.expr(), + udf_column._jc, self.session._jsparkSession.parseDataType(outputStructType.json()), outputMode, timeMode, @@ -653,7 +653,7 @@ def applyInArrow( ) # type: ignore[call-overload] df = self._df udf_column = udf(*[df[col] for col in df.columns]) - jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc.expr()) + jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc) return DataFrame(jdf, self.session) def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": @@ -793,7 +793,7 @@ def applyInPandas( all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) udf_column = udf(*all_cols) - jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) + jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc) return DataFrame(jdf, self._gd1.session) def applyInArrow( @@ -891,7 +891,7 @@ def applyInArrow( all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) udf_column = udf(*all_cols) - jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, udf_column._jc.expr()) + jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, udf_column._jc) return DataFrame(jdf, self._gd1.session) @staticmethod diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index b02fe018b688e..c11a8b9d8d4d2 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -53,7 +53,7 @@ def mapInPandas( udf_column = udf(*[self[col] for col in self.columns]) jrp = self._build_java_profile(profile) - jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier, jrp) + jdf = self._jdf.mapInPandas(udf_column._jc, barrier, jrp) return DataFrame(jdf, self.sparkSession) def mapInArrow( @@ -75,7 +75,7 @@ def mapInArrow( udf_column = udf(*[self[col] for col in self.columns]) jrp = self._build_java_profile(profile) - jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier, jrp) + jdf = self._jdf.mapInArrow(udf_column._jc, barrier, jrp) return DataFrame(jdf, self.sparkSession) def _build_java_profile( diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py index 1482d2407b37d..011563d7006e8 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/sql_formatter.py @@ -48,13 +48,14 @@ def _convert_value(self, val: Any, field_name: str) -> Optional[str]: from py4j.java_gateway import is_instance_of from pyspark import SparkContext - from pyspark.sql import Column, DataFrame + from pyspark.sql import Column, DataFrame, SparkSession if isinstance(val, Column): - assert SparkContext._gateway is not None + jsession = SparkSession.active()._jsparkSession + jexpr = jsession.expression(val._jc) + assert SparkContext._gateway is not None gw = SparkContext._gateway - jexpr = val._jc.expr() if is_instance_of( gw, jexpr, "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute" ) or is_instance_of( diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 04e886e4d35b7..9cf93938528f8 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -397,15 +397,13 @@ def _create_judf(self, func: Callable[..., Any]) -> "JavaObject": return judf def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: - from pyspark.sql.classic.column import _to_java_expr, _to_seq + from pyspark.sql.classic.column import _to_java_column, _to_seq sc = get_active_spark_context() assert sc._jvm is not None - jexprs = [_to_java_expr(arg) for arg in args] + [ - sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression( - key, _to_java_expr(value) - ) + jcols = [_to_java_column(arg) for arg in args] + [ + sc._jvm.PythonSQLUtils.namedArgumentExpression(key, _to_java_column(value)) for key, value in kwargs.items() ] @@ -424,9 +422,7 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: UserWarning, ) judf = self._judf - jUDFExpr = judf.builder(_to_seq(sc, jexprs)) - jPythonUDF = judf.fromUDFExpr(jUDFExpr) - return Column(jPythonUDF) + return Column(judf.apply(_to_seq(sc, jcols))) # Disallow enabling two profilers at the same time. if profiler_enabled and memory_profiler_enabled: @@ -450,7 +446,7 @@ def func(*args: Any, **kwargs: Any) -> Any: func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] judf = self._create_judf(func) - jUDFExpr = judf.builder(_to_seq(sc, jexprs)) + jUDFExpr = judf.builderWithColumns(_to_seq(sc, jcols)) jPythonUDF = judf.fromUDFExpr(jUDFExpr) id = jUDFExpr.resultId().id() sc.profiler_collector.add_profiler(id, profiler) @@ -468,14 +464,13 @@ def func(*args: Any, **kwargs: Any) -> Any: func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] judf = self._create_judf(func) - jUDFExpr = judf.builder(_to_seq(sc, jexprs)) + jUDFExpr = judf.builderWithColumns(_to_seq(sc, jcols)) jPythonUDF = judf.fromUDFExpr(jUDFExpr) id = jUDFExpr.resultId().id() sc.profiler_collector.add_profiler(id, memory_profiler) else: judf = self._judf - jUDFExpr = judf.builder(_to_seq(sc, jexprs)) - jPythonUDF = judf.fromUDFExpr(jUDFExpr) + jPythonUDF = judf.apply(_to_seq(sc, jcols)) return Column(jPythonUDF) # This function is for improving the online help system in the interactive interpreter. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ef7c2f6345b2..7352d2bf94a0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -912,6 +912,7 @@ object FunctionRegistry { registerInternalExpression[PandasMode]("pandas_mode") registerInternalExpression[EWM]("ewm") registerInternalExpression[NullIndex]("null_index") + registerInternalExpression[CastTimestampNTZToLong]("timestamp_ntz_to_long") private def makeExprInfoForVirtualOperator(name: String, usage: String): ExpressionInfo = { new ExpressionInfo( diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index c8aba5d19fe7f..a535f75719ecd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} -import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} +import org.apache.spark.sql.{withOrigin, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} @@ -79,6 +79,7 @@ import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeout import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils, UserDefinedFunctionUtils} +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst} import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ @@ -108,6 +109,7 @@ class SparkConnectPlanner( @Since("4.0.0") @DeveloperApi def session: SparkSession = sessionHolder.session + import sessionHolder.session.RichColumn private[connect] def parser = session.sessionState.sqlParser @@ -552,7 +554,7 @@ class SparkConnectPlanner( .ofRows(session, transformRelation(rel.getInput)) .stat .sampleBy( - col = Column(transformExpression(rel.getCol)), + col = column(transformExpression(rel.getCol)), fractions = fractions.toMap, seed = if (rel.hasSeed) rel.getSeed else Utils.random.nextLong) .logicalPlan @@ -644,17 +646,17 @@ class SparkConnectPlanner( val pythonUdf = transformPythonUDF(commonUdf) val cols = rel.getGroupingExpressionsList.asScala.toSeq.map(expr => - Column(transformExpression(expr))) + column(transformExpression(expr))) val group = Dataset .ofRows(session, transformRelation(rel.getInput)) .groupBy(cols: _*) pythonUdf.evalType match { case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF => - group.flatMapGroupsInPandas(pythonUdf).logicalPlan + group.flatMapGroupsInPandas(column(pythonUdf)).logicalPlan case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF => - group.flatMapGroupsInArrow(pythonUdf).logicalPlan + group.flatMapGroupsInArrow(column(pythonUdf)).logicalPlan case _ => throw InvalidPlanInput( @@ -763,10 +765,10 @@ class SparkConnectPlanner( case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => val inputCols = rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr => - Column(transformExpression(expr))) + column(transformExpression(expr))) val otherCols = rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr => - Column(transformExpression(expr))) + column(transformExpression(expr))) val input = Dataset .ofRows(session, transformRelation(rel.getInput)) @@ -980,7 +982,7 @@ class SparkConnectPlanner( private def transformApplyInPandasWithState(rel: proto.ApplyInPandasWithState): LogicalPlan = { val pythonUdf = transformPythonUDF(rel.getFunc) val cols = - rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) + rel.getGroupingExpressionsList.asScala.toSeq.map(expr => column(transformExpression(expr))) val outputSchema = parseSchema(rel.getOutputSchema) @@ -990,7 +992,7 @@ class SparkConnectPlanner( .ofRows(session, transformRelation(rel.getInput)) .groupBy(cols: _*) .applyInPandasWithState( - pythonUdf, + column(pythonUdf), outputSchema, stateSchema, rel.getOutputMode, @@ -1078,7 +1080,7 @@ class SparkConnectPlanner( Metadata.empty } - (alias.getName(0), Column(transformExpression(alias.getExpr)), metadata) + (alias.getName(0), column(transformExpression(alias.getExpr)), metadata) }.unzip3 Dataset @@ -1126,7 +1128,7 @@ class SparkConnectPlanner( private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = { val ids = rel.getIdsList.asScala.toArray.map { expr => - Column(transformExpression(expr)) + column(transformExpression(expr)) } if (!rel.hasValues) { @@ -1139,7 +1141,7 @@ class SparkConnectPlanner( transformRelation(rel.getInput)) } else { val values = rel.getValues.getValuesList.asScala.toArray.map { expr => - Column(transformExpression(expr)) + column(transformExpression(expr)) } Unpivot( @@ -1168,7 +1170,7 @@ class SparkConnectPlanner( private def transformCollectMetrics(rel: proto.CollectMetrics, planId: Long): LogicalPlan = { val metrics = rel.getMetricsList.asScala.toSeq.map { expr => - Column(transformExpression(expr)) + column(transformExpression(expr)) } val name = rel.getName val input = transformRelation(rel.getInput) @@ -2236,10 +2238,10 @@ class SparkConnectPlanner( private def transformAsOfJoin(rel: proto.AsOfJoin): LogicalPlan = { val left = Dataset.ofRows(session, transformRelation(rel.getLeft)) val right = Dataset.ofRows(session, transformRelation(rel.getRight)) - val leftAsOf = Column(transformExpression(rel.getLeftAsOf)) - val rightAsOf = Column(transformExpression(rel.getRightAsOf)) + val leftAsOf = column(transformExpression(rel.getLeftAsOf)) + val rightAsOf = column(transformExpression(rel.getRightAsOf)) val joinType = rel.getJoinType - val tolerance = if (rel.hasTolerance) Column(transformExpression(rel.getTolerance)) else null + val tolerance = if (rel.hasTolerance) column(transformExpression(rel.getTolerance)) else null val allowExactMatches = rel.getAllowExactMatches val direction = rel.getDirection @@ -2255,7 +2257,7 @@ class SparkConnectPlanner( allowExactMatches = allowExactMatches, direction = direction) } else { - val joinExprs = if (rel.hasJoinExpr) Column(transformExpression(rel.getJoinExpr)) else null + val joinExprs = if (rel.hasJoinExpr) column(transformExpression(rel.getJoinExpr)) else null left.joinAsOf( other = right, leftAsOf = leftAsOf, @@ -2296,7 +2298,7 @@ class SparkConnectPlanner( private def transformDrop(rel: proto.Drop): LogicalPlan = { var output = Dataset.ofRows(session, transformRelation(rel.getInput)) if (rel.getColumnsCount > 0) { - val cols = rel.getColumnsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) + val cols = rel.getColumnsList.asScala.toSeq.map(expr => column(transformExpression(expr))) output = output.drop(cols.head, cols.tail: _*) } if (rel.getColumnNamesCount > 0) { @@ -2371,7 +2373,7 @@ class SparkConnectPlanner( rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral) } else { RelationalGroupedDataset - .collectPivotValues(Dataset.ofRows(session, input), Column(pivotExpr)) + .collectPivotValues(Dataset.ofRows(session, input), column(pivotExpr)) .map(expressions.Literal.apply) } logical.Pivot( @@ -2697,12 +2699,12 @@ class SparkConnectPlanner( if (!namedArguments.isEmpty) { session.sql( sql.getQuery, - namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))), + namedArguments.asScala.toMap.transform((_, e) => column(transformExpression(e))), tracker) } else if (!posArguments.isEmpty) { session.sql( sql.getQuery, - posArguments.asScala.map(e => Column(transformExpression(e))).toArray, + posArguments.asScala.map(e => column(transformExpression(e))).toArray, tracker) } else if (!args.isEmpty) { session.sql( @@ -2953,7 +2955,7 @@ class SparkConnectPlanner( if (writeOperation.getPartitioningColumnsCount > 0) { val names = writeOperation.getPartitioningColumnsList.asScala .map(transformExpression) - .map(Column(_)) + .map(column) .toSeq w.partitionedBy(names.head, names.tail: _*) } @@ -2971,7 +2973,7 @@ class SparkConnectPlanner( w.create() } case proto.WriteOperationV2.Mode.MODE_OVERWRITE => - w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition))) + w.overwrite(column(transformExpression(writeOperation.getOverwriteCondition))) case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS => w.overwritePartitions() case proto.WriteOperationV2.Mode.MODE_APPEND => @@ -3521,7 +3523,7 @@ class SparkConnectPlanner( val sourceDs = Dataset.ofRows(session, transformRelation(cmd.getSourceTablePlan)) var mergeInto = sourceDs - .mergeInto(cmd.getTargetTableName, Column(transformExpression(cmd.getMergeCondition))) + .mergeInto(cmd.getTargetTableName, column(transformExpression(cmd.getMergeCondition))) .withNewMatchedActions(matchedActions: _*) .withNewNotMatchedActions(notMatchedActions: _*) .withNewNotMatchedBySourceActions(notMatchedBySourceActions: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 26df8cd9294b7..f3ae2187c579a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -22,16 +22,10 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{LEFT_EXPR, RIGHT_EXPR} -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.catalyst.parser.DataTypeParser import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{ColumnNode, ExpressionColumnNode, TypedAggUtils} +import org.apache.spark.sql.internal.ColumnNode import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -39,18 +33,8 @@ private[spark] object Column { def apply(colName: String): Column = new Column(colName) - def apply(expr: Expression): Column = Column(ExpressionColumnNode(expr)) - def apply(node: => ColumnNode): Column = withOrigin(new Column(node)) - private[sql] def generateAlias(e: Expression): String = { - e match { - case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => - a.aggregateFunction.toString - case expr => toPrettySQL(expr) - } - } - private[sql] def fn(name: String, inputs: Column*): Column = { fn(name, isDistinct = false, inputs: _*) } @@ -89,30 +73,9 @@ private[spark] object Column { @Stable class TypedColumn[-T, U]( node: ColumnNode, - private[sql] val encoder: Encoder[U], - private[sql] val inputType: Option[(ExpressionEncoder[_], Seq[Attribute])] = None) + private[sql] val encoder: Encoder[U]) extends Column(node) { - override lazy val expr: Expression = { - val expression = internal.ColumnNodeToExpressionConverter(node) - inputType match { - case Some((inputEncoder, inputAttributes)) => - TypedAggUtils.withInputType(expression, inputEncoder, inputAttributes) - case None => - expression - } - } - - /** - * Inserts the specific input type and schema into any expressions that are expected to operate - * on a decoded object. - */ - private[sql] def withInputType( - inputEncoder: ExpressionEncoder[_], - inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { - new TypedColumn[T, U](node, encoder, Option((inputEncoder, inputAttributes))) - } - /** * Gives the [[TypedColumn]] a name (alias). * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated @@ -155,8 +118,6 @@ class TypedColumn[-T, U]( */ @Stable class Column(val node: ColumnNode) extends Logging { - lazy val expr: Expression = internal.ColumnNodeToExpressionConverter(node) - def this(name: String) = this(withOrigin { name match { case "*" => internal.UnresolvedStar(None) @@ -184,36 +145,13 @@ class Column(val node: ColumnNode) extends Logging { override def hashCode: Int = this.node.normalized.hashCode() - /** - * Returns the expression for this column either with an existing or auto assigned name. - */ - private[sql] def named: NamedExpression = expr match { - case expr: NamedExpression => expr - - // Leave an unaliased generator with an empty list of names since the analyzer will generate - // the correct defaults after the nested expression's type has been resolved. - case g: Generator => MultiAlias(g, Nil) - - // If we have a top level Cast, there is a chance to give it a better alias, if there is a - // NamedExpression under this Cast. - case c: Cast => - c.transformUp { - case c @ Cast(_: NamedExpression, _, _, _) => UnresolvedAlias(c) - } match { - case ne: NamedExpression => ne - case _ => UnresolvedAlias(expr, Some(Column.generateAlias)) - } - - case expr: Expression => UnresolvedAlias(expr, Some(Column.generateAlias)) - } - /** * Provides a type hint about the expected return value of this column. This information can * be used by operations such as `select` on a [[Dataset]] to automatically convert the * results into the correct JVM types. * @since 1.6.0 */ - def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, encoderFor[U]) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, implicitly[Encoder[U]]) /** * Extracts a value or values from a complex type. @@ -1203,7 +1141,7 @@ class Column(val node: ColumnNode) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to)) + def cast(to: String): Column = cast(DataTypeParser.parseDataType(to)) /** * Casts the column to a different data type and the result is null on failure. @@ -1234,7 +1172,7 @@ class Column(val node: ColumnNode) extends Logging { * @since 4.0.0 */ def try_cast(to: String): Column = { - try_cast(CatalystSqlParser.parseDataType(to)) + try_cast(DataTypeParser.parseDataType(to)) } private def sortOrder( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 231d361810f84..2af5bce69087f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -26,6 +26,7 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -36,6 +37,7 @@ import org.apache.spark.util.ArrayImplicits._ */ @Stable final class DataFrameNaFunctions private[sql](df: DataFrame) { + import df.sparkSession.RichColumn /** * Returns a new `DataFrame` that drops rows containing any null or NaN values. @@ -398,7 +400,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { (attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) { replaceCol(attr, replacementMap) } else { - Column(attr) + column(attr) } } df.select(projections : _*) @@ -431,7 +433,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v: jl.Integer => fillCol[Integer](attr, v) case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue()) case v: String => fillCol[String](attr, v) - }.getOrElse(Column(attr)) + }.getOrElse(column(attr)) } df.select(projections : _*) } @@ -441,7 +443,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * with `replacement`. */ private def fillCol[T](attr: Attribute, replacement: T): Column = { - fillCol(attr.dataType, attr.name, Column(attr), replacement) + fillCol(attr.dataType, attr.name, column(attr), replacement) } /** @@ -468,7 +470,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val branches = replacementMap.flatMap { case (source, target) => Seq(Literal(source), buildExpr(target)) }.toSeq - Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) + column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) } private def convertToDouble(v: Any): Double = v match { @@ -502,7 +504,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. val predicate = AtLeastNNonNulls(minNonNulls, cols) - df.filter(Column(predicate)) + df.filter(column(predicate)) } private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame = { @@ -538,9 +540,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(_.semanticEquals(col))) { - fillCol(col.dataType, col.name, Column(col), value) + fillCol(col.dataType, col.name, column(col), value) } else { - Column(col) + column(col) } } df.select(projections : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 790d15267a574..15be4e2d22668 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,11 +22,9 @@ import java.{lang => jl, util => ju} import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, CountMinSketchAgg} +import org.apache.spark.sql.Encoders.BINARY import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions.{col, count_min_sketch, lit} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} @@ -503,16 +501,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { eps: Double, confidence: Double, seed: Int): CountMinSketch = withOrigin { - val countMinSketchAgg = new CountMinSketchAgg( - col.expr, - Literal(eps, DoubleType), - Literal(confidence, DoubleType), - Literal(seed, IntegerType) - ) - val bytes = df.select( - Column(countMinSketchAgg.toAggregateExpression(false)) - ).head().getAs[Array[Byte]](0) - countMinSketchAgg.deserialize(bytes) + val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed)) + val bytes = df.select(cms).as(BINARY).head() + CountMinSketch.readFrom(bytes) } /** @@ -561,14 +552,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 2.0.0 */ def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin { - val bloomFilterAgg = new BloomFilterAggregate( - col.expr, - Literal(expectedNumItems, LongType), - Literal(numBits, LongType) - ) - val bytes = df.select( - Column(bloomFilterAgg.toAggregateExpression(false)) - ).head().getAs[Array[Byte]](0) - bloomFilterAgg.deserialize(bytes) + val bf = Column.internalFn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) + val bytes = df.select(bf).as(BINARY).head() + BloomFilter.readFrom(bytes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 9b824074533af..d9ad0003a5255 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -41,6 +41,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) private val df: DataFrame = ds.toDF() private val sparkSession = ds.sparkSession + import sparkSession.expression private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) @@ -88,7 +89,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = { def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) - val asTransforms = (column +: columns).map(_.expr).map { + val asTransforms = (column +: columns).map(expression).map { case PartitionTransform.YEARS(Seq(attr: Attribute)) => LogicalExpressions.years(ref(attr.name)) case PartitionTransform.MONTHS(Seq(attr: Attribute)) => @@ -185,7 +186,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { val overwrite = OverwriteByExpression.byName( - UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap) + UnresolvedRelation(tableName), logicalPlan, expression(condition), options.toMap) runCommand(overwrite) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 94129d2e8b58b..fcb15760e29a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -60,7 +60,9 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -220,6 +222,8 @@ class Dataset[T] private[sql]( queryExecution.sparkSession } + import sparkSession.RichColumn + // A globally unique id of this Dataset. private[sql] val id = Dataset.curId.getAndIncrement() @@ -289,7 +293,7 @@ class Dataset[T] private[sql]( truncate: Int): Seq[Seq[String]] = { val newDf = commandResultOptimized.toDF() val castCols = newDf.logicalPlan.output.map { col => - Column(ToPrettyString(col)) + column(ToPrettyString(col)) } val data = newDf.select(castCols: _*).take(numRows + 1) @@ -549,7 +553,7 @@ class Dataset[T] private[sql]( s"New column names (${colNames.size}): " + colNames.mkString(", ")) val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => - Column(oldAttribute).as(newName) + column(oldAttribute).as(newName) } select(newCols : _*) } @@ -1298,11 +1302,11 @@ class Dataset[T] private[sql]( tolerance: Column, allowExactMatches: Boolean, direction: String): DataFrame = { - val joinExprs = usingColumns.map { column => - EqualTo(resolve(column), other.resolve(column)) - }.reduceOption(And).map(Column.apply).orNull - - joinAsOf(other, leftAsOf, rightAsOf, joinExprs, joinType, + val joinConditions = usingColumns.map { name => + this(name) === other(name) + } + val joinCondition = joinConditions.reduceOption(_ && _).orNull + joinAsOf(other, leftAsOf, rightAsOf, joinCondition, joinType, tolerance, allowExactMatches, direction) } @@ -1470,12 +1474,12 @@ class Dataset[T] private[sql]( */ def col(colName: String): Column = colName match { case "*" => - Column(ResolvedStar(queryExecution.analyzed.output)) + column(ResolvedStar(queryExecution.analyzed.output)) case _ => if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) { colRegex(colName) } else { - Column(addDataFrameIdToCol(resolve(colName))) + column(addDataFrameIdToCol(resolve(colName))) } } @@ -1489,7 +1493,7 @@ class Dataset[T] private[sql]( * @since 3.5.0 */ def metadataColumn(colName: String): Column = - Column(queryExecution.analyzed.getMetadataAttributeByName(colName)) + column(queryExecution.analyzed.getMetadataAttributeByName(colName)) // Attach the dataset id and column position to the column reference, so that we can detect // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. @@ -1519,11 +1523,11 @@ class Dataset[T] private[sql]( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis colName match { case ParserUtils.escapedIdentifier(columnNameRegex) => - Column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) + column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => - Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) + column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) case _ => - Column(addDataFrameIdToCol(resolve(colName))) + column(addDataFrameIdToCol(resolve(colName))) } } @@ -1623,9 +1627,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = sparkSession.withActive { - select(exprs.map { expr => - Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) - }: _*) + select(exprs.map(functions.expr): _*) } /** @@ -1641,7 +1643,8 @@ class Dataset[T] private[sql]( */ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder: ExpressionEncoder[U1] = encoderFor(c1.encoder) - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) + val tc1 = withInputType(c1.named, exprEnc, logicalPlan.output) + val project = Project(tc1 :: Nil, logicalPlan) if (!encoder.isSerializedAsStructForTopLevel) { new Dataset[U1](sparkSession, project, encoder) @@ -1658,8 +1661,7 @@ class Dataset[T] private[sql]( */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(c => encoderFor(c.encoder)) - val namedColumns = - columns.map(_.withInputType(exprEnc, logicalPlan.output).named) + val namedColumns = columns.map(c => withInputType(c.named, exprEnc, logicalPlan.output)) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(execution, ExpressionEncoder.tuple(encoders)) } @@ -1737,7 +1739,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(conditionExpr: String): Dataset[T] = sparkSession.withActive { - filter(Column(sparkSession.sessionState.sqlParser.parseExpression(conditionExpr))) + filter(functions.expr(conditionExpr)) } /** @@ -2849,7 +2851,7 @@ class Dataset[T] private[sql]( resolver(field.name, colName) } match { case Some((colName: String, col: Column)) => col.as(colName) - case _ => Column(field) + case _ => column(field) } } @@ -3053,7 +3055,7 @@ class Dataset[T] private[sql]( val allColumns = queryExecution.analyzed.output val remainingCols = allColumns.filter { attribute => colNames.forall(n => !resolver(attribute.name, n)) - }.map(attribute => Column(attribute)) + }.map(attribute => column(attribute)) if (remainingCols.size == allColumns.size) { toDF() } else { @@ -3518,9 +3520,10 @@ class Dataset[T] private[sql]( * workers. */ private[sql] def mapInPandas( - func: PythonUDF, + funcCol: Column, isBarrier: Boolean = false, profile: ResourceProfile = null): DataFrame = { + val func = funcCol.expr Dataset.ofRows( sparkSession, MapInPandas( @@ -3537,9 +3540,10 @@ class Dataset[T] private[sql]( * Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches. */ private[sql] def mapInArrow( - func: PythonUDF, + funcCol: Column, isBarrier: Boolean = false, profile: ResourceProfile = null): DataFrame = { + val func = funcCol.expr Dataset.ofRows( sparkSession, MapInArrow( @@ -4245,7 +4249,7 @@ class Dataset[T] private[sql]( * This is for 'distributed-sequence' default index in pandas API on Spark. */ private[sql] def withSequenceColumn(name: String) = { - select(Column(DistributedSequenceID()).alias(name), col("*")) + select(column(DistributedSequenceID()).alias(name), col("*")) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a672f29966df7..e3ea33a7504bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.internal.TypedAggUtils +import org.apache.spark.sql.internal.TypedAggUtils.{aggKeyColumn, withInputType} import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} /** @@ -49,6 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession + import queryExecution.sparkSession._ /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -969,9 +970,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(c => encoderFor(c.encoder)) - val namedColumns = - columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = TypedAggUtils.aggKeyColumn(kExprEnc, groupingAttributes) + val namedColumns = columns.map(c => withInputType(c.named, vExprEnc, dataAttributes)) + val keyColumn = aggKeyColumn(kExprEnc, groupingAttributes) val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala index b7f9c96f82e04..d8042720577df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala @@ -46,7 +46,8 @@ class MergeIntoWriter[T] private[sql] ( private val df: DataFrame = ds.toDF() - private val sparkSession = ds.sparkSession + private[sql] val sparkSession = ds.sparkSession + import sparkSession.RichColumn private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) @@ -231,6 +232,8 @@ class MergeIntoWriter[T] private[sql] ( case class WhenMatched[T] private[sql]( mergeIntoWriter: MergeIntoWriter[T], condition: Option[Expression]) { + import mergeIntoWriter.sparkSession.RichColumn + /** * Specifies an action to update all matched rows in the DataFrame. * @@ -277,6 +280,7 @@ case class WhenMatched[T] private[sql]( case class WhenNotMatched[T] private[sql]( mergeIntoWriter: MergeIntoWriter[T], condition: Option[Expression]) { + import mergeIntoWriter.sparkSession.RichColumn /** * Specifies an action to insert all non-matched rows into the DataFrame. @@ -312,6 +316,7 @@ case class WhenNotMatched[T] private[sql]( case class WhenNotMatchedBySource[T] private[sql]( mergeIntoWriter: MergeIntoWriter[T], condition: Option[Expression]) { + import mergeIntoWriter.sparkSession.RichColumn /** * Specifies an action to update all non-matched rows in the target DataFrame when diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 076da3a880131..3cafe0d98f1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -35,6 +35,8 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.ExpressionUtils.{column, generateAlias} +import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -56,6 +58,7 @@ class RelationalGroupedDataset protected[sql]( private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { import RelationalGroupedDataset._ + import df.sparkSession._ private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @scala.annotation.nowarn("cat=deprecation") @@ -250,7 +253,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.exprEnc, df.logicalPlan.output).expr + withInputType(typed.expr, df.exprEnc, df.logicalPlan.output) case c => c.expr }) } @@ -508,7 +511,7 @@ class RelationalGroupedDataset protected[sql]( broadcastVars: Array[Broadcast[Object]], outputSchema: StructType): DataFrame = { val groupingNamedExpressions = groupingExprs.map(alias) - val groupingCols = groupingNamedExpressions.map(Column(_)) + val groupingCols = groupingNamedExpressions.map(column) val groupingDataFrame = df.select(groupingCols : _*) val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) Dataset.ofRows( @@ -538,7 +541,8 @@ class RelationalGroupedDataset protected[sql]( * This function uses Apache Arrow as serialization format between Java executors and Python * workers. */ - private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { + private[sql] def flatMapGroupsInPandas(column: Column): DataFrame = { + val expr = column.expr.asInstanceOf[PythonUDF] require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map pandas udf") require(expr.dataType.isInstanceOf[StructType], @@ -570,7 +574,8 @@ class RelationalGroupedDataset protected[sql]( * This function uses Apache Arrow as serialization format between Java executors and Python * workers. */ - private[sql] def flatMapGroupsInArrow(expr: PythonUDF): DataFrame = { + private[sql] def flatMapGroupsInArrow(column: Column): DataFrame = { + val expr = column.expr.asInstanceOf[PythonUDF] require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, "Must pass a grouped map arrow udf") require(expr.dataType.isInstanceOf[StructType], @@ -602,7 +607,8 @@ class RelationalGroupedDataset protected[sql]( */ private[sql] def flatMapCoGroupsInPandas( r: RelationalGroupedDataset, - expr: PythonUDF): DataFrame = { + column: Column): DataFrame = { + val expr = column.expr.asInstanceOf[PythonUDF] require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, "Must pass a cogrouped map pandas udf") require(this.groupingExprs.length == r.groupingExprs.length, @@ -648,7 +654,8 @@ class RelationalGroupedDataset protected[sql]( */ private[sql] def flatMapCoGroupsInArrow( r: RelationalGroupedDataset, - expr: PythonUDF): DataFrame = { + column: Column): DataFrame = { + val expr = column.expr.asInstanceOf[PythonUDF] require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, "Must pass a cogrouped map arrow udf") require(this.groupingExprs.length == r.groupingExprs.length, @@ -697,7 +704,7 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def applyInPandasWithState( - func: PythonUDF, + func: Column, outputStructType: StructType, stateStructType: StructType, outputModeStr: String, @@ -715,7 +722,7 @@ class RelationalGroupedDataset protected[sql]( val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) val outputAttrs = toAttributes(outputStructType) val plan = FlatMapGroupsInPandasWithState( - func, + func.expr, groupingAttrs, outputAttrs, stateStructType, @@ -737,7 +744,7 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def transformWithStateInPandas( - func: PythonUDF, + func: Column, outputStructType: StructType, outputModeStr: String, timeModeStr: String): DataFrame = { @@ -751,7 +758,7 @@ class RelationalGroupedDataset protected[sql]( val timeMode = TimeModes(timeModeStr) val plan = TransformWithStateInPandas( - func, + func.expr, groupingAttrs, outputAttrs, outputMode, @@ -808,7 +815,7 @@ private[sql] object RelationalGroupedDataset { private def alias(expr: Expression): NamedExpression = expr match { case expr: NamedExpression => expr - case a: AggregateExpression => UnresolvedAlias(a, Some(Column.generateAlias)) + case a: AggregateExpression => UnresolvedAlias(a, Some(generateAlias)) case _ if !expr.resolved => UnresolvedAlias(expr, None) case expr: Expression => Alias(expr, toPrettySQL(expr))() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d64623a744fe4..975d90df9047e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -39,7 +39,8 @@ import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -928,6 +929,24 @@ class SparkSession private( private[sql] def leafNodeDefaultParallelism: Int = { conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM).getOrElse(sparkContext.defaultParallelism) } + + private[sql] object Converter extends ColumnNodeToExpressionConverter with Serializable { + override protected def parser: ParserInterface = sessionState.sqlParser + override protected def conf: SQLConf = sessionState.conf + } + + private[sql] def expression(e: Column): Expression = Converter(e.node) + + private[sql] implicit class RichColumn(val column: Column) { + /** + * Returns the expression for this column. + */ + def expr: Expression = Converter(column.node) + /** + * Returns the expression for this column either with an existing or auto assigned name. + */ + def named: NamedExpression = ExpressionUtils.toNamed(expr) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index dbb3a333bfb11..6b497553dcb0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -32,12 +32,13 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.internal.{ExpressionColumnNode, SQLConf} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -140,40 +141,34 @@ private[sql] object PythonSQLUtils extends Logging { } } - def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr)) + def castTimestampNTZToLong(c: Column): Column = + Column.internalFn("timestamp_ntz_to_long", c) def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column = - Column(new EWM(e.expr, alpha, ignoreNA)) + Column.internalFn("ewm", e, lit(alpha), lit(ignoreNA)) - def nullIndex(e: Column): Column = Column(NullIndex(e.expr)) + def nullIndex(e: Column): Column = Column.internalFn("null_index", e) - def pandasProduct(e: Column, ignoreNA: Boolean): Column = { - Column(PandasProduct(e.expr, ignoreNA).toAggregateExpression(false)) - } + def pandasProduct(e: Column, ignoreNA: Boolean): Column = + Column.internalFn("pandas_product", e, lit(ignoreNA)) - def pandasStddev(e: Column, ddof: Int): Column = { - Column(PandasStddev(e.expr, ddof).toAggregateExpression(false)) - } + def pandasStddev(e: Column, ddof: Int): Column = + Column.internalFn("pandas_stddev", e, lit(ddof)) - def pandasVariance(e: Column, ddof: Int): Column = { - Column(PandasVariance(e.expr, ddof).toAggregateExpression(false)) - } + def pandasVariance(e: Column, ddof: Int): Column = + Column.internalFn("pandas_var", e, lit(ddof)) - def pandasSkewness(e: Column): Column = { - Column(PandasSkewness(e.expr).toAggregateExpression(false)) - } + def pandasSkewness(e: Column): Column = + Column.internalFn("pandas_skew", e) - def pandasKurtosis(e: Column): Column = { - Column(PandasKurtosis(e.expr).toAggregateExpression(false)) - } + def pandasKurtosis(e: Column): Column = + Column.internalFn("pandas_kurt", e) - def pandasMode(e: Column, ignoreNA: Boolean): Column = { - Column(PandasMode(e.expr, ignoreNA).toAggregateExpression(false)) - } + def pandasMode(e: Column, ignoreNA: Boolean): Column = + Column.internalFn("pandas_mode", e, lit(ignoreNA)) - def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = { - Column(PandasCovar(col1.expr, col2.expr, ddof).toAggregateExpression(false)) - } + def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = + Column.internalFn("pandas_covar", col1, col2, lit(ddof)) def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) @@ -184,13 +179,12 @@ private[sql] object PythonSQLUtils extends Logging { Column(internal.LambdaFunction(function.node, arguments)) } - def namedArgumentExpression(name: String, e: Column): Column = - Column(ExpressionColumnNode(NamedArgumentExpression(name, e.expr))) + def namedArgumentExpression(name: String, e: Column): Column = NamedArgumentExpression(name, e) def distributedIndex(): Column = { val expr = MonotonicallyIncreasingID() expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") - Column(ExpressionColumnNode(expr)) + expr } @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index ffef4996fe052..3832d73044078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -565,7 +565,7 @@ object ScalaAggregator { bufferEncoder = encoderFor(uda.aggregator.bufferEncoder), nullable = uda.nullable, isDeterministic = uda.deterministic, - aggregatorName = Option(uda.name)) + aggregatorName = uda.givenName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 7acd1cb0852b9..5a9adf8ab553d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -26,9 +26,9 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{COUNT, DATABASE_NAME, ERROR, TABLE_NAME, TIME} -import org.apache.spark.sql.{Column, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.Utils @@ -483,17 +484,17 @@ object CommandUtils extends Logging { partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = { val filter = if (partitionValueSpec.isDefined) { val filters = partitionValueSpec.get.map { - case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value)) + case (columnName, value) => col(columnName) === lit(value) } - filters.reduce(And) + filters.reduce(_ && _) } else { - Literal.TrueLiteral + lit(true) } val tableDf = sparkSession.table(tableMeta.identifier) - val partitionColumns = tableMeta.partitionColumnNames.map(Column(_)) + val partitionColumns = tableMeta.partitionColumnNames.map(col) - val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() + val df = tableDf.filter(filter).groupBy(partitionColumns: _*).count() df.collect().map { r => val partitionColumnValues = partitionColumns.indices.map { i => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 32406460ad713..ea1f5e6ae1340 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.types.{DataType, StructType} /** @@ -63,9 +64,11 @@ case class UserDefinedPythonFunction( } } + def builderWithColumns(e: Seq[Column]): Expression = builder(e.map(expression)) + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - fromUDFExpr(builder(exprs.map(_.expr))) + fromUDFExpr(builder(exprs.map(expression))) } /** @@ -73,8 +76,8 @@ case class UserDefinedPythonFunction( */ def fromUDFExpr(expr: Expression): Column = { expr match { - case udaf: PythonUDAF => Column(udaf.toAggregateExpression()) - case _ => Column(expr) + case udaf: PythonUDAF => udaf.toAggregateExpression() + case _ => expr } } } @@ -157,7 +160,7 @@ case class UserDefinedPythonTableFunction( /** Returns a [[DataFrame]] that will evaluate to calling this UDTF with the given input. */ def apply(session: SparkSession, exprs: Column*): DataFrame = { - val udtf = builder(exprs.map(_.expr), session.sessionState.sqlParser) + val udtf = builder(exprs.map(session.expression), session.sessionState.sqlParser) Dataset.ofRows(session, udtf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index c05562fc083ca..148766f9d0026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,12 +22,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.{functions, Column, DataFrame} +import org.apache.spark.sql.{functions, DataFrame} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -57,8 +58,7 @@ object FrequentItems extends Logging { val sizeOfMap = (1 / support).toInt val frequentItemCols = cols.map { col => - val aggExpr = new CollectFrequentItems(functions.col(col).expr, sizeOfMap) - Column(aggExpr.toAggregateExpression(isDistinct = false)).as(s"${col}_freqItems") + column(new CollectFrequentItems(functions.col(col), sizeOfMap)).as(s"${col}_freqItems") } df.select(frequentItemCols: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d059f5ada576b..dd7fee455b4df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -21,7 +21,6 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors @@ -73,7 +72,7 @@ object StatFunctions extends Logging { require(field.dataType.isInstanceOf[NumericType], s"Quantile calculation for column $colName with data type ${field.dataType}" + " is not supported.") - Column(Cast(Column(colName).expr, DoubleType)) + Column(colName).cast(DoubleType) } val emptySummaries = Array.fill(cols.size)( new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError)) @@ -252,7 +251,7 @@ object StatFunctions extends Logging { .withColumnRenamed("_1", "summary") } else { val valueColumns = columnNames.map { columnName => - Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) + element_at(col(columnName), col("summary")).as(columnName) } import org.apache.spark.util.ArrayImplicits._ ds.select(mapColumns: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala index 68bda47cf8ce0..b6340a35e7703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala @@ -43,13 +43,12 @@ private[sql] object TypedAggUtils { * Insert inputs into typed aggregate expressions. For untyped aggregate expressions, * the resolving is handled in the analyzer directly. */ - private[sql] def withInputType( - expr: Expression, + private[sql] def withInputType[T <: Expression]( + expr: T, inputEncoder: ExpressionEncoder[_], - inputAttributes: Seq[Attribute]): Expression = { + inputAttributes: Seq[Attribute]): T = { val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes) - - expr transform { + val transformed = expr transform { case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => ta.withInputInfo( deser = unresolvedDeserializer, @@ -57,6 +56,7 @@ private[sql] object TypedAggUtils { schema = inputEncoder.schema ) } + transformed.asInstanceOf[T] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index ea6e36680da45..55fa107a57106 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -16,20 +16,25 @@ */ package org.apache.spark.sql.internal +import scala.language.implicitConversions + import UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.SparkException -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{Column, Dataset, SparkSession} import org.apache.spark.sql.catalyst.{analysis, expressions, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAlias} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Generator, NamedExpression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} -import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF, TypedAggregateExpression} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator} +import org.apache.spark.sql.types.{DataType, NullType} /** * Convert a [[ColumnNode]] into an [[Expression]]. @@ -173,7 +178,13 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres toScalaUDF(udf, arguments.map(apply)) case ExpressionColumnNode(expression, _) => - expression + val transformed = expression.transformDown { + case ColumnNodeExpression(node) => apply(node) + } + transformed match { + case f: AggregateFunction => f.toAggregateExpression() + case _ => transformed + } case node => throw SparkException.internalError("Unsupported ColumnNode: " + node) @@ -234,7 +245,7 @@ private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressi /** * [[ColumnNode]] wrapper for an [[Expression]]. */ -private[sql] case class ExpressionColumnNode( +private[sql] case class ExpressionColumnNode private( expression: Expression, override val origin: Origin = CurrentOrigin.get) extends ColumnNode { override def normalize(): ExpressionColumnNode = { @@ -247,3 +258,67 @@ private[sql] case class ExpressionColumnNode( override def sql: String = expression.sql } + +private[sql] object ExpressionColumnNode { + def apply(e: Expression): ColumnNode = e match { + case ColumnNodeExpression(node) => node + case _ => new ExpressionColumnNode(e) + } +} + +private[internal] case class ColumnNodeExpression private(node: ColumnNode) extends Unevaluable { + override def nullable: Boolean = true + override def dataType: DataType = NullType + override def children: Seq[Expression] = Nil + override protected def withNewChildrenInternal(c: IndexedSeq[Expression]): Expression = this +} + +private[sql] object ColumnNodeExpression { + def apply(node: ColumnNode): Expression = node match { + case ExpressionColumnNode(e, _) => e + case _ => new ColumnNodeExpression(node) + } +} + +private[spark] object ExpressionUtils { + /** + * Create an Expression backed Column. + */ + implicit def column(e: Expression): Column = Column(ExpressionColumnNode(e)) + + /** + * Create an ColumnNode backed Expression. Please not that this has to be converted to an actual + * Expression before it is used. + */ + implicit def expression(c: Column): Expression = ColumnNodeExpression(c.node) + + /** + * Returns the expression either with an existing or auto assigned name. + */ + def toNamed(expr: Expression): NamedExpression = expr match { + case expr: NamedExpression => expr + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. + case g: Generator => MultiAlias(g, Nil) + + // If we have a top level Cast, there is a chance to give it a better alias, if there is a + // NamedExpression under this Cast. + case c: expressions.Cast => + c.transformUp { + case c @ expressions.Cast(_: NamedExpression, _, _, _) => UnresolvedAlias(c) + } match { + case ne: NamedExpression => ne + case _ => UnresolvedAlias(expr, Some(generateAlias)) + } + + case expr: Expression => UnresolvedAlias(expr, Some(generateAlias)) + } + + def generateAlias(e: Expression): String = { + e match { + case AggregateExpression(f: TypedAggregateExpression, _, _, _, _) => f.toString + case expr => toPrettySQL(expr) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 64b5128872610..68a7a4b8b2412 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -978,15 +978,15 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-37646: lit") { assert(lit($"foo") == $"foo") assert(lit($"foo") == $"foo") - assert(lit(1).expr == Column(Literal(1)).expr) - assert(lit(null).expr == Column(Literal(null, NullType)).expr) + assert(lit(1).expr == Literal(1)) + assert(lit(null).expr == Literal(null, NullType)) } test("typedLit") { assert(typedLit($"foo") == $"foo") assert(typedLit($"foo") == $"foo") - assert(typedLit(1).expr == Column(Literal(1)).expr) - assert(typedLit[String](null).expr == Column(Literal(null, StringType)).expr) + assert(typedLit(1).expr == Literal(1)) + assert(typedLit[String](null).expr == Literal(null, StringType)) val df = Seq(Tuple1(0)).toDF("a") // Only check the types `lit` cannot handle diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 48ac2cc5d4044..8c1cc6c3bea1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.objects.MapObjects import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BooleanType, Decimal, DoubleType, IntegerType, MapType, StringType, StructField, StructType} @@ -82,8 +83,8 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { // items: Seq[Int] => items.map { item => Seq(Struct(item)) } val result = df.select( - Column(MapObjects( - (item: Expression) => array(struct(Column(item))).expr, + column(MapObjects( + (item: Expression) => array(struct(column(item))).expr, $"items".expr, df.schema("items").dataType.asInstanceOf[ArrayType].elementType )) as "items" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 830becc84c604..cf8dcc0b8b2f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.io.File import java.lang.reflect.Modifier import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -26,13 +25,14 @@ import scala.util.Random import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -974,75 +974,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } test("SPARK-35876: arrays_zip should retain field names") { - withTempDir { dir => - val df = spark.sparkContext.parallelize( - Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6)))).toDF("val1", "val2") - val qualifiedDF = df.as("foo") - - // Fields are UnresolvedAttribute - val zippedDF1 = - qualifiedDF.select(Column(ArraysZip(Seq($"foo.val1".expr, $"foo.val2".expr))) as "zipped") - val maybeAlias1 = zippedDF1.queryExecution.logical.expressions.head - assert(maybeAlias1.isInstanceOf[Alias]) - val maybeArraysZip1 = maybeAlias1.children.head - assert(maybeArraysZip1.isInstanceOf[ArraysZip]) - assert(maybeArraysZip1.children.forall(_.isInstanceOf[UnresolvedAttribute])) - val file1 = new File(dir, "arrays_zip1") - zippedDF1.write.parquet(file1.getAbsolutePath) - val restoredDF1 = spark.read.parquet(file1.getAbsolutePath) - val fieldNames1 = restoredDF1.schema.head.dataType.asInstanceOf[ArrayType] - .elementType.asInstanceOf[StructType].fieldNames - assert(fieldNames1.toSeq === Seq("val1", "val2")) - - // Fields are resolved NamedExpression - val zippedDF2 = - df.select(Column(ArraysZip(Seq(df("val1").expr, df("val2").expr))) as "zipped") - val maybeAlias2 = zippedDF2.queryExecution.logical.expressions.head - assert(maybeAlias2.isInstanceOf[Alias]) - val maybeArraysZip2 = maybeAlias2.children.head - assert(maybeArraysZip2.isInstanceOf[ArraysZip]) - assert(maybeArraysZip2.children.forall( - e => e.isInstanceOf[AttributeReference] && e.resolved)) - val file2 = new File(dir, "arrays_zip2") - zippedDF2.write.parquet(file2.getAbsolutePath) - val restoredDF2 = spark.read.parquet(file2.getAbsolutePath) - val fieldNames2 = restoredDF2.schema.head.dataType.asInstanceOf[ArrayType] - .elementType.asInstanceOf[StructType].fieldNames - assert(fieldNames2.toSeq === Seq("val1", "val2")) - - // Fields are unresolved NamedExpression - val zippedDF3 = df.select( - Column(ArraysZip(Seq(($"val1" as "val3").expr, ($"val2" as "val4").expr))) as "zipped") - val maybeAlias3 = zippedDF3.queryExecution.logical.expressions.head - assert(maybeAlias3.isInstanceOf[Alias]) - val maybeArraysZip3 = maybeAlias3.children.head - assert(maybeArraysZip3.isInstanceOf[ArraysZip]) - assert(maybeArraysZip3.children.forall(e => e.isInstanceOf[Alias] && !e.resolved)) - val file3 = new File(dir, "arrays_zip3") - zippedDF3.write.parquet(file3.getAbsolutePath) - val restoredDF3 = spark.read.parquet(file3.getAbsolutePath) - val fieldNames3 = restoredDF3.schema.head.dataType.asInstanceOf[ArrayType] - .elementType.asInstanceOf[StructType].fieldNames - assert(fieldNames3.toSeq === Seq("val3", "val4")) - - // Fields are neither UnresolvedAttribute nor NamedExpression - val zippedDF4 = df.select( - Column(ArraysZip(Seq(array_sort($"val1").expr, array_sort($"val2").expr))) as "zipped") - val maybeAlias4 = zippedDF4.queryExecution.logical.expressions.head - assert(maybeAlias4.isInstanceOf[Alias]) - val maybeArraysZip4 = maybeAlias4.children.head - assert(maybeArraysZip4.isInstanceOf[ArraysZip]) - assert(maybeArraysZip4.children.forall { - case _: UnresolvedAttribute | _: NamedExpression => false - case _ => true - }) - val file4 = new File(dir, "arrays_zip4") - zippedDF4.write.parquet(file4.getAbsolutePath) - val restoredDF4 = spark.read.parquet(file4.getAbsolutePath) - val fieldNames4 = restoredDF4.schema.head.dataType.asInstanceOf[ArrayType] - .elementType.asInstanceOf[StructType].fieldNames - assert(fieldNames4.toSeq === Seq("0", "1")) - } + val df = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") + val qualifiedDF = df.as("foo") + + // Fields are UnresolvedAttribute + val zippedDF1 = qualifiedDF.select(arrays_zip($"foo.val1", $"foo.val2") as "zipped") + val zippedDF1expectedSchema = new StructType() + .add("zipped", ArrayType(new StructType() + .add("val1", IntegerType) + .add("val2", IntegerType))) + val zippedDF1Schema = zippedDF1.queryExecution.executedPlan.schema.toNullable + assert(zippedDF1Schema == zippedDF1expectedSchema) + + // Fields are resolved NamedExpression + val zippedDF2 = df.select(arrays_zip(df("val1"), df("val2")) as "zipped") + val zippedDF2Schema = zippedDF2.queryExecution.executedPlan.schema.toNullable + assert(zippedDF1Schema == zippedDF1expectedSchema) + + // Fields are unresolved NamedExpression + val zippedDF3 = df.select(arrays_zip($"val1" as "val3", $"val2" as "val4") as "zipped") + val zippedDF3expectedSchema = new StructType() + .add("zipped", ArrayType(new StructType() + .add("val3", IntegerType) + .add("val4", IntegerType))) + val zippedDF3Schema = zippedDF3.queryExecution.executedPlan.schema.toNullable + assert(zippedDF3Schema == zippedDF3expectedSchema) + + // Fields are neither UnresolvedAttribute nor NamedExpression + val zippedDF4 = df.select(arrays_zip(array_sort($"val1"), array_sort($"val2")) as "zipped") + val zippedDF4expectedSchema = new StructType() + .add("zipped", ArrayType(new StructType() + .add("0", IntegerType) + .add("1", IntegerType))) + val zippedDF4Schema = zippedDF4.queryExecution.executedPlan.schema.toNullable + assert(zippedDF4Schema == zippedDF4expectedSchema) } test("SPARK-40292: arrays_zip should retain field names in nested structs") { @@ -5485,7 +5450,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { val c = if (codegenFallback) { - Column(CodegenFallbackExpr(v.expr)) + column(CodegenFallbackExpr(v.expr)) } else { v } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 7dc40549a17bc..310b5a62c908a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeRef import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, explode, sum, year} +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.TestData diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 301ab28b9124b..70d7797ba89f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation} @@ -43,6 +43,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, StringWrapper, TestData2} @@ -1566,7 +1567,7 @@ class DataFrameSuite extends QueryTest test("SPARK-46794: exclude subqueries from LogicalRDD constraints") { withTempDir { checkpointDir => val subquery = - Column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan)) + column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan)) val df = spark.range(1000).filter($"id" === subquery) assert(df.logicalPlan.constraints.exists(_.exists(_.isInstanceOf[ScalarSubquery]))) @@ -2417,7 +2418,7 @@ class DataFrameSuite extends QueryTest | SELECT a, b FROM (SELECT a, b FROM VALUES (1, 2) AS t(a, b)) |) |""".stripMargin) - val stringCols = df.logicalPlan.output.map(Column(_).cast(StringType)) + val stringCols = df.logicalPlan.output.map(column(_).cast(StringType)) val castedDf = df.select(stringCols: _*) checkAnswer(castedDf, Row("1", "1") :: Row("1", "2") :: Nil) } @@ -2505,7 +2506,7 @@ class DataFrameSuite extends QueryTest assert(row.getInt(0).toString == row.getString(3)) } - val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value")))) + val v3 = map(lit("key"), lit("value")) val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback df.select(v3, v3, v4, v4).collect().foreach { row => assert(row.getMap(0).toString() == row.getMap(1).toString()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index e3aff9b36aece..ace4d5b294a78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, S import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -861,7 +862,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest lead($"value", 2, null, true).over(window), lead($"value", 3, null, true).over(window), lead(concat($"value", $"key"), 1, null, true).over(window), - Column(Lag($"value".expr, NonFoldableLiteral(1), Literal(null), true)).over(window), + column(Lag($"value".expr, NonFoldableLiteral(1), Literal(null), true)).over(window), lag($"value", 2).over(window), lag($"value", 0, null, true).over(window), lag($"value", 1, null, true).over(window), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 44709fd309cfb..cdea4446d9461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.expressions.SparkUserDefinedFunction +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType, VariantType} import org.apache.spark.util.ArrayImplicits._ @@ -1591,7 +1592,7 @@ object IntegratedUDFTestUtils extends SQLHelper { Cast(toScalaUDF(udf, Cast(expr, StringType) :: Nil), rt) } - def apply(exprs: Column*): Column = Column(builder(exprs.map(_.expr))) + def apply(exprs: Column*): Column = builder(exprs.map(expression)) val prettyName: String = "Scala UDF" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 67b4a3e319e40..6b4be982b3ecb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -25,10 +25,11 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, StructsToJson} import org.apache.spark.sql.catalyst.expressions.Cast._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -1395,17 +1396,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { val df = Seq(1).toDF("a") val schema = StructType(StructField("b", ObjectType(classOf[java.lang.Integer])) :: Nil) val row = InternalRow.fromSeq(Seq(Integer.valueOf(1))) - val structData = Literal.create(row, schema) + val structData = column(Literal.create(row, schema)) checkError( exception = intercept[AnalysisException] { - df.select($"a").withColumn("c", Column(StructsToJson(Map.empty, structData))).collect() + df.select($"a").withColumn("c", to_json(structData)).collect() }, errorClass = "DATATYPE_MISMATCH.CANNOT_CONVERT_TO_JSON", parameters = Map( "sqlExpr" -> "\"to_json(NAMED_STRUCT('b', 1))\"", "name" -> "`b`", "type" -> "\"JAVA.LANG.INTEGER\"" - ) + ), + context = ExpectedContext("to_json", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 268e876c282f7..fcab4a7580445 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -536,27 +536,27 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { test("SPARK-45033: maps as parameters") { import org.apache.spark.util.ArrayImplicits._ def fromArr(keys: Array[_], values: Array[_]): Column = { - map_from_arrays(Column(Literal(keys)), Column(Literal(values))) + map_from_arrays(lit(keys), lit(values)) } def callFromArr(keys: Array[_], values: Array[_]): Column = { - call_function("map_from_arrays", Column(Literal(keys)), Column(Literal(values))) + call_function("map_from_arrays", lit(keys), lit(values)) } def createMap(keys: Array[_], values: Array[_]): Column = { - val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => Column(Literal(v)))) + val zipped = keys.map(k => lit(k)).zip(values.map(v => lit(v))) map(zipped.flatMap { case (k, v) => Seq(k, v) }.toImmutableArraySeq: _*) } def callMap(keys: Array[_], values: Array[_]): Column = { - val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => Column(Literal(v)))) + val zipped = keys.map(k => lit(k)).zip(values.map(v => lit(v))) call_function("map", zipped.flatMap { case (k, v) => Seq(k, v) }.toImmutableArraySeq: _*) } def fromEntries(keys: Array[_], values: Array[_]): Column = { val structures = keys.zip(values) - .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))} + .map { case (k, v) => struct(lit(k), lit(v))} map_from_entries(array(structures.toImmutableArraySeq: _*)) } def callFromEntries(keys: Array[_], values: Array[_]): Column = { val structures = keys.zip(values) - .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))} + .map { case (k, v) => struct(lit(k), lit(v))} call_function("map_from_entries", call_function("array", structures.toImmutableArraySeq: _*)) } @@ -590,8 +590,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { spark.sql("SELECT :m['a'][1]", Map("m" -> map_from_arrays( - Column(Literal(Array("a"))), - array(map_from_arrays(Column(Literal(Array(1))), Column(Literal(Array(2)))))))), + lit(Array("a")), + array(map_from_arrays(lit(Array(1)), lit(Array(2))))))), Row(2)) // `str_to_map` is not supported checkError( @@ -599,8 +599,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { spark.sql("SELECT :m['a'][1]", Map("m" -> map_from_arrays( - Column(Literal(Array("a"))), - array(str_to_map(Column(Literal("a:1,b:2,c:3"))))))) + lit(Array("a")), + array(str_to_map(lit("a:1,b:2,c:3")))))) }, errorClass = "INVALID_SQL_ARG", parameters = Map("name" -> "m"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 986e625137a77..624bae70ce09c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.{column => toColumn, expression} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -88,10 +89,10 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSparkSession { test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") { val df = data.toDF("a", "b") - val max = TypedMax($"a".expr) + val max = TypedMax($"a") // Always uses SortAggregateExec - val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan + val sparkPlan = df.select(max).queryExecution.sparkPlan assert(!sparkPlan.isInstanceOf[HashAggregateExec]) } @@ -211,15 +212,9 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSparkSession { checkAnswer(query, expected) } - private def typedMax(column: Column): Column = { - val max = TypedMax(column.expr, nullable = false) - Column(max.toAggregateExpression()) - } + private def typedMax(column: Column): Column = TypedMax(column) - private def nullableTypedMax(column: Column): Column = { - val max = TypedMax(column.expr, nullable = true) - Column(max.toAggregateExpression()) - } + private def nullableTypedMax(column: Column): Column = TypedMax(column, nullable = true) } object TypedImperativeAggregateSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index c4dba850cf777..b5e22e71a8fa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, CreateNamedStruct, JsonToStructs, Literal, StructsToJson} -import org.apache.spark.sql.catalyst.expressions.variant.ParseJson +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray @@ -32,7 +32,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { test("parse_json/to_json round-trip") { def check(input: String, output: String = null): Unit = { val df = Seq(input).toDF("v") - val variantDF = df.select(Column(StructsToJson(Map.empty, ParseJson(Column("v").expr)))) + val variantDF = df.select(to_json(parse_json(col("v")))) val expected = if (output != null) output else input checkAnswer(variantDF, Seq(Row(expected))) } @@ -62,8 +62,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { test("from_json/to_json round-trip") { def check(input: String, output: String = null): Unit = { val df = Seq(input).toDF("v") - val variantDF = df.select(Column(StructsToJson(Map.empty, - JsonToStructs(VariantType, Map.empty, Column("v").expr)))) + val variantDF = df.select(to_json(from_json(col("v"), VariantType))) val expected = if (output != null) output else input checkAnswer(variantDF, Seq(Row(expected))) } @@ -127,22 +126,24 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { test("to_json with nested variant") { val df = Seq(1).toDF("v") - val variantDF1 = df.select( - Column(StructsToJson(Map.empty, CreateArray(Seq( - ParseJson(Literal("{}")), ParseJson(Literal("\"\"")), ParseJson(Literal("[1, 2, 3]"))))))) + val variantDF1 = df.select(to_json(array( + parse_json(lit("{}")), + parse_json(lit("\"\"")), + parse_json(lit("[1, 2, 3]"))))) checkAnswer(variantDF1, Seq(Row("[{},\"\",[1,2,3]]"))) val variantDF2 = df.select( - Column(StructsToJson(Map.empty, CreateNamedStruct(Seq( - Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str" }""")), - Literal("b"), ParseJson(Literal("[[]]")), - Literal("c"), ParseJson(Literal("false"))))))) + to_json(named_struct( + lit("a"), parse_json(lit("""{ "x": 1, "y": null, "z": "str" }""")), + lit("b"), parse_json(lit("[[]]")), + lit("c"), parse_json(lit("false")) + ))) checkAnswer(variantDF2, Seq(Row("""{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}"""))) } test("parse_json - Codegen Support") { val df = Seq(("1", """{"a": 1}""")).toDF("key", "v").toDF() - val variantDF = df.select(Column(ParseJson(Column("v").expr))) + val variantDF = df.select(parse_json(col("v"))) val plan = variantDF.queryExecution.executedPlan assert(plan.isInstanceOf[WholeStageCodegenExec]) val v = VariantBuilder.parseJson("""{"a":1}""") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala index d9c3848d3b6b7..1401048cf705d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -21,15 +21,15 @@ import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddStaticMagic} import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, Expression} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog} import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType, NumericType, StructType} @@ -81,7 +81,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark { s"codegen = $codegenEnabled" val benchmark = new Benchmark(name, N, output = output) benchmark.addCase(s"native_long_add", numIters = 3) { _ => - spark.range(N).select(Column(NativeAdd($"id".expr, $"id".expr, resultNullable))).noop() + spark.range(N).select(NativeAdd(col("id"), col("id"), resultNullable)).noop() } Seq("java_long_add_default", "java_long_add_magic", "java_long_add_static_magic", "scala_long_add_default", "scala_long_add_magic").foreach { functionName => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index b9f4e82cdd3c2..349b124970e32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -211,7 +211,7 @@ class QueryExecutionErrorsSuite test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") { def checkUnsupportedTypeInLiteral(v: Any, literal: String, dataType: String): Unit = { checkError( - exception = intercept[SparkRuntimeException] { lit(v).expr }, + exception = intercept[SparkRuntimeException] { spark.expression(lit(v)) }, errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE", parameters = Map("value" -> literal, "type" -> dataType), sqlState = "0A000") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala index e2ff7dc1c9aec..cf6553012ad86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Or} import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -87,9 +85,11 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { val numCols = 500 val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) - val predicate = (0 until numCols).map { idx => - (from_json($"value", schema).getField(s"col$idx") >= Literal(100000)).expr - }.asInstanceOf[Seq[Expression]].reduce(Or) + val jsonValue = from_json($"value", schema) + val predicates = (0 until numCols).map { idx => + jsonValue.getField(s"col$idx") >= lit(100000) + } + val predicate = predicates.reduce(_ || _) Seq( ("false", "true", "CODEGEN_ONLY"), @@ -108,7 +108,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { val df = spark.read .text(path.getAbsolutePath) - .where(Column(predicate)) + .where(predicate) df.write.mode("overwrite").format("noop").save() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 83130abb80fec..f1431f2a81b8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -28,13 +28,14 @@ import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchA import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException} -import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.ExpressionUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -46,6 +47,7 @@ import org.apache.spark.util.ArrayImplicits._ */ @ExtendedSQLTest class OrcFilterSuite extends OrcTest with SharedSparkSession { + import testImplicits.toRichColumn override protected def sparkConf: SparkConf = super @@ -58,8 +60,8 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checker: (SearchArgument) => Unit): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) + .select(output.map(e => ExpressionUtils.column(e)): _*) + .where(ExpressionUtils.column(predicate)) query.queryExecution.optimizedPlan match { case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _, _)) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 0c696acdedafa..48b4f8d4bc015 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.FileBasedDataSourceTest import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.util.ArrayImplicits._ @@ -117,8 +118,8 @@ trait OrcTest extends QueryTest with FileBasedDataSourceTest with BeforeAndAfter (implicit df: DataFrame): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) + .select(output.map(e => column(e)): _*) + .where(predicate) query.queryExecution.optimizedPlan match { case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _, _)) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala index 6ca9f6cd525fa..1ccb77424d958 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala @@ -21,12 +21,12 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl import org.apache.spark.SparkConf -import org.apache.spark.sql.{Column, DataFrame} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Predicate} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Predicate} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.orc.OrcShimUtils.{Operator, SearchArgument} +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.tags.ExtendedSQLTest @@ -44,15 +44,15 @@ class OrcV1FilterSuite extends OrcFilterSuite { checker: (SearchArgument) => Unit): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) + .select(output.map(e => column(e)): _*) + .where(predicate) var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(orcRelation) filters - }.flatten.reduceLeftOption(_ && _) + }.flatten.reduceLeftOption(And) assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") val (_, selectedFilters, _) = @@ -89,15 +89,15 @@ class OrcV1FilterSuite extends OrcFilterSuite { (implicit df: DataFrame): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) + .select(output.map(e => column(e)): _*) + .where(predicate) var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(orcRelation) filters - }.flatten.reduceLeftOption(_ && _) + }.flatten.reduceLeftOption(And) assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") val (_, selectedFilters, _) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 795e9f46a8d1d..5c382b1858716 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsR import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.internal.{ExpressionUtils, LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.internal.LegacyBehaviorPolicy.{CORRECTED, LEGACY} import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType.{INT96, TIMESTAMP_MICROS, TIMESTAMP_MILLIS} import org.apache.spark.sql.test.SharedSparkSession @@ -77,6 +77,7 @@ import org.apache.spark.util.ArrayImplicits._ * within the test. */ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSparkSession { + import testImplicits.toRichColumn protected def createParquetFilters( schema: MessageType, @@ -2259,8 +2260,8 @@ class ParquetV1FilterSuite extends ParquetFilterSuite { SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", SQLConf.NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST.key -> pushdownDsList) { val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) + .select(output.map(ExpressionUtils.column): _*) + .where(ExpressionUtils.column(predicate)) val nestedOrAttributes = predicate.collectFirst { case g: GetStructField => g @@ -2338,8 +2339,8 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> InferFiltersFromConstraints.ruleName, SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) + .select(output.map(ExpressionUtils.column): _*) + .where(ExpressionUtils.column(predicate)) query.queryExecution.optimizedPlan.collectFirst { case PhysicalOperation(_, filters, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index d3538cf65a50a..b704790e4296b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructType} class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession { + import testImplicits.toRichColumn private val EnsureRequirements = new EnsureRequirements() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index c283d39425812..5de106415ec68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { import testImplicits.newProductEncoder import testImplicits.localSeqToDatasetHolder + import testImplicits.toRichColumn private val EnsureRequirements = new EnsureRequirements() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 962021604e717..e4ea88067c7c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { + import testImplicits.toRichColumn private val EnsureRequirements = new EnsureRequirements() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 3573bafe482cc..8b11e0c69fa70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import scala.util.Random import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -31,6 +32,7 @@ import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} @@ -221,12 +223,13 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti df) // Case 4: InSet - val inSetExpr = expressions.InSet($"j".expr, + val inSetExpr = expressions.InSet( + UnresolvedAttribute("j"), Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3)) checkPrunedAnswers( bucketSpec, bucketValues = Seq(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3), - filterCondition = Column(inSetExpr), + filterCondition = column(inSetExpr), df) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala index 1eae0fe9ef088..433bc6b4380b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateDistributionSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.streaming import org.apache.spark.sql.IntegratedUDFTestUtils.{shouldTestPandasUDFs, TestGroupedMapPandasUDFWithState} -import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.MemoryStream @@ -83,8 +82,7 @@ class FlatMapGroupsInPandasWithStateDistributionSuite extends StreamTest .repartition($"key1") .groupBy($"key1", $"key2") .applyInPandasWithState( - pythonFunc(inputDataDS("key1"), inputDataDS("key2"), inputDataDS("timestamp")) - .expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("key1"), inputDataDS("key2"), inputDataDS("timestamp")), outputStructType, stateStructType, "Update", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 3f36544b117de..49825f7cde839 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming import org.apache.spark.sql.IntegratedUDFTestUtils._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update} import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec @@ -87,7 +86,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("value") .applyInPandasWithState( - pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("value")), outputStructType, stateStructType, "Update", @@ -161,7 +160,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("key") .applyInPandasWithState( - pythonFunc(inputDataDS("key"), inputDataDS("value")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("key"), inputDataDS("value")), outputStructType, stateStructType, "Update", @@ -234,7 +233,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("value") .applyInPandasWithState( - pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("value")), outputStructType, stateStructType, "Append", @@ -318,7 +317,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("value") .applyInPandasWithState( - pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("value")), outputStructType, stateStructType, "Update", @@ -434,7 +433,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { .withWatermark("eventTime", "10 seconds") .groupBy("key") .applyInPandasWithState( - pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDF("key"), inputDataDF("eventTime")), outputStructType, stateStructType, "Update", @@ -519,7 +518,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { .withWatermark("timestamp", "10 second") .groupBy("key") .applyInPandasWithState( - pythonFunc(inputDataDF("key"), inputDataDF("timestamp")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDF("key"), inputDataDF("timestamp")), outputStructType, stateStructType, "Update", @@ -589,7 +588,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("value") .applyInPandasWithState( - pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("value")), outputStructType, stateStructType, "Update", @@ -656,7 +655,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("key") .applyInPandasWithState( - pythonFunc(inputDataDS("key")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("key")), outputStructType, stateStructType, "Update", @@ -713,7 +712,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("key") .applyInPandasWithState( - pythonFunc(inputDataDS("key")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("key")), outputStructType, stateStructType, "Update", @@ -789,8 +788,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { .groupBy("key1", "key2") .applyInPandasWithState( pythonFunc( - inputDataDS("key1"), inputDataDS("key2"), inputDataDS("val1"), inputDataDS("val2") - ).expr.asInstanceOf[PythonUDF], + inputDataDS("key1"), inputDataDS("key2"), inputDataDS("val1"), inputDataDS("val2")), outputStructType, stateStructType, "Update", @@ -877,8 +875,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { .applyInPandasWithState( pythonFunc( inputDataDS("val1"), inputDataDS("key2"), inputDataDS("val2"), inputDataDS("key1"), - inputDataDS("val3") - ).expr.asInstanceOf[PythonUDF], + inputDataDS("val3")), outputStructType, stateStructType, "Update", @@ -949,7 +946,7 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { inputDataDS .groupBy("value") .applyInPandasWithState( - pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + pythonFunc(inputDataDS("value")), outputStructType, stateStructType, "Update", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b607b98ea078f..54d6840eb5775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -241,6 +241,7 @@ private[sql] trait SQLTestUtilsBase */ protected object testImplicits extends SQLImplicits { protected override def session: SparkSession = self.spark + implicit def toRichColumn(c: Column): SparkSession#RichColumn = session.RichColumn(c) } protected override def withSQLConf[T](pairs: (String, String)*)(f: => T): T = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 1a4700e7445b6..700a4984a4e39 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -23,10 +23,10 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Column, DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile +import org.apache.spark.sql.functions.{lit, percentile_approx => pa} import org.apache.spark.sql.hive.execution.TestingTypedCount import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.internal.ExpressionUtils.{column => toCol, expression} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.LongType @@ -117,8 +117,7 @@ object ObjectHashAggregateExecBenchmark extends SqlBasedBenchmark { output = output ) - def typed_count(column: Column): Column = - Column(TestingTypedCount(column.expr).toAggregateExpression()) + def typed_count(column: Column): Column = TestingTypedCount(column) val df = spark.range(N) @@ -205,10 +204,8 @@ object ObjectHashAggregateExecBenchmark extends SqlBasedBenchmark { benchmark.run() } - private def percentile_approx( - column: Column, percentage: Double, isDistinct: Boolean = false): Column = { - val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) - Column(approxPercentile.toAggregateExpression(isDistinct)) + private def percentile_approx(column: Column, percentage: Double): Column = { + pa(column, lit(percentage), lit(10000)) } override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 1e525c46a9cfb..2152a29b17ff4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -32,6 +32,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto with BeforeAndAfter with SQLTestUtils { import spark.implicits._ + import spark.RichColumn override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 4e2db21403599..bcd0644af0782 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -23,12 +23,12 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax import org.scalatest.matchers.must.Matchers._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.{col, count_distinct, first, lit, max, percentile_approx => pa} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.ExpressionUtils.{column => toCol, expression} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -177,14 +177,11 @@ class ObjectHashAggregateSuite } } - private def percentile_approx( - column: Column, percentage: Double, isDistinct: Boolean = false): Column = { - val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) - Column(approxPercentile.toAggregateExpression(isDistinct)) + private def percentile_approx(column: Column, percentage: Double): Column = { + pa(column, lit(percentage), lit(10000)) } - private def typed_count(column: Column): Column = - Column(TestingTypedCount(column.expr).toAggregateExpression()) + private def typed_count(column: Column): Column = TestingTypedCount(column) // Generates 50 random rows for a given schema. private def generateRandomRows(schemaForGenerator: StructType): Seq[Row] = {