Skip to content

Commit

Permalink
[SPARK-35798][SQL] Fix SparkPlan.sqlContext usage
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
There might be `SparkPlan` nodes where canonicalization on executor side can cause issues. This is a follow-up fix to conversation https://github.com/apache/spark/pull/32885/files#r651019687.

### Why are the changes needed?
To avoid potential NPEs when canonicalization happens on executors.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing UTs.

Closes #32947 from peter-toth/SPARK-35798-fix-sparkplan.sqlcontext-usage.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
peter-toth authored and cloud-fan committed Jun 17, 2021
1 parent b86a69f commit abf9675
Show file tree
Hide file tree
Showing 35 changed files with 97 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {

override def doExecute(): RDD[InternalRow] = {
val broadcastedHadoopConf =
new SerializableConfiguration(sqlContext.sessionState.newHadoopConf())
new SerializableConfiguration(session.sessionState.newHadoopConf())

child.execute().mapPartitions { iter =>
if (iter.hasNext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ case class RowToColumnarExec(child: SparkPlan) extends RowToColumnarTransition {
)

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled
val enableOffHeapColumnVector = conf.offHeapColumnVectorEnabled
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
// Instead of creating a new config we are reusing columnBatchSize. In the future if we do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ case class CommandResultExec(

@transient private lazy val rdd: RDD[InternalRow] = {
if (rows.isEmpty) {
sqlContext.sparkContext.emptyRDD
sparkContext.emptyRDD
} else {
val numSlices = math.min(
unsafeRows.length, sqlContext.sparkSession.leafNodeDefaultParallelism)
sqlContext.sparkContext.parallelize(unsafeRows, numSlices)
unsafeRows.length, session.leafNodeDefaultParallelism)
sparkContext.parallelize(unsafeRows, numSlices)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ trait DataSourceScanExec extends LeafExecNode {
// Metadata that describes more details of this scan.
protected def metadata: Map[String, String]

protected val maxMetadataValueLength = sqlContext.sessionState.conf.maxMetadataStringLength
protected val maxMetadataValueLength = conf.maxMetadataStringLength

override def simpleString(maxFields: Int): String = {
val metadataEntries = metadata.toSeq.sorted.map {
Expand Down Expand Up @@ -86,7 +86,7 @@ trait DataSourceScanExec extends LeafExecNode {
* Shorthand for calling redactString() without specifying redacting rules
*/
protected def redact(text: String): String = {
Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text)
Utils.redact(conf.stringRedactionPattern, text)
}

/**
Expand Down Expand Up @@ -179,7 +179,7 @@ case class FileSourceScanExec(

private lazy val needsUnsafeRowConversion: Boolean = {
if (relation.fileFormat.isInstanceOf[ParquetSource]) {
sqlContext.conf.parquetVectorizedReaderEnabled
conf.parquetVectorizedReaderEnabled
} else {
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ case class LocalTableScanExec(

@transient private lazy val rdd: RDD[InternalRow] = {
if (rows.isEmpty) {
sqlContext.sparkContext.emptyRDD
sparkContext.emptyRDD
} else {
val numSlices = math.min(
unsafeRows.length, sqlContext.sparkSession.leafNodeDefaultParallelism)
sqlContext.sparkContext.parallelize(unsafeRows, numSlices)
unsafeRows.length, session.leafNodeDefaultParallelism)
sparkContext.parallelize(unsafeRows, numSlices)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ case class SortExec(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

private val enableRadixSort = sqlContext.conf.enableRadixSort
private val enableRadixSort = conf.enableRadixSort

override lazy val metrics = Map(
"sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch

object SparkPlan {
Expand All @@ -55,15 +56,17 @@ object SparkPlan {
* The naming convention is that physical operators end with "Exec" suffix, e.g. [[ProjectExec]].
*/
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
@transient final val session = SparkSession.getActiveSession.orNull

/**
* A handle to the SQL Context that was used to create this plan. Since many operators need
* access to the sqlContext for RDD operations or configuration this field is automatically
* populated by the query planning infrastructure.
*/
@transient final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull
protected def sparkContext = session.sparkContext

protected def sparkContext = sqlContext.sparkContext
override def conf: SQLConf = {
if (session != null) {
session.sessionState.conf
} else {
super.conf
}
}

val id: Int = SparkPlan.newPlanId()

Expand All @@ -80,8 +83,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
if (sqlContext != null) {
SparkSession.setActiveSession(sqlContext.sparkSession)
if (session != null) {
SparkSession.setActiveSession(session)
}
super.makeCopy(newArgs)
}
Expand Down Expand Up @@ -448,7 +451,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2)
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
Expand All @@ -467,7 +470,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
} else {
parts
}
val sc = sqlContext.sparkContext
val sc = sparkContext
val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
if (it.hasNext) it.next() else (0L, Array.emptyByteArray), partsToScan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class SubqueryBroadcastExec(
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
SQLExecution.withExecutionId(session, executionId) {
val beforeCollect = System.nanoTime()

val broadcastRelation = child.executeBroadcast[HashedRelation]().value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,17 +724,17 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
val (_, compiledCodeStats) = try {
CodeGenerator.compile(cleanedSource)
} catch {
case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback =>
case NonFatal(_) if !Utils.isTesting && conf.codegenFallback =>
// We should already saw the error message
logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
return child.execute()
}

// Check if compiled code has a too large function
if (compiledCodeStats.maxMethodCodeSize > sqlContext.conf.hugeMethodLimit) {
if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) {
logInfo(s"Found too long generated codes and JIT optimization might not work: " +
s"the bytecode size (${compiledCodeStats.maxMethodCodeSize}) is above the limit " +
s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
s"${conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
return child.execute()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object AggUtils {
resultExpressions = resultExpressions,
child = child)
} else {
val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
val objectHashEnabled = child.conf.useObjectHashAggregation
val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)

if (objectHashEnabled && useObjectHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ case class HashAggregateExec(
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
Option(sqlContext).map { sc =>
sc.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null)
Option(session).map { s =>
s.conf.get("spark.sql.TungstenAggregate.testFallbackStartsAt", null)
}.orNull match {
case null | "" => None
case fallbackStartsAt =>
Expand Down Expand Up @@ -679,15 +679,15 @@ case class HashAggregateExec(

// This is for testing/benchmarking only.
// We enforce to first level to be a vectorized hashmap, instead of the default row-based one.
isVectorizedHashMapEnabled = sqlContext.conf.enableVectorizedHashMap
isVectorizedHashMapEnabled = conf.enableVectorizedHashMap
}
}

private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
if (sqlContext.conf.enableTwoLevelAggMap) {
if (conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap(ctx)
} else if (sqlContext.conf.enableVectorizedHashMap) {
} else if (conf.enableVectorizedHashMap) {
logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.")
}
val bitMaxCapacity = testFallbackStartsAt match {
Expand All @@ -700,7 +700,7 @@ case class HashAggregateExec(
} else {
(math.log10(fastMapCounter) / math.log10(2)).floor.toInt
}
case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit
case _ => conf.fastHashAggregateRowMaxCapacityBit
}

val thisPlan = ctx.addReferenceObj("plan", this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ case class ObjectHashAggregateExec(
val aggTime = longMetric("aggTime")
val spillSize = longMetric("spillSize")
val numTasksFallBacked = longMetric("numTasksFallBacked")
val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold
val fallbackCountThreshold = conf.objectAggSortBasedFallbackThreshold

child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
val beforeAgg = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ case class UpdatingSessionsExec(
groupingWithoutSessionExpression.map(_.toAttribute)

override protected def doExecute(): RDD[InternalRow] = {
val inMemoryThreshold = sqlContext.conf.sessionWindowBufferInMemoryThreshold
val spillThreshold = sqlContext.conf.sessionWindowBufferSpillThreshold
val inMemoryThreshold = conf.sessionWindowBufferInMemoryThreshold
val spillThreshold = conf.sessionWindowBufferSpillThreshold

child.execute().mapPartitions { iter =>
new UpdatingSessionsIterator(iter, groupingExpression, sessionExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val start: Long = range.start
val end: Long = range.end
val step: Long = range.step
val numSlices: Int = range.numSlices.getOrElse(sqlContext.sparkSession.leafNodeDefaultParallelism)
val numSlices: Int = range.numSlices.getOrElse(session.leafNodeDefaultParallelism)
val numElements: BigInt = range.numElements
val isEmptyRange: Boolean = start == end || (start < end ^ 0 < step)

Expand Down Expand Up @@ -442,9 +442,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)

override def inputRDDs(): Seq[RDD[InternalRow]] = {
val rdd = if (isEmptyRange) {
new EmptyRDD[InternalRow](sqlContext.sparkContext)
new EmptyRDD[InternalRow](sparkContext)
} else {
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
}
rdd :: Nil
}
Expand Down Expand Up @@ -608,10 +608,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
if (isEmptyRange) {
new EmptyRDD[InternalRow](sqlContext.sparkContext)
new EmptyRDD[InternalRow](sparkContext)
} else {
sqlContext
.sparkContext
sparkContext
.parallelize(0 until numSlices, numSlices)
.mapPartitionsWithIndex { (i, _) =>
val partitionStart = (i * numElements) / numSlices * step + start
Expand Down Expand Up @@ -814,11 +813,11 @@ case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: Option[Int]
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
sqlContext.sparkSession,
session,
SubqueryExec.executionContext) {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
SQLExecution.withExecutionId(session, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = if (maxNumRows.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ case class CachedRDDBuilder(

@transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = null

val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator
val rowCountStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator
val sizeInBytesStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator
val rowCountStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator

val cachedName = tableName.map(n => s"In-memory table $n")
.getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ case class InMemoryTableScanExec(
override def outputOrdering: Seq[SortOrder] =
relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])

lazy val enableAccumulatorsForTest: Boolean = sqlContext.conf.inMemoryTableScanStatisticsEnabled
lazy val enableAccumulatorsForTest: Boolean = conf.inMemoryTableScanStatisticsEnabled

// Accumulators used for testing purposes
lazy val readPartitions = sparkContext.longAccumulator
lazy val readBatches = sparkContext.longAccumulator

private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
private val inMemoryPartitionPruningEnabled = conf.inMemoryPartitionPruning

private def filteredCachedBatches(): RDD[CachedBatch] = {
val buffers = relation.cacheBuilder.cachedColumnBuffers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {
*/
protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow])
cmd.run(session).map(converter(_).asInstanceOf[InternalRow])
}

override def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil
Expand All @@ -92,7 +92,7 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {
}

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
sparkContext.parallelize(sideEffectResult, 1)
}
}

Expand All @@ -110,7 +110,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)

protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val rows = cmd.run(sqlContext.sparkSession, child)
val rows = cmd.run(session, child)

rows.map(converter(_).asInstanceOf[InternalRow])
}
Expand All @@ -133,7 +133,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
}

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
sparkContext.parallelize(sideEffectResult, 1)
}

override protected def withNewChildInternal(newChild: SparkPlan): DataWritingCommandExec =
Expand Down
Loading

0 comments on commit abf9675

Please sign in to comment.