Skip to content

Commit

Permalink
resolve
Browse files Browse the repository at this point in the history
  • Loading branch information
itholic committed Aug 10, 2023
2 parents b15a08e + 1846991 commit 4341ec4
Show file tree
Hide file tree
Showing 127 changed files with 11,299 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ import org.apache.spark.util.SparkClassUtils
class Dataset[T] private[sql] (
val sparkSession: SparkSession,
@DeveloperApi val plan: proto.Plan,
val encoder: AgnosticEncoder[T])
val encoder: Encoder[T])
extends Serializable {
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)

private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder)

override def toString: String = {
try {
val builder = new mutable.StringBuilder
Expand Down Expand Up @@ -828,7 +830,7 @@ class Dataset[T] private[sql] (
}

private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSortBuilder
.setInput(plan.getRoot)
.setIsGlobal(global)
Expand Down Expand Up @@ -878,8 +880,8 @@ class Dataset[T] private[sql] (
ProductEncoder[(T, U)](
ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
Seq(
EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty),
EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty)))
EncoderField(s"_1", this.agnosticEncoder, leftNullable, Metadata.empty),
EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty)))

sparkSession.newDataset(tupleEncoder) { builder =>
val joinBuilder = builder.getJoinBuilder
Expand All @@ -889,8 +891,8 @@ class Dataset[T] private[sql] (
.setJoinType(joinTypeValue)
.setJoinCondition(condition.expr)
.setJoinDataType(joinBuilder.getJoinDataTypeBuilder
.setIsLeftStruct(this.encoder.isStruct)
.setIsRightStruct(other.encoder.isStruct))
.setIsLeftStruct(this.agnosticEncoder.isStruct)
.setIsRightStruct(other.agnosticEncoder.isStruct))
}
}

Expand Down Expand Up @@ -1010,13 +1012,13 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
@scala.annotation.varargs
def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) {
builder =>
def hint(name: String, parameters: Any*): Dataset[T] =
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getHintBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
}
}

private def getPlanId: Option[Long] =
if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
Expand Down Expand Up @@ -1056,7 +1058,7 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
def as(alias: String): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSubqueryAliasBuilder
.setInput(plan.getRoot)
.setAlias(alias)
Expand Down Expand Up @@ -1238,8 +1240,9 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
def filter(condition: Column): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
}

/**
Expand Down Expand Up @@ -1355,12 +1358,12 @@ class Dataset[T] private[sql] (
def reduce(func: (T, T) => T): T = {
val udf = ScalarUserDefinedFunction(
function = func,
inputEncoders = encoder :: encoder :: Nil,
outputEncoder = encoder)
inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
outputEncoder = agnosticEncoder)
val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr

val result = sparkSession
.newDataset(encoder) { builder =>
.newDataset(agnosticEncoder) { builder =>
builder.getAggregateBuilder
.setInput(plan.getRoot)
.addAggregateExpressions(reduceExpr)
Expand Down Expand Up @@ -1718,7 +1721,7 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getLimitBuilder
.setInput(plan.getRoot)
.setLimit(n)
Expand All @@ -1730,7 +1733,7 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
def offset(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getOffsetBuilder
.setInput(plan.getRoot)
.setOffset(n)
Expand All @@ -1739,7 +1742,7 @@ class Dataset[T] private[sql] (
private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)(
f: proto.SetOperation.Builder => Unit): Dataset[T] = {
checkSameSparkSession(right)
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
f(
builder.getSetOpBuilder
.setSetOpType(setOpType)
Expand Down Expand Up @@ -2012,7 +2015,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSampleBuilder
.setInput(plan.getRoot)
.setWithReplacement(withReplacement)
Expand Down Expand Up @@ -2080,7 +2083,7 @@ class Dataset[T] private[sql] (
normalizedCumWeights
.sliding(2)
.map { case Array(low, high) =>
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSampleBuilder
.setInput(sortedInput)
.setWithReplacement(false)
Expand Down Expand Up @@ -2401,15 +2404,16 @@ class Dataset[T] private[sql] (

private def buildDropDuplicates(
columns: Option[Seq[String]],
withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
val dropBuilder = builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setWithinWatermark(withinWaterMark)
if (columns.isDefined) {
dropBuilder.addAllColumnNames(columns.get.asJava)
} else {
dropBuilder.setAllColumnsAsKeys(true)
}
withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
val dropBuilder = builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setWithinWatermark(withinWaterMark)
if (columns.isDefined) {
dropBuilder.addAllColumnNames(columns.get.asJava)
} else {
dropBuilder.setAllColumnsAsKeys(true)
}
}

/**
Expand Down Expand Up @@ -2630,9 +2634,9 @@ class Dataset[T] private[sql] (
def filter(func: T => Boolean): Dataset[T] = {
val udf = ScalarUserDefinedFunction(
function = func,
inputEncoders = encoder :: Nil,
inputEncoders = agnosticEncoder :: Nil,
outputEncoder = PrimitiveBooleanEncoder)
sparkSession.newDataset[T](encoder) { builder =>
sparkSession.newDataset[T](agnosticEncoder) { builder =>
builder.getFilterBuilder
.setInput(plan.getRoot)
.setCondition(udf.apply(col("*")).expr)
Expand Down Expand Up @@ -2683,7 +2687,7 @@ class Dataset[T] private[sql] (
val outputEncoder = encoderFor[U]
val udf = ScalarUserDefinedFunction(
function = func,
inputEncoders = encoder :: Nil,
inputEncoders = agnosticEncoder :: Nil,
outputEncoder = outputEncoder)
sparkSession.newDataset(outputEncoder) { builder =>
builder.getMapPartitionsBuilder
Expand Down Expand Up @@ -2785,7 +2789,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def tail(n: Int): Array[T] = {
val lastN = sparkSession.newDataset(encoder) { builder =>
val lastN = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getTailBuilder
.setInput(plan.getRoot)
.setLimit(n)
Expand Down Expand Up @@ -2856,7 +2860,7 @@ class Dataset[T] private[sql] (
}

private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getRepartitionBuilder
.setInput(plan.getRoot)
.setNumPartitions(numPartitions)
Expand All @@ -2866,11 +2870,12 @@ class Dataset[T] private[sql] (

private def buildRepartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}

/**
Expand Down Expand Up @@ -3183,7 +3188,7 @@ class Dataset[T] private[sql] (
* @since 3.5.0
*/
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getWithWatermarkBuilder
.setInput(plan.getRoot)
.setEventTime(eventTime)
Expand Down Expand Up @@ -3251,7 +3256,7 @@ class Dataset[T] private[sql] (
sparkSession.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA)
}

def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder)
def collectResult(): SparkResult[T] = sparkSession.execute(plan, agnosticEncoder)

private[sql] def withResult[E](f: SparkResult[T] => E): E = {
val result = collectResult()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -988,15 +988,15 @@ private object KeyValueGroupedDatasetImpl {
groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = {
val gf = ScalarUserDefinedFunction(
function = groupingFunc,
inputEncoders = ds.encoder :: Nil, // Using the original value and key encoders
inputEncoders = ds.agnosticEncoder :: Nil, // Using the original value and key encoders
outputEncoder = kEncoder)
new KeyValueGroupedDatasetImpl(
ds.sparkSession,
ds.plan,
kEncoder,
kEncoder,
ds.encoder,
ds.encoder,
ds.agnosticEncoder,
ds.agnosticEncoder,
Arrays.asList(gf.apply(col("*")).expr),
UdfUtils.identical(),
() => ds.map(groupingFunc)(kEncoder))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ object SparkSession extends Logging {
class Builder() extends Logging {
private val builder = SparkConnectClient.builder()
private var client: SparkConnectClient = _
private[this] val options = new scala.collection.mutable.HashMap[String, String]

def remote(connectionString: String): Builder = {
builder.connectionString(connectionString)
Expand All @@ -804,6 +805,84 @@ object SparkSession extends Logging {
this
}

/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: String): Builder = synchronized {
options += key -> value
this
}

/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Long): Builder = synchronized {
options += key -> value.toString
this
}

/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Double): Builder = synchronized {
options += key -> value.toString
this
}

/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Boolean): Builder = synchronized {
options += key -> value.toString
this
}

/**
* Sets a config a map of options. Options set using this method are automatically propagated
* to the Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(map: Map[String, Any]): Builder = synchronized {
map.foreach { kv: (String, Any) =>
{
options += kv._1 -> kv._2.toString
}
}
this
}

/**
* Sets a config option. Options set using this method are automatically propagated to both
* `SparkConf` and SparkSession's own configuration.
*
* @since 3.5.0
*/
def config(map: java.util.Map[String, Any]): Builder = synchronized {
config(map.asScala.toMap)
}

@deprecated("enableHiveSupport does not work in Spark Connect")
def enableHiveSupport(): Builder = this

@deprecated("master does not work in Spark Connect, please use remote instead")
def master(master: String): Builder = this

@deprecated("appName does not work in Spark Connect")
def appName(name: String): Builder = this

private def tryCreateSessionFromClient(): Option[SparkSession] = {
if (client != null) {
Option(new SparkSession(client, cleaner, planIdGenerator))
Expand All @@ -812,6 +891,12 @@ object SparkSession extends Logging {
}
}

private def applyOptions(session: SparkSession): Unit = {
options.foreach { case (key, value) =>
session.conf.set(key, value)
}
}

/**
* Build the [[SparkSession]].
*
Expand All @@ -833,6 +918,7 @@ object SparkSession extends Logging {
val session = tryCreateSessionFromClient()
.getOrElse(SparkSession.this.create(builder.configuration))
setDefaultAndActiveSession(session)
applyOptions(session)
session
}

Expand All @@ -842,14 +928,17 @@ object SparkSession extends Logging {
* If a session exist with the same configuration that is returned instead of creating a new
* session.
*
* This method will update the default and/or active session if they are not set.
* This method will update the default and/or active session if they are not set. This method
* will always set the specified configuration options on the session, even when it is not
* newly created.
*
* @since 3.5.0
*/
def getOrCreate(): SparkSession = {
val session = tryCreateSessionFromClient()
.getOrElse(sessions.get(builder.configuration))
setDefaultAndActiveSession(session)
applyOptions(session)
session
}
}
Expand Down
Loading

0 comments on commit 4341ec4

Please sign in to comment.