Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-1171] Support merge parquet schema and read missing schema #1175

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, Arr
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import com.intel.oap.vectorized.ArrowWritableColumnVector
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.parquet.hadoop.codec.CodecConfig
import org.apache.parquet.hadoop.util.ContextUtil

Expand All @@ -44,8 +43,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkVectorUtils}
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils.UnsafeItr
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.StructType
Expand All @@ -55,7 +54,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi


override def isSplitable(sparkSession: SparkSession,
options: Map[String, String], path: Path): Boolean = {
options: Map[String, String], path: Path): Boolean = {
ArrowUtils.isOriginalFormatSplitable(
new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap))
}
Expand All @@ -65,22 +64,22 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi
}

override def inferSchema(sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
convert(files, options)
}

override def prepareWrite(sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val arrowOptions = new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap)
val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf)

val conf = ContextUtil.getConfiguration(job)
conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName)
logInfo(s"write parquet with codec: ${parquetOptions.compressionCodecClassName}")

new OutputWriterFactory {
override def getFileExtension(context: TaskAttemptContext): String = {
ArrowUtils.getFormat(arrowOptions) match {
Expand All @@ -94,13 +93,14 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi
context: TaskAttemptContext): OutputWriter = {
val originPath = path
val writeQueue = new ArrowWriteQueue(ArrowUtils.toArrowSchema(dataSchema),
ArrowUtils.getFormat(arrowOptions), parquetOptions.compressionCodecClassName.toLowerCase(), originPath)
ArrowUtils.getFormat(arrowOptions),
parquetOptions.compressionCodecClassName.toLowerCase(), originPath)

new OutputWriter {
override def write(row: InternalRow): Unit = {
val batch = row.asInstanceOf[FakeRow].batch
writeQueue.enqueue(SparkVectorUtils
.toArrowRecordBatch(batch))
.toArrowRecordBatch(batch))
}

override def close(): Unit = {
Expand Down Expand Up @@ -140,74 +140,79 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi
// todo predicate validation / pushdown
val parquetFileFields = factory.inspect().getFields.asScala
val caseInsensitiveFieldMap = mutable.Map[String, String]()
val requiredFields = if (caseSensitive) {
new Schema(requiredSchema.map { field =>
parquetFileFields.find(_.getName.equals(field.name))
.getOrElse(ArrowUtils.toArrowField(field))
}.asJava)
} else {
new Schema(requiredSchema.map { readField =>
// TODO: check schema inside of complex type
val matchedFields =
parquetFileFields.filter(_.getName.equalsIgnoreCase(readField.name))
if (matchedFields.size > 1) {
// Need to fail if there is ambiguity, i.e. more than one field is matched
val fieldsString = matchedFields.map(_.getName).mkString("[", ", ", "]")
throw new RuntimeException(
s"""
|Found duplicate field(s) "${readField.name}": $fieldsString
|in case-insensitive mode""".stripMargin.replaceAll("\n", " "))
} else {
matchedFields
.map { field =>
caseInsensitiveFieldMap += (readField.name -> field.getName)
field
}.headOption.getOrElse(ArrowUtils.toArrowField(readField))
}
}.asJava)
}
val dataset = factory.finish(requiredFields)
// TODO: support array/map/struct types in out-of-order schema reading.
val actualReadFields =
ArrowUtils.getRequestedField(requiredSchema, parquetFileFields, caseSensitive)

val compare = ArrowUtils.compareStringFunc(caseSensitive)
val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray
val actualReadSchema = new StructType(
actualReadFieldNames.map(f => requiredSchema.find(field => compare(f, field.name)).get))
val dataset = factory.finish(actualReadFields)

val hasMissingColumns = actualReadFields.getFields.size() != requiredSchema.size
val filter = if (enableFilterPushDown) {
ArrowFilters.translateFilters(filters, caseInsensitiveFieldMap.toMap)
val pushedFilters = if (hasMissingColumns) {
ArrowFilters.evaluateMissingFieldFilters(filters, actualReadFieldNames)
} else {
filters
}
if (pushedFilters == null) {
null
} else {
ArrowFilters.translateFilters(
pushedFilters, caseInsensitiveFieldMap.toMap)
}
} else {
org.apache.arrow.dataset.filter.Filter.EMPTY
}

val scanOptions = new ScanOptions(
requiredFields.getFields.asScala.map(f => f.getName).toArray,
filter,
batchSize)
val scanner = dataset.newScan(scanOptions)
if (filter == null) {
new Iterator[InternalRow] {
override def hasNext: Boolean = false
override def next(): InternalRow = null
}
} else {
val scanOptions = new ScanOptions(
actualReadFieldNames,
filter,
batchSize)
val scanner = dataset.newScan(scanOptions)

val taskList = scanner
val taskList = scanner
.scan()
.iterator()
.asScala
.toList
val itrList = taskList
.map(task => task.execute())

Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => {
itrList.foreach(_.close())
taskList.foreach(_.close())
scanner.close()
dataset.close()
factory.close()
}))

val partitionVectors =
ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues)

SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => {
partitionVectors.foreach(_.close())
})

val itr = itrList
.toIterator
.flatMap(itr => itr.asScala)
.map(batch => ArrowUtils.loadBatch(batch, requiredSchema, partitionVectors))
new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]]
val itrList = taskList
.map(task => task.execute())

Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => {
itrList.foreach(_.close())
taskList.foreach(_.close())
scanner.close()
dataset.close()
factory.close()
}))

val partitionVectors =
ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues)

val nullVectors = if (hasMissingColumns) {
val missingSchema =
new StructType(requiredSchema.filterNot(actualReadSchema.contains).toArray)
ArrowUtils.loadMissingColumns(batchSize, missingSchema)
} else {
Array.empty[ArrowWritableColumnVector]
}

val itr = itrList
.toIterator
.flatMap(itr => itr.asScala)
.map(batch => ArrowUtils.loadBatch(
batch, actualReadSchema, requiredSchema, partitionVectors, nullVectors))
new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]]
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow
import org.apache.arrow.dataset.DatasetTypes
import org.apache.arrow.dataset.DatasetTypes.TreeNode
import org.apache.arrow.dataset.filter.FilterImpl
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType

object ArrowFilters {
def pruneWithSchema(pushedFilters: Array[Filter], schema: StructType): Seq[Filter] = {
def pruneWithSchema(pushedFilters: Seq[Filter], schema: StructType): Seq[Filter] = {
pushedFilters.filter(pushedFilter => {
isToBeAccepted(pushedFilter, schema)
})
Expand Down Expand Up @@ -57,6 +56,70 @@ object ArrowFilters {
false
}

def evaluateMissingFieldFilters(
pushedFilters: Seq[Filter],
requiredFields: Seq[String]): Seq[Filter] = {
val evaluatedFilters = evaluateFilters(pushedFilters, requiredFields)
if (evaluatedFilters.exists(_._2 == false)) {
null
} else {
evaluatedFilters.map(_._1).filterNot(_ == null)
}
}

def evaluateFilters(
pushedFilters: Seq[Filter],
requiredFields: Seq[String]): Seq[(Filter, Boolean)] = {
pushedFilters.map {
case r @ EqualTo(attribute, value) if !requiredFields.contains(attribute) =>
(null, null == value)
case r @ GreaterThan(attribute, value) if !requiredFields.contains(attribute) =>
(null, false)
case r @ GreaterThanOrEqual(attribute, value) if !requiredFields.contains(attribute) =>
(null, null == value)
case LessThan(attribute, value) if !requiredFields.contains(attribute) =>
(null, false)
case r @ LessThanOrEqual(attribute, value) if !requiredFields.contains(attribute) =>
(null, null == value)
case r @ Not(child) =>
evaluateFilters(Seq(child), requiredFields).head match {
case (null, false) => (null, true)
case (null, true) => (null, false)
case (_, true) => (r, true)
}
case r @ And(left, right) =>
val evaluatedFilters = evaluateFilters(Seq(left, right), requiredFields)
val filters = evaluatedFilters.map(_._1).filterNot(_ == null)
if (evaluatedFilters.forall(_._2)) {
if (filters.size > 1) {
(r, true)
} else {
(filters.head, true)
}
} else {
(null, false)
}
case r @ Or(left, right) =>
val evaluatedFilters = evaluateFilters(Seq(left, right), requiredFields)
val filters = evaluatedFilters.map(_._1).filterNot(_ == null)
if (evaluatedFilters.exists(_._2)) {
if (filters.size > 1) {
(r, true)
} else {
(filters.head, true)
}
} else {
(null, false)
}
case IsNotNull(attribute) if !requiredFields.contains(attribute) =>
(null, false)
case IsNull(attribute) if !requiredFields.contains(attribute) =>
(null, true)
case r =>
(r, true)
}
}

def translateFilters(
pushedFilters: Seq[Filter],
caseInsensitiveFieldMap: Map[String, String]): org.apache.arrow.dataset.filter.Filter = {
Expand Down
Loading