Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-1171] Support merge parquet schema and read missing schema (#1175)
Browse files Browse the repository at this point in the history
* Support merge parquet schema and read missing schema

* fix error

* optimize null vectors

* optimize code

* optimize code

* change code

* add schema merge suite tests

* add test for struct type
  • Loading branch information
jackylee-ch authored Dec 9, 2022
1 parent d2f77ad commit cf11842
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, Arr
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import com.intel.oap.vectorized.ArrowWritableColumnVector
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.parquet.hadoop.codec.CodecConfig
import org.apache.parquet.hadoop.util.ContextUtil

Expand All @@ -44,8 +43,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkVectorUtils}
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils.UnsafeItr
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.StructType
Expand All @@ -55,7 +54,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi


override def isSplitable(sparkSession: SparkSession,
options: Map[String, String], path: Path): Boolean = {
options: Map[String, String], path: Path): Boolean = {
ArrowUtils.isOriginalFormatSplitable(
new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap))
}
Expand All @@ -65,22 +64,22 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi
}

override def inferSchema(sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
convert(files, options)
}

override def prepareWrite(sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val arrowOptions = new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap)
val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf)

val conf = ContextUtil.getConfiguration(job)
conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName)
logInfo(s"write parquet with codec: ${parquetOptions.compressionCodecClassName}")

new OutputWriterFactory {
override def getFileExtension(context: TaskAttemptContext): String = {
ArrowUtils.getFormat(arrowOptions) match {
Expand All @@ -94,13 +93,14 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi
context: TaskAttemptContext): OutputWriter = {
val originPath = path
val writeQueue = new ArrowWriteQueue(ArrowUtils.toArrowSchema(dataSchema),
ArrowUtils.getFormat(arrowOptions), parquetOptions.compressionCodecClassName.toLowerCase(), originPath)
ArrowUtils.getFormat(arrowOptions),
parquetOptions.compressionCodecClassName.toLowerCase(), originPath)

new OutputWriter {
override def write(row: InternalRow): Unit = {
val batch = row.asInstanceOf[FakeRow].batch
writeQueue.enqueue(SparkVectorUtils
.toArrowRecordBatch(batch))
.toArrowRecordBatch(batch))
}

override def close(): Unit = {
Expand Down Expand Up @@ -140,74 +140,79 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi
// todo predicate validation / pushdown
val parquetFileFields = factory.inspect().getFields.asScala
val caseInsensitiveFieldMap = mutable.Map[String, String]()
val requiredFields = if (caseSensitive) {
new Schema(requiredSchema.map { field =>
parquetFileFields.find(_.getName.equals(field.name))
.getOrElse(ArrowUtils.toArrowField(field))
}.asJava)
} else {
new Schema(requiredSchema.map { readField =>
// TODO: check schema inside of complex type
val matchedFields =
parquetFileFields.filter(_.getName.equalsIgnoreCase(readField.name))
if (matchedFields.size > 1) {
// Need to fail if there is ambiguity, i.e. more than one field is matched
val fieldsString = matchedFields.map(_.getName).mkString("[", ", ", "]")
throw new RuntimeException(
s"""
|Found duplicate field(s) "${readField.name}": $fieldsString
|in case-insensitive mode""".stripMargin.replaceAll("\n", " "))
} else {
matchedFields
.map { field =>
caseInsensitiveFieldMap += (readField.name -> field.getName)
field
}.headOption.getOrElse(ArrowUtils.toArrowField(readField))
}
}.asJava)
}
val dataset = factory.finish(requiredFields)
// TODO: support array/map/struct types in out-of-order schema reading.
val actualReadFields =
ArrowUtils.getRequestedField(requiredSchema, parquetFileFields, caseSensitive)

val compare = ArrowUtils.compareStringFunc(caseSensitive)
val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray
val actualReadSchema = new StructType(
actualReadFieldNames.map(f => requiredSchema.find(field => compare(f, field.name)).get))
val dataset = factory.finish(actualReadFields)

val hasMissingColumns = actualReadFields.getFields.size() != requiredSchema.size
val filter = if (enableFilterPushDown) {
ArrowFilters.translateFilters(filters, caseInsensitiveFieldMap.toMap)
val pushedFilters = if (hasMissingColumns) {
ArrowFilters.evaluateMissingFieldFilters(filters, actualReadFieldNames)
} else {
filters
}
if (pushedFilters == null) {
null
} else {
ArrowFilters.translateFilters(
pushedFilters, caseInsensitiveFieldMap.toMap)
}
} else {
org.apache.arrow.dataset.filter.Filter.EMPTY
}

val scanOptions = new ScanOptions(
requiredFields.getFields.asScala.map(f => f.getName).toArray,
filter,
batchSize)
val scanner = dataset.newScan(scanOptions)
if (filter == null) {
new Iterator[InternalRow] {
override def hasNext: Boolean = false
override def next(): InternalRow = null
}
} else {
val scanOptions = new ScanOptions(
actualReadFieldNames,
filter,
batchSize)
val scanner = dataset.newScan(scanOptions)

val taskList = scanner
val taskList = scanner
.scan()
.iterator()
.asScala
.toList
val itrList = taskList
.map(task => task.execute())

Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => {
itrList.foreach(_.close())
taskList.foreach(_.close())
scanner.close()
dataset.close()
factory.close()
}))

val partitionVectors =
ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues)

SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => {
partitionVectors.foreach(_.close())
})

val itr = itrList
.toIterator
.flatMap(itr => itr.asScala)
.map(batch => ArrowUtils.loadBatch(batch, requiredSchema, partitionVectors))
new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]]
val itrList = taskList
.map(task => task.execute())

Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => {
itrList.foreach(_.close())
taskList.foreach(_.close())
scanner.close()
dataset.close()
factory.close()
}))

val partitionVectors =
ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues)

val nullVectors = if (hasMissingColumns) {
val missingSchema =
new StructType(requiredSchema.filterNot(actualReadSchema.contains).toArray)
ArrowUtils.loadMissingColumns(batchSize, missingSchema)
} else {
Array.empty[ArrowWritableColumnVector]
}

val itr = itrList
.toIterator
.flatMap(itr => itr.asScala)
.map(batch => ArrowUtils.loadBatch(
batch, actualReadSchema, requiredSchema, partitionVectors, nullVectors))
new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]]
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow
import org.apache.arrow.dataset.DatasetTypes
import org.apache.arrow.dataset.DatasetTypes.TreeNode
import org.apache.arrow.dataset.filter.FilterImpl
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType

object ArrowFilters {
def pruneWithSchema(pushedFilters: Array[Filter], schema: StructType): Seq[Filter] = {
def pruneWithSchema(pushedFilters: Seq[Filter], schema: StructType): Seq[Filter] = {
pushedFilters.filter(pushedFilter => {
isToBeAccepted(pushedFilter, schema)
})
Expand Down Expand Up @@ -57,6 +56,70 @@ object ArrowFilters {
false
}

def evaluateMissingFieldFilters(
pushedFilters: Seq[Filter],
requiredFields: Seq[String]): Seq[Filter] = {
val evaluatedFilters = evaluateFilters(pushedFilters, requiredFields)
if (evaluatedFilters.exists(_._2 == false)) {
null
} else {
evaluatedFilters.map(_._1).filterNot(_ == null)
}
}

def evaluateFilters(
pushedFilters: Seq[Filter],
requiredFields: Seq[String]): Seq[(Filter, Boolean)] = {
pushedFilters.map {
case r @ EqualTo(attribute, value) if !requiredFields.contains(attribute) =>
(null, null == value)
case r @ GreaterThan(attribute, value) if !requiredFields.contains(attribute) =>
(null, false)
case r @ GreaterThanOrEqual(attribute, value) if !requiredFields.contains(attribute) =>
(null, null == value)
case LessThan(attribute, value) if !requiredFields.contains(attribute) =>
(null, false)
case r @ LessThanOrEqual(attribute, value) if !requiredFields.contains(attribute) =>
(null, null == value)
case r @ Not(child) =>
evaluateFilters(Seq(child), requiredFields).head match {
case (null, false) => (null, true)
case (null, true) => (null, false)
case (_, true) => (r, true)
}
case r @ And(left, right) =>
val evaluatedFilters = evaluateFilters(Seq(left, right), requiredFields)
val filters = evaluatedFilters.map(_._1).filterNot(_ == null)
if (evaluatedFilters.forall(_._2)) {
if (filters.size > 1) {
(r, true)
} else {
(filters.head, true)
}
} else {
(null, false)
}
case r @ Or(left, right) =>
val evaluatedFilters = evaluateFilters(Seq(left, right), requiredFields)
val filters = evaluatedFilters.map(_._1).filterNot(_ == null)
if (evaluatedFilters.exists(_._2)) {
if (filters.size > 1) {
(r, true)
} else {
(filters.head, true)
}
} else {
(null, false)
}
case IsNotNull(attribute) if !requiredFields.contains(attribute) =>
(null, false)
case IsNull(attribute) if !requiredFields.contains(attribute) =>
(null, true)
case r =>
(r, true)
}
}

def translateFilters(
pushedFilters: Seq[Filter],
caseInsensitiveFieldMap: Map[String, String]): org.apache.arrow.dataset.filter.Filter = {
Expand Down
Loading

0 comments on commit cf11842

Please sign in to comment.