Skip to content

Commit

Permalink
Spark35 ?
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Mar 8, 2024
1 parent 438f53c commit b13fc32
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.util.collection.BitSet
* @param disableBucketedScan Disable bucketed scan based on physical query plan, see rule
* [[DisableUnnecessaryBucketedScan]] for details.
*/
case class AbstractFileSourceScanExec(
abstract class AbstractFileSourceScanExec(
@transient override val relation: HadoopFsRelation,
override val output: Seq[Attribute],
override val requiredSchema: StructType,
Expand Down Expand Up @@ -252,22 +252,8 @@ case class AbstractFileSourceScanExec(

// Filters unused DynamicPruningExpression expressions - one which has been replaced
// with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning
private def filterUnusedDynamicPruningExpressions(
protected def filterUnusedDynamicPruningExpressions(
predicates: Seq[Expression]): Seq[Expression] = {
predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral))
}

override def doCanonicalize(): AbstractFileSourceScanExec = {
AbstractFileSourceScanExec(
relation,
output.map(QueryPlan.normalizeExpressions(_, output)),
requiredSchema,
QueryPlan.normalizePredicates(
filterUnusedDynamicPruningExpressions(partitionFilters), output),
optionalBucketSet,
optionalNumCoalescedBuckets,
QueryPlan.normalizePredicates(dataFilters, output),
None,
disableBucketedScan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.BitSet

class FileSourceScanExecShim(
abstract class FileSourceScanExecShim(
@transient relation: HadoopFsRelation,
output: Seq[Attribute],
requiredSchema: StructType,
Expand All @@ -36,7 +36,7 @@ class FileSourceScanExecShim(
dataFilters: Seq[Expression],
tableIdentifier: Option[TableIdentifier],
disableBucketedScan: Boolean = false)
extends FileSourceScanExec(
extends AbstractFileSourceScanExec(
relation,
output,
requiredSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ import org.apache.spark.sql.internal.SQLConf
/**
* Physical plan node for scanning a batch of data from a data source v2.
*/
case class AbstractBatchScanExec(
abstract class AbstractBatchScanExec(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
val runtimeFilters: Seq[Expression],
ordering: Option[Seq[SortOrder]] = None,
@transient table: Table,
spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams()
val spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams()
) extends DataSourceV2ScanExecBase {

@transient lazy val batch: Batch = if (scan == null) null else scan.toBatch
Expand Down Expand Up @@ -231,14 +231,6 @@ case class AbstractBatchScanExec(
override def keyGroupedPartitioning: Option[Seq[Expression]] =
spjParams.keyGroupedPartitioning

override def doCanonicalize(): AbstractBatchScanExec = {
this.copy(
output = output.map(QueryPlan.normalizeExpressions(_, output)),
runtimeFilters = QueryPlan.normalizePredicates(
runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)),
output))
}

override def simpleString(maxFields: Int): String = {
val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields)
val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
Expand All @@ -30,12 +31,15 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch

class BatchScanExecShim(
abstract class BatchScanExecShim(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
@transient table: Table)
extends BatchScanExec(output, scan, runtimeFilters, table = table) {
@transient val table: Table,
val commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
val applyPartialClustering: Boolean = false,
val replicatePartitions: Boolean = false)
extends AbstractBatchScanExec(output, scan, runtimeFilters, table = table) {

// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics: Map[String, SQLMetric] = Map()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.execution

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.ql.io.{DelegateSymlinkTextInputFormat, SymlinkTextInputFormat}
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition}
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
import org.apache.hadoop.mapred.InputFormat

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.util.Utils

/**
* The Hive table scan operator. Column and partition pruning are both handled.
*
* @param requestedAttributes Attributes to be fetched from the Hive table.
* @param relation The Hive table be scanned.
* @param partitionPruningPred An optional partition pruning predicate for partitioned table.
*/
private[hive]
abstract class AbstractHiveTableScanExec(
requestedAttributes: Seq[Attribute],
relation: HiveTableRelation,
partitionPruningPred: Seq[Expression])(
@transient protected val sparkSession: SparkSession)
extends LeafExecNode with CastSupport {

require(partitionPruningPred.isEmpty || relation.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")

override def conf: SQLConf = sparkSession.sessionState.conf

override def nodeName: String = s"Scan hive ${relation.tableMeta.qualifiedName}"

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def producedAttributes: AttributeSet = outputSet ++
AttributeSet(partitionPruningPred.flatMap(_.references))

private val originalAttributes = AttributeMap(relation.output.map(a => a -> a))

override val output: Seq[Attribute] = {
// Retrieve the original attributes based on expression ID so that capitalization matches.
requestedAttributes.map(originalAttributes)
}

// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
require(pred.dataType == BooleanType,
s"Data type of predicate $pred must be ${BooleanType.catalogString} rather than " +
s"${pred.dataType.catalogString}.")

BindReferences.bindReference(pred, relation.partitionCols)
}

@transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta)
@transient private lazy val tableDesc = new TableDesc(
getInputFormat(hiveQlTable.getInputFormatClass, conf),
hiveQlTable.getOutputFormatClass,
hiveQlTable.getMetadata)

// Create a local copy of hadoopConf,so that scan specific modifications should not impact
// other queries
@transient private lazy val hadoopConf = {
val c = sparkSession.sessionState.newHadoopConf()
// append columns ids and names before broadcast
addColumnMetadataToConf(c)
c
}

@transient private lazy val hadoopReader = new HadoopTableReader(
output,
relation.partitionCols,
tableDesc,
sparkSession,
hadoopConf)

private def castFromString(value: String, dataType: DataType) = {
cast(Literal(value), dataType).eval(null)
}

private def addColumnMetadataToConf(hiveConf: Configuration): Unit = {
// Specifies needed column IDs for those non-partitioning columns.
val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex)
val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer)
val neededColumnNames = output.filter(columnOrdinals.contains).map(_.name)

HiveShim.appendReadColumns(hiveConf, neededColumnIDs, neededColumnNames)

val deserializer = tableDesc.getDeserializerClass.getConstructor().newInstance()
deserializer.initialize(hiveConf, tableDesc.getProperties)

// Specifies types and object inspectors of columns to be scanned.
val structOI = ObjectInspectorUtils
.getStandardObjectInspector(
deserializer.getObjectInspector,
ObjectInspectorCopyOption.JAVA)
.asInstanceOf[StructObjectInspector]

val columnTypeNames = structOI
.getAllStructFieldRefs.asScala
.map(_.getFieldObjectInspector)
.map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName)
.mkString(",")

hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames)
hiveConf.set(serdeConstants.LIST_COLUMNS, relation.dataCols.map(_.name).mkString(","))
}

/**
* Prunes partitions not involve the query plan.
*
* @param partitions All partitions of the relation.
* @return Partitions that are involved in the query plan.
*/
private[hive] def prunePartitions(partitions: Seq[HivePartition]): Seq[HivePartition] = {
boundPruningPred match {
case None => partitions
case Some(shouldKeep) => partitions.filter { part =>
val dataTypes = relation.partitionCols.map(_.dataType)
val castedValues = part.getValues.asScala.zip(dataTypes)
.map { case (value, dataType) => castFromString(value, dataType) }

// Only partitioned values are needed here, since the predicate has already been bound to
// partition key attribute references.
val row = InternalRow.fromSeq(castedValues.toSeq)
shouldKeep.eval(row).asInstanceOf[Boolean]
}
}
}

@transient lazy val prunedPartitions: Seq[HivePartition] = {
if (relation.prunedPartitions.nonEmpty) {
val hivePartitions =
relation.prunedPartitions.get.map(HiveClientImpl.toHivePartition(_, hiveQlTable))
if (partitionPruningPred.forall(!ExecSubqueryExpression.hasSubquery(_))) {
hivePartitions
} else {
prunePartitions(hivePartitions)
}
} else {
if (sparkSession.sessionState.conf.metastorePartitionPruning &&
partitionPruningPred.nonEmpty) {
rawPartitions
} else {
prunePartitions(rawPartitions)
}
}
}

// exposed for tests
@transient lazy val rawPartitions: Seq[HivePartition] = {
val prunedPartitions =
if (sparkSession.sessionState.conf.metastorePartitionPruning &&
partitionPruningPred.nonEmpty) {
// Retrieve the original attributes based on expression ID so that capitalization matches.
val normalizedFilters = partitionPruningPred.map(_.transform {
case a: AttributeReference => originalAttributes(a)
})
sparkSession.sessionState.catalog
.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters)
} else {
sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier)
}
prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable))
}

protected override def doExecute(): RDD[InternalRow] = {
// Using dummyCallSite, as getCallSite can turn out to be expensive with
// multiple partitions.
val rdd = if (!relation.isPartitioned) {
Utils.withDummyCallSite(sparkContext) {
hadoopReader.makeRDDForTable(hiveQlTable)
}
} else {
Utils.withDummyCallSite(sparkContext) {
hadoopReader.makeRDDForPartitionedTable(prunedPartitions)
}
}
val numOutputRows = longMetric("numOutputRows")
// Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649)
val outputSchema = schema
rdd.mapPartitionsWithIndexInternal { (index, iter) =>
val proj = UnsafeProjection.create(outputSchema)
proj.initialize(index)
iter.map { r =>
numOutputRows += 1
proj(r)
}
}
}

// Filters unused DynamicPruningExpression expressions - one which has been replaced
// with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning
private def filterUnusedDynamicPruningExpressions(
predicates: Seq[Expression]): Seq[Expression] = {
predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral))
}

// Optionally returns a delegate input format based on the provided input format class.
// This is currently used to replace SymlinkTextInputFormat with DelegateSymlinkTextInputFormat
// in order to fix SPARK-40815.
private def getInputFormat(
inputFormatClass: Class[_ <: InputFormat[_, _]],
conf: SQLConf): Class[_ <: InputFormat[_, _]] = {
if (inputFormatClass == classOf[SymlinkTextInputFormat] &&
conf != null && conf.getConf(HiveUtils.USE_DELEGATE_FOR_SYMLINK_TEXT_INPUT_FORMAT)) {
classOf[DelegateSymlinkTextInputFormat]
} else {
inputFormatClass
}
}

override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)
}

0 comments on commit b13fc32

Please sign in to comment.