Skip to content

Commit

Permalink
[SPARK-46124][CORE][SQL][SS][CONNECT][DSTREAM][MLLIB][ML][PYTHON][R][…
Browse files Browse the repository at this point in the history
…AVRO][K8S][YARN][UI] Replace explicit `ArrayOps#toSeq` with `s.c.immutable.ArraySeq.unsafeWrapArray`

### What changes were proposed in this pull request?
There is a behavioral difference between Scala 2.13 and 2.12 for explicit `ArrayOps.toSeq` calls, similar to the implicit conversion from `Array` to `Seq`.

In Scala 2.12, `ArrayOps.toSeq` will return `thisCollection` ,  and use implicit conversion rules to wrap the `Array` as `mutable.WrappedArray`, this process does not involve any collection copy:

```scala
Welcome to Scala 2.12.18 (OpenJDK 64-Bit Server VM, Java 17.0.9).
Type in expressions for evaluation. Or try :help.

scala> Array(1,2,3).toSeq
res0: Seq[Int] = WrappedArray(1, 2, 3)
```

However, in Scala 2.13, it returns an `immutable.ArraySeq` that with collection copy.

Since we have always used the non-collection copy behavior for this explicit conversion in the era of Scala 2.12, it is safe to assume that no collection copy is needed for Scala 2.13.

Therefore, this pr replaces explicit `ArrayOps.toSeq` in the Spark code with `s.c.immutable.ArraySeq.unsafeWrapArray` to avoid a collection copy, and this pr only involves changes to the production code, and does not involve changes to the test code.

### Why are the changes needed?
Replace `ArrayOps#toSeq` with `s.c.immutable.ArraySeq.unsafeWrapArray` to save a collection copy, which has potential benefits for performance."

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Pass GitHub Action

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#44041 from LuciferYang/ArrayToSeq-2-ArrayToImmutableArraySeq.

Authored-by: yangjie01 <yangjie01@baidu.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
LuciferYang authored and srowen committed Nov 29, 2023
1 parent f5e4e84 commit 8a4890d
Show file tree
Hide file tree
Showing 126 changed files with 349 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._

/**
* A serializer to serialize data in catalyst format to data in avro format.
Expand Down Expand Up @@ -309,7 +310,7 @@ private[sql] class AvroSerializer(
avroPath: Seq[String]): InternalRow => Any = {
val nonNullTypes = nonNullUnionBranches(unionType)
val expectedFieldNames = nonNullTypes.indices.map(i => s"member$i")
val catalystFieldNames = catalystStruct.fieldNames.toSeq
val catalystFieldNames = catalystStruct.fieldNames.toImmutableArraySeq
if (positionalFieldMatch) {
if (expectedFieldNames.length != catalystFieldNames.length) {
throw new IncompatibleSchemaException(s"Generic Avro union at ${toFieldStr(avroPath)} " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._

/**
* A column that will be computed based on the data in a `DataFrame`.
Expand Down Expand Up @@ -1004,7 +1005,7 @@ class Column private[sql] (@DeveloperApi val expr: proto.Expression) extends Log
* @group expr_ops
* @since 3.4.0
*/
def as(aliases: Array[String]): Column = as(aliases.toSeq)
def as(aliases: Array[String]): Column = as(aliases.toImmutableArraySeq)

/**
* Gives the column an alias.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
import org.apache.spark.util.ArrayImplicits._

/**
* Functionality for working with missing data in `DataFrame`s.
Expand Down Expand Up @@ -57,7 +58,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
*
* @since 3.4.0
*/
def drop(cols: Array[String]): DataFrame = drop(cols.toSeq)
def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq)

/**
* (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values
Expand All @@ -76,7 +77,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
*
* @since 3.4.0
*/
def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq)
def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toImmutableArraySeq)

/**
* (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in
Expand Down Expand Up @@ -107,7 +108,8 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
*
* @since 3.4.0
*/
def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq)
def drop(minNonNulls: Int, cols: Array[String]): DataFrame =
drop(minNonNulls, cols.toImmutableArraySeq)

/**
* (Scala-specific) Returns a new `DataFrame` that drops rows containing less than `minNonNulls`
Expand Down Expand Up @@ -152,7 +154,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
*
* @since 3.4.0
*/
def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
Expand All @@ -179,7 +181,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
*
* @since 3.4.0
*/
def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
Expand All @@ -206,7 +208,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
*
* @since 3.4.0
*/
def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified string
Expand All @@ -233,7 +235,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
*
* @since 3.4.0
*/
def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean
Expand Down Expand Up @@ -374,7 +376,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
* @since 3.4.0
*/
def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = {
replace(cols.toSeq, replacement.asScala.toMap)
replace(cols.toImmutableArraySeq, replacement.asScala.toMap)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.functions.{struct, to_json}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkClassUtils

/**
Expand Down Expand Up @@ -650,7 +651,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = {
join(right, usingColumns.toSeq)
join(right, usingColumns.toImmutableArraySeq)
}

/**
Expand Down Expand Up @@ -729,7 +730,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = {
join(right, usingColumns.toSeq, joinType)
join(right, usingColumns.toImmutableArraySeq, joinType)
}

/**
Expand Down Expand Up @@ -1306,12 +1307,12 @@ class Dataset[T] private[sql] (
valueColumnName: String): DataFrame = sparkSession.newDataFrame { builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
.addAllIds(ids.toSeq.map(_.expr).asJava)
.addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
.setValueColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
.addAllValues(values.toSeq.map(_.expr).asJava)
.addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
}
}

Expand Down Expand Up @@ -2496,7 +2497,8 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq)
def dropDuplicates(colNames: Array[String]): Dataset[T] =
dropDuplicates(colNames.toImmutableArraySeq)

/**
* Returns a new [[Dataset]] with duplicate rows removed, considering only the subset of
Expand All @@ -2518,7 +2520,7 @@ class Dataset[T] private[sql] (
}

def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = {
dropDuplicatesWithinWatermark(colNames.toSeq)
dropDuplicatesWithinWatermark(colNames.toImmutableArraySeq)
}

@scala.annotation.varargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.connect.proto
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkClassUtils

/**
Expand Down Expand Up @@ -225,7 +226,7 @@ object DataTypeProtoConverter {
.build()

case StructType(fields: Array[StructField]) =>
val protoFields = fields.toSeq.map {
val protoFields = fields.toImmutableArraySeq.map {
case StructField(
name: String,
dataType: DataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.planner.{SaveModeConverter, TableSaveMethodConverter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -473,8 +474,8 @@ package object dsl {
proto.StatApproxQuantile
.newBuilder()
.setInput(logicalPlan)
.addAllCols(cols.toSeq.asJava)
.addAllProbabilities(probabilities.toSeq.map(Double.box).asJava)
.addAllCols(cols.toImmutableArraySeq.asJava)
.addAllProbabilities(probabilities.toImmutableArraySeq.map(Double.box).asJava)
.setRelativeError(relativeError)
.build())
.build()
Expand All @@ -500,7 +501,7 @@ package object dsl {
proto.StatFreqItems
.newBuilder()
.setInput(logicalPlan)
.addAllCols(cols.toSeq.asJava)
.addAllCols(cols.toImmutableArraySeq.asJava)
.setSupport(support)
.build())
.build()
Expand Down Expand Up @@ -1082,7 +1083,7 @@ package object dsl {
weights.sum > 0,
s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}")

val sum = weights.toSeq.sum
val sum = weights.toImmutableArraySeq.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights
.sliding(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class SparkConnectPlanner(
val values = rel.getValuesList.asScala.toArray
if (values.length == 1) {
val value = LiteralValueProtoConverter.toCatalystValue(values.head)
val columns = if (cols.nonEmpty) Some(cols.toSeq) else None
val columns = if (cols.nonEmpty) Some(cols.toImmutableArraySeq) else None
dataset.na.fillValue(value, columns).logicalPlan
} else {
val valueMap = mutable.Map.empty[String, Any]
Expand Down Expand Up @@ -2434,7 +2434,7 @@ class SparkConnectPlanner(
.sort(pivotCol) // ensure that the output columns are in a consistent logical order
.collect()
.map(_.get(0))
.toSeq
.toImmutableArraySeq
.map(expressions.Literal.apply)
}

Expand Down Expand Up @@ -3073,7 +3073,7 @@ class SparkConnectPlanner(
val progressReports = if (command.getLastProgress) {
Option(query.lastProgress).toSeq
} else {
query.recentProgress.toSeq
query.recentProgress.toImmutableArraySeq
}
respBuilder.setRecentProgress(
StreamingQueryCommandResult.RecentProgressResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
import org.apache.spark.util.ArrayImplicits._

private[connect] class SparkConnectAnalyzeHandler(
responseObserver: StreamObserver[proto.AnalyzePlanResponse])
Expand Down Expand Up @@ -128,7 +129,7 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
.newBuilder()
.addAllFiles(inputFiles.toSeq.asJava)
.addAllFiles(inputFiles.toImmutableArraySeq.asJava)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ private[kafka010] class KafkaOffsetReaderAdmin(
val end = splitOffsetRanges.last.copy(untilOffset = untilOffsetsMap(tp))
Seq(first) ++ splitOffsetRanges.drop(1).dropRight(1) :+ end
}
}.toArray.toSeq
}.toArray.toImmutableArraySeq
} else {
offsetRangesBase
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ private[kafka010] class KafkaOffsetReaderConsumer(
val end = splitOffsetRanges.last.copy(untilOffset = untilOffsetsMap(tp))
Seq(first) ++ splitOffsetRanges.drop(1).dropRight(1) :+ end
}
}.toArray.toSeq
}.toArray.toImmutableArraySeq
} else {
offsetRangesBase
}
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
import org.apache.spark.util._
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

Expand Down Expand Up @@ -1054,7 +1055,7 @@ private[spark] class MapOutputTrackerMaster(
val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
if (blockManagerIds.nonEmpty) {
blockManagerIds.get.map(_.host).distinct.toSeq
blockManagerIds.get.map(_.host).distinct.toImmutableArraySeq
} else {
Nil
}
Expand Down Expand Up @@ -1142,7 +1143,7 @@ private[spark] class MapOutputTrackerMaster(
if (startMapIndex < endMapIndex &&
(startMapIndex >= 0 && endMapIndex <= statuses.length)) {
val statusesPicked = statuses.slice(startMapIndex, endMapIndex).filter(_ != null)
statusesPicked.map(_.location.host).distinct.toSeq
statusesPicked.map(_.location.host).distinct.toImmutableArraySeq
} else {
Nil
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria

/** Set JAR files to distribute to the cluster. (Java-friendly version.) */
def setJars(jars: Array[String]): SparkConf = {
setJars(jars.toSeq)
setJars(jars.toImmutableArraySeq)
}

/**
Expand Down Expand Up @@ -158,7 +158,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
* (Java-friendly version.)
*/
def setExecutorEnv(variables: Array[(String, String)]): SparkConf = {
setExecutorEnv(variables.toSeq)
setExecutorEnv(variables.toImmutableArraySeq)
}

/**
Expand Down
Loading

0 comments on commit 8a4890d

Please sign in to comment.