Skip to content

Commit

Permalink
[SPARK-49025][CONNECT] Make Column implementation agnostic
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This makes Column API implementation agnostic. We do this by:
- Removing `Column.expr`. This has been replaced in the source by either the use of `Column` itself, the use of an Expression that wraps a ColumnNode, or by (implicit) conversions.
- Removing `Column.apply(e: Expression)`. This has been replaced in the source by the `ExpressionUtils.column` (implicit) method, or by the use of `Column`.
- Removing `TypedColumn.withTypedColumn(..)`. This has been replaced by direct calls to `TypedAggUtils.withInputType(...)`.
- Removing `Column.named` and `Column.generateAlias`. This has been moved to `ExpressionUtils.`.
- Making a bunch of pandas and arrow operators use a Column instead of an Expression.

### Why are the changes needed?
This is one of the last steps in our effort to unify the Scala Column API for Classic and Connect.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47785 from hvanhovell/SPARK-49025.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Aug 20, 2024
1 parent 5d2d6a3 commit 8fbbcb0
Show file tree
Hide file tree
Showing 61 changed files with 488 additions and 492 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}


// scalastyle:off: object.name
Expand All @@ -41,7 +42,7 @@ object functions {
def from_avro(
data: Column,
jsonFormatSchema: String): Column = {
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty))
AvroDataToCatalyst(data, jsonFormatSchema, Map.empty)
}

/**
Expand All @@ -62,7 +63,7 @@ object functions {
data: Column,
jsonFormatSchema: String,
options: java.util.Map[String, String]): Column = {
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap))
AvroDataToCatalyst(data, jsonFormatSchema, options.asScala.toMap)
}

/**
Expand All @@ -74,7 +75,7 @@ object functions {
*/
@Experimental
def to_avro(data: Column): Column = {
Column(CatalystDataToAvro(data.expr, None))
CatalystDataToAvro(data, None)
}

/**
Expand All @@ -87,6 +88,6 @@ object functions {
*/
@Experimental
def to_avro(data: Column, jsonFormatSchema: String): Column = {
Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema)))
CatalystDataToAvro(data, Some(jsonFormatSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.TypedColumn.expr"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.TypedColumn$"),

// ColumnNode conversions
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.Converter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$Converter$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$RichColumn"),

// Datasource V2 partition transforms
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"),
Expand Down Expand Up @@ -433,6 +439,9 @@ object CheckConnectJvmClientCompatibility {
// SQLImplicits
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.session"),

// Column API
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Column.expr"),

// Steaming API
ProblemFilters.exclude[MissingTypesProblem](
"org.apache.spark.sql.streaming.DataStreamWriter" // Client version extends Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import java.time.LocalDateTime
import java.util.Properties

import org.apache.spark.SparkException
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.tags.DockerTest
Expand Down Expand Up @@ -303,7 +303,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
ArrayType(DecimalType(2, 2), true))
// Test write null values.
df.select(df.queryExecution.analyzed.output.map { a =>
Column(Literal.create(null, a.dataType)).as(a.name)
lit(null).cast(a.dataType).as(a.name)
}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.protobuf.utils.ProtobufUtils

// scalastyle:off: object.name
Expand Down Expand Up @@ -66,15 +67,11 @@ object functions {
*/
@Experimental
def from_protobuf(
data: Column,
messageName: String,
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]): Column = {
Column(
ProtobufDataToCatalyst(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
)
data: Column,
messageName: String,
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]): Column = {
ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap)
}

/**
Expand All @@ -93,7 +90,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent)))
ProtobufDataToCatalyst(data, messageName, Some(fileContent))
}

/**
Expand All @@ -112,7 +109,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet)))
ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet))
}

/**
Expand All @@ -132,7 +129,7 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageClassName: String): Column = {
Column(ProtobufDataToCatalyst(data.expr, messageClassName))
ProtobufDataToCatalyst(data, messageClassName)
}

/**
Expand All @@ -156,7 +153,7 @@ object functions {
data: Column,
messageClassName: String,
options: java.util.Map[String, String]): Column = {
Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap))
ProtobufDataToCatalyst(data, messageClassName, None, options.asScala.toMap)
}

/**
Expand Down Expand Up @@ -194,7 +191,7 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet)))
CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet))
}
/**
* Converts a column into binary of protobuf format. The Protobuf definition is provided
Expand All @@ -216,9 +213,7 @@ object functions {
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
Column(
CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap)
)
CatalystDataToProtobuf(data, messageName, Some(fileContent), options.asScala.toMap)
}

/**
Expand All @@ -242,11 +237,7 @@ object functions {
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]
): Column = {
Column(
CatalystDataToProtobuf(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
)
CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap)
}

/**
Expand All @@ -266,7 +257,7 @@ object functions {
*/
@Experimental
def to_protobuf(data: Column, messageClassName: String): Column = {
Column(CatalystDataToProtobuf(data.expr, messageClassName))
CatalystDataToProtobuf(data, messageClassName)
}

/**
Expand All @@ -288,6 +279,6 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String])
: Column = {
Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap))
CatalystDataToProtobuf(data, messageClassName, None, options.asScala.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}

val mappedOutputCols = inputColNames.zip(tds).map { case (colName, td) =>
dataset.col(colName).expr.dataType match {
SchemaUtils.getSchemaField(dataset.schema, colName).dataType match {
case DoubleType =>
when(!col(colName).isNaN && col(colName) > td, lit(1.0))
.otherwise(lit(0.0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,12 @@ class StringIndexer @Since("1.4.0") (
private def getSelectedCols(dataset: Dataset[_], inputCols: Seq[String]): Seq[Column] = {
inputCols.map { colName =>
val col = dataset.col(colName)
if (col.expr.dataType == StringType) {
col
} else {
// We don't count for NaN values. Because `StringIndexerAggregator` only processes strings,
// we replace NaNs with null in advance.
when(!isnan(col), col).cast(StringType)
}
// We don't count for NaN values. Because `StringIndexerAggregator` only processes strings,
// we replace NaNs with null in advance.
val fpTypes = Seq(DoubleType, FloatType).map(_.catalogString)
when(typeof(col).isin(fpTypes: _*) && isnan(col), lit(null))
.otherwise(col)
.cast(StringType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
// Schema transformation.
val schema = dataset.schema

val vectorCols = $(inputCols).filter { c =>
dataset.col(c).expr.dataType match {
case _: VectorUDT => true
case _ => false
}
val inputColsWithField = $(inputCols).map { c =>
c -> SchemaUtils.getSchemaField(schema, c)
}

val vectorCols = inputColsWithField.collect {
case (c, field) if field.dataType.isInstanceOf[VectorUDT] => c
}
val vectorColsLengths = VectorAssembler.getLengths(
dataset, vectorCols.toImmutableArraySeq, $(handleInvalid))

val featureAttributesMap = $(inputCols).map { c =>
val field = SchemaUtils.getSchemaField(schema, c)
val featureAttributesMap = inputColsWithField.map { case (c, field) =>
field.dataType match {
case DoubleType =>
val attribute = Attribute.fromStructField(field)
Expand Down Expand Up @@ -144,8 +144,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
dataset(c).expr.dataType match {
val args = inputColsWithField.map { case (c, field) =>
field.dataType match {
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
Expand Down
10 changes: 4 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputT
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -248,16 +249,13 @@ private[ml] class SummaryBuilderImpl(
) extends SummaryBuilder {

override def summary(featuresCol: Column, weightCol: Column): Column = {

val agg = SummaryBuilderImpl.MetricsAggregate(
SummaryBuilderImpl.MetricsAggregate(
requestedMetrics,
requestedCompMetrics,
featuresCol.expr,
weightCol.expr,
featuresCol,
weightCol,
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)

Column(agg.toAggregateExpression())
}
}

Expand Down
1 change: 1 addition & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy"),
// SPARK-49022: Use Column API
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.TypedColumn.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.this")
)

Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/sql/classic/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ def _to_java_column(col: "ColumnOrName") -> "JavaObject":
return jcol


def _to_java_expr(col: "ColumnOrName") -> "JavaObject":
return _to_java_column(col).expr()


@overload
def _to_seq(sc: "SparkContext", cols: Iterable["JavaObject"]) -> "JavaObject":
...
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def applyInPandas(
udf = pandas_udf(func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc)
return DataFrame(jdf, self.session)

def applyInPandasWithState(
Expand Down Expand Up @@ -356,7 +356,7 @@ def applyInPandasWithState(
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.applyInPandasWithState(
udf_column._jc.expr(),
udf_column._jc,
self.session._jsparkSession.parseDataType(outputStructType.json()),
self.session._jsparkSession.parseDataType(stateStructType.json()),
outputMode,
Expand Down Expand Up @@ -523,7 +523,7 @@ def transformWithStateUDF(
udf_column = udf(*[df[col] for col in df.columns])

jdf = self._jgd.transformWithStateInPandas(
udf_column._jc.expr(),
udf_column._jc,
self.session._jsparkSession.parseDataType(outputStructType.json()),
outputMode,
timeMode,
Expand Down Expand Up @@ -653,7 +653,7 @@ def applyInArrow(
) # type: ignore[call-overload]
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc.expr())
jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc)
return DataFrame(jdf, self.session)

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
Expand Down Expand Up @@ -793,7 +793,7 @@ def applyInPandas(

all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc)
return DataFrame(jdf, self._gd1.session)

def applyInArrow(
Expand Down Expand Up @@ -891,7 +891,7 @@ def applyInArrow(

all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, udf_column._jc.expr())
jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, udf_column._jc)
return DataFrame(jdf, self._gd1.session)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/pandas/map_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def mapInPandas(
udf_column = udf(*[self[col] for col in self.columns])

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier, jrp)
jdf = self._jdf.mapInPandas(udf_column._jc, barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def mapInArrow(
Expand All @@ -75,7 +75,7 @@ def mapInArrow(
udf_column = udf(*[self[col] for col in self.columns])

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier, jrp)
jdf = self._jdf.mapInArrow(udf_column._jc, barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def _build_java_profile(
Expand Down
Loading

0 comments on commit 8fbbcb0

Please sign in to comment.