From 04106ffebf3bf9fdc9186c029ce38e673eb25cbf Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Wed, 30 Nov 2022 10:50:01 +0800 Subject: [PATCH 1/8] Support merge parquet schema and read missing schema --- .../datasources/arrow/ArrowFileFormat.scala | 149 +++++++++++------ .../datasources/v2/arrow/ArrowFilters.scala | 65 ++++++- .../arrow/ArrowPartitionReaderFactory.scala | 158 +++++++++++------- .../datasources/v2/arrow/ArrowUtils.scala | 33 +++- 4 files changed, 290 insertions(+), 115 deletions(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala index fb8f6ddc6..28e8f2868 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala @@ -28,7 +28,7 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, Arr import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._ import com.intel.oap.vectorized.ArrowWritableColumnVector import org.apache.arrow.dataset.scanner.ScanOptions -import org.apache.arrow.vector.types.pojo.Schema +import org.apache.arrow.vector.types.pojo.{Field, Schema} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.Job @@ -40,6 +40,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkVectorUtils} import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils.UnsafeItr import org.apache.spark.sql.internal.SQLConf @@ -51,7 +52,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab override def isSplitable(sparkSession: SparkSession, - options: Map[String, String], path: Path): Boolean = { + options: Map[String, String], path: Path): Boolean = { ArrowUtils.isOriginalFormatSplitable( new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap)) } @@ -61,15 +62,21 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab } override def inferSchema(sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Option[StructType] = { - convert(files, options) + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val arrowOptions = new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap) + ArrowUtils.getFormat(arrowOptions) match { + case _: org.apache.arrow.dataset.file.format.ParquetFileFormat => + ParquetUtils.inferSchema(sparkSession, options, files) + case _ => + convert(files, options) + } } override def prepareWrite(sparkSession: SparkSession, - job: Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory = { + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { val arrowOptions = new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap) new OutputWriterFactory { override def getFileExtension(context: TaskAttemptContext): String = { @@ -90,7 +97,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab override def write(row: InternalRow): Unit = { val batch = row.asInstanceOf[FakeRow].batch writeQueue.enqueue(SparkVectorUtils - .toArrowRecordBatch(batch)) + .toArrowRecordBatch(batch)) } override def close(): Unit = { @@ -130,13 +137,14 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab // todo predicate validation / pushdown val parquetFileFields = factory.inspect().getFields.asScala val caseInsensitiveFieldMap = mutable.Map[String, String]() - val requiredFields = if (caseSensitive) { - new Schema(requiredSchema.map { field => - parquetFileFields.find(_.getName.equals(field.name)) - .getOrElse(ArrowUtils.toArrowField(field)) + // TODO: support array/map/struct types in out-of-order schema reading. + val requestColNames = requiredSchema.map(_.name) + val actualReadFields = if (caseSensitive) { + new Schema(parquetFileFields.filter { field => + requestColNames.exists(_.equals(field.getName)) }.asJava) } else { - new Schema(requiredSchema.map { readField => + requiredSchema.foreach { readField => // TODO: check schema inside of complex type val matchedFields = parquetFileFields.filter(_.getName.equalsIgnoreCase(readField.name)) @@ -147,57 +155,96 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab s""" |Found duplicate field(s) "${readField.name}": $fieldsString |in case-insensitive mode""".stripMargin.replaceAll("\n", " ")) - } else { - matchedFields - .map { field => - caseInsensitiveFieldMap += (readField.name -> field.getName) - field - }.headOption.getOrElse(ArrowUtils.toArrowField(readField)) } + } + new Schema(parquetFileFields.filter { field => + requestColNames.exists(_.equalsIgnoreCase(field.getName)) }.asJava) } - val dataset = factory.finish(requiredFields) + val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray + val actualReadSchema = if (caseSensitive) { + new StructType(actualReadFieldNames.map(f => requiredSchema.find(_.name.equals(f)).get)) + } else { + new StructType( + actualReadFieldNames.map(f => requiredSchema.find(_.name.equalsIgnoreCase(f)).get)) + } + val dataset = factory.finish(actualReadFields) + val hashMissingColumns = actualReadFields.getFields.size() != requiredSchema.size val filter = if (enableFilterPushDown) { - ArrowFilters.translateFilters(filters, caseInsensitiveFieldMap.toMap) + val pushedFilters = if (hashMissingColumns) { + ArrowFilters.evaluateMissingFieldFilters(filters, actualReadFieldNames) + } else { + filters + } + if (pushedFilters == null) { + null + } else { + ArrowFilters.translateFilters( + pushedFilters, caseInsensitiveFieldMap.toMap) + } } else { org.apache.arrow.dataset.filter.Filter.EMPTY } - val scanOptions = new ScanOptions( - requiredFields.getFields.asScala.map(f => f.getName).toArray, - filter, - batchSize) - val scanner = dataset.newScan(scanOptions) + if (filter == null) { + new Iterator[InternalRow] { + override def hasNext: Boolean = false + override def next(): InternalRow = null + } + } else { + val scanOptions = new ScanOptions( + actualReadFieldNames, + filter, + batchSize) + val scanner = dataset.newScan(scanOptions) - val taskList = scanner + val taskList = scanner .scan() .iterator() .asScala .toList - val itrList = taskList - .map(task => task.execute()) - - Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => { - itrList.foreach(_.close()) - taskList.foreach(_.close()) - scanner.close() - dataset.close() - factory.close() - })) - - val partitionVectors = - ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues) - - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { - partitionVectors.foreach(_.close()) - }) - - val itr = itrList - .toIterator - .flatMap(itr => itr.asScala) - .map(batch => ArrowUtils.loadBatch(batch, requiredSchema, partitionVectors)) - new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]] + val itrList = taskList + .map(task => task.execute()) + + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => { + itrList.foreach(_.close()) + taskList.foreach(_.close()) + scanner.close() + dataset.close() + factory.close() + })) + + val partitionVectors = + ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues) + + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { + partitionVectors.foreach(_.close()) + }) + + val nullVectors = if (hashMissingColumns) { + val vectors = + ArrowWritableColumnVector.allocateColumns(batchSize, requiredSchema) + vectors.foreach { vector => + vector.putNulls(0, batchSize) + vector.setValueCount(batchSize) + } + + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { + vectors.foreach(_.close()) + }) + vectors + } else { + Array.empty[ArrowWritableColumnVector] + } + + val itr = itrList + .toIterator + .flatMap(itr => itr.asScala) + .map(batch => ArrowUtils.loadBatch( + batch, actualReadSchema, partitionVectors, nullVectors)) + new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]] + } } } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala index 0bcfd3812..a4685f286 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala @@ -20,7 +20,6 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow import org.apache.arrow.dataset.DatasetTypes import org.apache.arrow.dataset.DatasetTypes.TreeNode import org.apache.arrow.dataset.filter.FilterImpl -import org.apache.arrow.vector.types.pojo.Field import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -57,6 +56,70 @@ object ArrowFilters { false } + def evaluateMissingFieldFilters( + pushedFilters: Seq[Filter], + requiredFields: Seq[String]): Seq[Filter] = { + val evaluatedFilters = evaluateFilters(pushedFilters, requiredFields) + if (evaluatedFilters.exists(_._2 == false)) { + null + } else { + evaluatedFilters.map(_._1).filterNot(_ == null) + } + } + + def evaluateFilters( + pushedFilters: Seq[Filter], + requiredFields: Seq[String]): Seq[(Filter, Boolean)] = { + pushedFilters.map { + case r @ EqualTo(attribute, value) if !requiredFields.contains(attribute) => + (null, null == value) + case r @ GreaterThan(attribute, value) if !requiredFields.contains(attribute) => + (null, false) + case r @ GreaterThanOrEqual(attribute, value) if !requiredFields.contains(attribute) => + (null, null == value) + case LessThan(attribute, value) if !requiredFields.contains(attribute) => + (null, false) + case r @ LessThanOrEqual(attribute, value) if !requiredFields.contains(attribute) => + (null, null == value) + case r @ Not(child) => + evaluateFilters(Seq(child), requiredFields).head match { + case (null, false) => (null, true) + case (null, true) => (null, false) + case (_, true) => (r, true) + } + case r @ And(left, right) => + val evaluatedFilters = evaluateFilters(Seq(left, right), requiredFields) + val filters = evaluatedFilters.map(_._1).filterNot(_ == null) + if (evaluatedFilters.forall(_._2)) { + if (filters.size > 1) { + (r, true) + } else { + (filters.head, true) + } + } else { + (null, false) + } + case r @ Or(left, right) => + val evaluatedFilters = evaluateFilters(Seq(left, right), requiredFields) + val filters = evaluatedFilters.map(_._1).filterNot(_ == null) + if (evaluatedFilters.exists(_._2)) { + if (filters.size > 1) { + (r, true) + } else { + (filters.head, true) + } + } else { + (null, false) + } + case IsNotNull(attribute) if !requiredFields.contains(attribute) => + (null, false) + case IsNull(attribute) if !requiredFields.contains(attribute) => + (null, true) + case r => + (r, true) + } + } + def translateFilters( pushedFilters: Seq[Filter], caseInsensitiveFieldMap: Map[String, String]): org.apache.arrow.dataset.filter.Filter = { diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala index 2551d0fa6..cd951a7e5 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala @@ -21,12 +21,14 @@ import java.net.URLDecoder import scala.collection.JavaConverters._ import scala.collection.mutable +import com.google.common.collect.Lists import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowPartitionReaderFactory.ColumnarBatchRetainer import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._ +import com.intel.oap.vectorized.ArrowWritableColumnVector import org.apache.arrow.dataset.scanner.ScanOptions -import org.apache.arrow.vector.types.pojo.Schema - +import org.apache.arrow.vector.types.pojo.{Field, Schema} import org.apache.spark.TaskContext + import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} @@ -66,13 +68,14 @@ case class ArrowPartitionReaderFactory( partitionedFile.start, partitionedFile.length, options) val parquetFileFields = factory.inspect().getFields.asScala val caseInsensitiveFieldMap = mutable.Map[String, String]() - val requiredFields = if (caseSensitive) { - new Schema(readDataSchema.map { field => - parquetFileFields.find(_.getName.equals(field.name)) - .getOrElse(ArrowUtils.toArrowField(field)) + // TODO: support array/map/struct types in out-of-order schema reading. + val requestColNames = readDataSchema.map(_.name) + val actualReadFields = if (caseSensitive) { + new Schema(parquetFileFields.filter { field => + requestColNames.exists(_.equals(field.getName)) }.asJava) } else { - new Schema(readDataSchema.map { readField => + readDataSchema.foreach { readField => // TODO: check schema inside of complex type val matchedFields = parquetFileFields.filter(_.getName.equalsIgnoreCase(readField.name)) @@ -83,69 +86,110 @@ case class ArrowPartitionReaderFactory( s""" |Found duplicate field(s) "${readField.name}": $fieldsString |in case-insensitive mode""".stripMargin.replaceAll("\n", " ")) - } else { - matchedFields - .map { field => - caseInsensitiveFieldMap += (readField.name -> field.getName) - field - }.headOption.getOrElse(ArrowUtils.toArrowField(readField)) } + } + new Schema(parquetFileFields.filter { field => + requestColNames.exists(_.equalsIgnoreCase(field.getName)) }.asJava) } - val dataset = factory.finish(requiredFields) + val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray + val actualReadSchema = if (caseSensitive) { + new StructType(actualReadFieldNames.map(f => readDataSchema.find(_.name.equals(f)).get)) + } else { + new StructType( + actualReadFieldNames.map(f => readDataSchema.find(_.name.equalsIgnoreCase(f)).get)) + } + val dataset = factory.finish(actualReadFields) + + val hashMissingColumns = actualReadFields.getFields.size() != readDataSchema.size val filter = if (enableFilterPushDown) { - ArrowFilters.translateFilters( - ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema), - caseInsensitiveFieldMap.toMap) + val filters = if (hashMissingColumns) { + ArrowFilters.evaluateMissingFieldFilters(pushedFilters, actualReadFieldNames).toArray + } else { + pushedFilters + } + if (filters == null) { + null + } else { + ArrowFilters.translateFilters( + ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema), + caseInsensitiveFieldMap.toMap) + } } else { org.apache.arrow.dataset.filter.Filter.EMPTY } - val scanOptions = new ScanOptions(readDataSchema.map(f => f.name).toArray, - filter, batchSize) - val scanner = dataset.newScan(scanOptions) - - val taskList = scanner - .scan() - .iterator() - .asScala - .toList - - val vsrItrList = taskList - .map(task => task.execute()) - - val partitionVectors = ArrowUtils.loadPartitionColumns( - batchSize, readPartitionSchema, partitionedFile.partitionValues) + if (filter == null) { + new PartitionReader[ColumnarBatch] { + override def next(): Boolean = false + override def get(): ColumnarBatch = null + override def close(): Unit = { + // Nothing will be done + } + } + } else { + val scanOptions = new ScanOptions(actualReadFieldNames, filter, batchSize) + val scanner = dataset.newScan(scanOptions) + + val taskList = scanner + .scan() + .iterator() + .asScala + .toList + + val vsrItrList = taskList + .map(task => task.execute()) + + val partitionVectors = ArrowUtils.loadPartitionColumns( + batchSize, readPartitionSchema, partitionedFile.partitionValues) + + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { + partitionVectors.foreach(_.close()) + }) + + val nullVectors = if (hashMissingColumns) { + val vectors = + ArrowWritableColumnVector.allocateColumns(batchSize, readDataSchema) + vectors.foreach { vector => + vector.putNulls(0, batchSize) + vector.setValueCount(batchSize) + } - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { - partitionVectors.foreach(_.close()) - }) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { + vectors.foreach(_.close()) + }) + vectors + } else { + Array.empty[ArrowWritableColumnVector] + } - val batchItr = vsrItrList - .toIterator - .flatMap(itr => itr.asScala) - .map(batch => ArrowUtils.loadBatch(batch, readDataSchema, partitionVectors)) + val batchItr = vsrItrList + .toIterator + .flatMap(itr => itr.asScala) + .map(batch => ArrowUtils.loadBatch( + batch, actualReadSchema, partitionVectors, nullVectors)) - new PartitionReader[ColumnarBatch] { - val holder = new ColumnarBatchRetainer() + new PartitionReader[ColumnarBatch] { + val holder = new ColumnarBatchRetainer() - override def next(): Boolean = { - holder.release() - batchItr.hasNext - } + override def next(): Boolean = { + holder.release() + batchItr.hasNext + } - override def get(): ColumnarBatch = { - val batch = batchItr.next() - holder.retain(batch) - batch - } + override def get(): ColumnarBatch = { + val batch = batchItr.next() + holder.retain(batch) + batch + } - override def close(): Unit = { - holder.release() - vsrItrList.foreach(itr => itr.close()) - taskList.foreach(task => task.close()) - scanner.close() - dataset.close() - factory.close() + override def close(): Unit = { + holder.release() + vsrItrList.foreach(itr => itr.close()) + taskList.foreach(task => task.close()) + scanner.close() + dataset.close() + factory.close() + } } } } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index b42dd70be..0a9978b91 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets import java.time.ZoneId import scala.collection.JavaConverters._ +import scala.collection.mutable import com.intel.oap.vectorized.{ArrowColumnVectorUtils, ArrowWritableColumnVector} import org.apache.arrow.dataset.file.FileSystemDatasetFactory @@ -70,7 +71,7 @@ object ArrowUtils { } def makeArrowDiscovery(encodedUri: String, startOffset: Long, length: Long, - options: ArrowOptions): FileSystemDatasetFactory = { + options: ArrowOptions): FileSystemDatasetFactory = { val format = getFormat(options) val allocator = SparkMemoryUtils.contextAllocator() @@ -104,7 +105,8 @@ object ArrowUtils { def loadBatch( input: ArrowRecordBatch, dataSchema: StructType, - partitionVectors: Array[ArrowWritableColumnVector]): ColumnarBatch = { + partitionVectors: Array[ArrowWritableColumnVector], + nullVectors: Array[ArrowWritableColumnVector]): ColumnarBatch = { val rowCount: Int = input.getLength val vectors = try { @@ -113,8 +115,27 @@ object ArrowUtils { input.close() } + val totalVectors = if (dataSchema.size != nullVectors.length) { + val finalVectors = + mutable.ArrayBuffer[ArrowWritableColumnVector]() + val nullIterator = nullVectors.iterator + while (nullIterator.hasNext) { + val nullVector = nullIterator.next() + finalVectors.append( + vectors.find(_.getValueVector.getName.equals(nullVector.getValueVector.getName)) + .getOrElse { + nullVector.setValueCount(rowCount) + nullVector.retain() + nullVector + }) + } + finalVectors.toArray + } else { + vectors + } + val batch = new ColumnarBatch( - vectors.map(_.asInstanceOf[ColumnVector]) ++ + totalVectors.map(_.asInstanceOf[ColumnVector]) ++ partitionVectors .map { vector => vector.setValueCount(rowCount) @@ -132,7 +153,7 @@ object ArrowUtils { } def loadBatch(input: ArrowRecordBatch, partitionValues: InternalRow, - partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = { + partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = { val rowCount: Int = input.getLength val vectors = try { @@ -149,13 +170,13 @@ object ArrowUtils { val batch = new ColumnarBatch( vectors.map(_.asInstanceOf[ColumnVector]) ++ - partitionColumns.map(_.asInstanceOf[ColumnVector]), + partitionColumns.map(_.asInstanceOf[ColumnVector]), rowCount) batch } def getFormat( - options: ArrowOptions): org.apache.arrow.dataset.file.format.FileFormat = { + options: ArrowOptions): org.apache.arrow.dataset.file.format.FileFormat = { val paramMap = options.parameters.toMap.asJava options.originalFormat match { case "parquet" => org.apache.arrow.dataset.file.format.ParquetFileFormat.create(paramMap) From 496340da71bd5989f93f6edd882b6ffb8b759be6 Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Fri, 2 Dec 2022 20:48:06 +0800 Subject: [PATCH 2/8] fix error --- .../datasources/v2/arrow/ArrowUtils.scala | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index 0a9978b91..ee524ab22 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -115,19 +115,31 @@ object ArrowUtils { input.close() } - val totalVectors = if (dataSchema.size != nullVectors.length) { + val totalVectors = if (nullVectors.nonEmpty) { val finalVectors = mutable.ArrayBuffer[ArrowWritableColumnVector]() val nullIterator = nullVectors.iterator + val caseSensitive = SQLConf.get.caseSensitiveAnalysis while (nullIterator.hasNext) { val nullVector = nullIterator.next() finalVectors.append( - vectors.find(_.getValueVector.getName.equals(nullVector.getValueVector.getName)) - .getOrElse { - nullVector.setValueCount(rowCount) - nullVector.retain() - nullVector - }) + if (caseSensitive) { + vectors.find( + _.getValueVector.getName.equals(nullVector.getValueVector.getName)) + .getOrElse { + nullVector.setValueCount(rowCount) + nullVector.retain() + nullVector + } + } else { + vectors.find( + _.getValueVector.getName.equalsIgnoreCase(nullVector.getValueVector.getName)) + .getOrElse { + nullVector.setValueCount(rowCount) + nullVector.retain() + nullVector + } + }) } finalVectors.toArray } else { From 461cff896e10b3e3f19c66ad3ae095bcf50ac19f Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Mon, 5 Dec 2022 16:51:22 +0800 Subject: [PATCH 3/8] optimize null vectors --- .../datasources/arrow/ArrowFileFormat.scala | 12 +++++----- .../arrow/ArrowPartitionReaderFactory.scala | 12 +++++----- .../datasources/v2/arrow/ArrowUtils.scala | 22 +++++++++++-------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala index 28e8f2868..076082c46 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala @@ -168,11 +168,13 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab new StructType( actualReadFieldNames.map(f => requiredSchema.find(_.name.equalsIgnoreCase(f)).get)) } + val missingSchema = + new StructType(requiredSchema.filterNot(actualReadSchema.contains).toArray) val dataset = factory.finish(actualReadFields) - val hashMissingColumns = actualReadFields.getFields.size() != requiredSchema.size + val hasMissingColumns = actualReadFields.getFields.size() != requiredSchema.size val filter = if (enableFilterPushDown) { - val pushedFilters = if (hashMissingColumns) { + val pushedFilters = if (hasMissingColumns) { ArrowFilters.evaluateMissingFieldFilters(filters, actualReadFieldNames) } else { filters @@ -222,9 +224,9 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab partitionVectors.foreach(_.close()) }) - val nullVectors = if (hashMissingColumns) { + val nullVectors = if (hasMissingColumns) { val vectors = - ArrowWritableColumnVector.allocateColumns(batchSize, requiredSchema) + ArrowWritableColumnVector.allocateColumns(batchSize, missingSchema) vectors.foreach { vector => vector.putNulls(0, batchSize) vector.setValueCount(batchSize) @@ -242,7 +244,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab .toIterator .flatMap(itr => itr.asScala) .map(batch => ArrowUtils.loadBatch( - batch, actualReadSchema, partitionVectors, nullVectors)) + batch, actualReadSchema, requiredSchema, partitionVectors, nullVectors)) new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]] } } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala index cd951a7e5..35227fe0c 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala @@ -99,11 +99,13 @@ case class ArrowPartitionReaderFactory( new StructType( actualReadFieldNames.map(f => readDataSchema.find(_.name.equalsIgnoreCase(f)).get)) } + val missingSchema = + new StructType(readDataSchema.filterNot(actualReadSchema.contains).toArray) val dataset = factory.finish(actualReadFields) - val hashMissingColumns = actualReadFields.getFields.size() != readDataSchema.size + val hasMissingColumns = actualReadFields.getFields.size() != readDataSchema.size val filter = if (enableFilterPushDown) { - val filters = if (hashMissingColumns) { + val filters = if (hasMissingColumns) { ArrowFilters.evaluateMissingFieldFilters(pushedFilters, actualReadFieldNames).toArray } else { pushedFilters @@ -146,9 +148,9 @@ case class ArrowPartitionReaderFactory( partitionVectors.foreach(_.close()) }) - val nullVectors = if (hashMissingColumns) { + val nullVectors = if (hasMissingColumns) { val vectors = - ArrowWritableColumnVector.allocateColumns(batchSize, readDataSchema) + ArrowWritableColumnVector.allocateColumns(batchSize, missingSchema) vectors.foreach { vector => vector.putNulls(0, batchSize) vector.setValueCount(batchSize) @@ -166,7 +168,7 @@ case class ArrowPartitionReaderFactory( .toIterator .flatMap(itr => itr.asScala) .map(batch => ArrowUtils.loadBatch( - batch, actualReadSchema, partitionVectors, nullVectors)) + batch, actualReadSchema, readDataSchema, partitionVectors, nullVectors)) new PartitionReader[ColumnarBatch] { val holder = new ColumnarBatchRetainer() diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index ee524ab22..9bf7065fd 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -105,8 +105,9 @@ object ArrowUtils { def loadBatch( input: ArrowRecordBatch, dataSchema: StructType, - partitionVectors: Array[ArrowWritableColumnVector], - nullVectors: Array[ArrowWritableColumnVector]): ColumnarBatch = { + requiredSchema: StructType, + partitionVectors: Array[ArrowWritableColumnVector] = Array.empty, + nullVectors: Array[ArrowWritableColumnVector] = Array.empty): ColumnarBatch = { val rowCount: Int = input.getLength val vectors = try { @@ -118,23 +119,26 @@ object ArrowUtils { val totalVectors = if (nullVectors.nonEmpty) { val finalVectors = mutable.ArrayBuffer[ArrowWritableColumnVector]() - val nullIterator = nullVectors.iterator + val requiredIterator = requiredSchema.iterator val caseSensitive = SQLConf.get.caseSensitiveAnalysis - while (nullIterator.hasNext) { - val nullVector = nullIterator.next() + while (requiredIterator.hasNext) { + val field = requiredIterator.next() finalVectors.append( if (caseSensitive) { - vectors.find( - _.getValueVector.getName.equals(nullVector.getValueVector.getName)) + vectors.find(_.getValueVector.getName.equals(field.name)) .getOrElse { + // The missing column need to be find in nullVectors + val nullVector = nullVectors.find(_.getValueVector.getName.equals(field.name)).get nullVector.setValueCount(rowCount) nullVector.retain() nullVector } } else { - vectors.find( - _.getValueVector.getName.equalsIgnoreCase(nullVector.getValueVector.getName)) + vectors.find(_.getValueVector.getName.equalsIgnoreCase(field.name)) .getOrElse { + // The missing column need to be find in nullVectors + val nullVector = + nullVectors.find(_.getValueVector.getName.equalsIgnoreCase(field.name)).get nullVector.setValueCount(rowCount) nullVector.retain() nullVector From 7570eee1fd016a815e57bcb408f03b01d1e37849 Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Tue, 6 Dec 2022 21:43:45 +0800 Subject: [PATCH 4/8] optimize code --- .../datasources/arrow/ArrowFileFormat.scala | 55 +++-------------- .../arrow/ArrowPartitionReaderFactory.scala | 55 +++-------------- .../datasources/v2/arrow/ArrowUtils.scala | 59 +++++++++++++++++++ 3 files changed, 77 insertions(+), 92 deletions(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala index 076082c46..77a3ca422 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala @@ -138,38 +138,13 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab val parquetFileFields = factory.inspect().getFields.asScala val caseInsensitiveFieldMap = mutable.Map[String, String]() // TODO: support array/map/struct types in out-of-order schema reading. - val requestColNames = requiredSchema.map(_.name) - val actualReadFields = if (caseSensitive) { - new Schema(parquetFileFields.filter { field => - requestColNames.exists(_.equals(field.getName)) - }.asJava) - } else { - requiredSchema.foreach { readField => - // TODO: check schema inside of complex type - val matchedFields = - parquetFileFields.filter(_.getName.equalsIgnoreCase(readField.name)) - if (matchedFields.size > 1) { - // Need to fail if there is ambiguity, i.e. more than one field is matched - val fieldsString = matchedFields.map(_.getName).mkString("[", ", ", "]") - throw new RuntimeException( - s""" - |Found duplicate field(s) "${readField.name}": $fieldsString - |in case-insensitive mode""".stripMargin.replaceAll("\n", " ")) - } - } - new Schema(parquetFileFields.filter { field => - requestColNames.exists(_.equalsIgnoreCase(field.getName)) - }.asJava) - } + val actualReadFields = + ArrowUtils.getRequestedField(requiredSchema, parquetFileFields, caseSensitive) + + val compare = ArrowUtils.compareStringFunc(caseSensitive) val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray - val actualReadSchema = if (caseSensitive) { - new StructType(actualReadFieldNames.map(f => requiredSchema.find(_.name.equals(f)).get)) - } else { - new StructType( - actualReadFieldNames.map(f => requiredSchema.find(_.name.equalsIgnoreCase(f)).get)) - } - val missingSchema = - new StructType(requiredSchema.filterNot(actualReadSchema.contains).toArray) + val actualReadSchema = new StructType( + actualReadFieldNames.map(f => requiredSchema.find(field => compare(f, field.name)).get)) val dataset = factory.finish(actualReadFields) val hasMissingColumns = actualReadFields.getFields.size() != requiredSchema.size @@ -220,22 +195,10 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab val partitionVectors = ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues) - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { - partitionVectors.foreach(_.close()) - }) - val nullVectors = if (hasMissingColumns) { - val vectors = - ArrowWritableColumnVector.allocateColumns(batchSize, missingSchema) - vectors.foreach { vector => - vector.putNulls(0, batchSize) - vector.setValueCount(batchSize) - } - - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { - vectors.foreach(_.close()) - }) - vectors + val missingSchema = + new StructType(requiredSchema.filterNot(actualReadSchema.contains).toArray) + ArrowUtils.loadMissingColumns(batchSize, missingSchema) } else { Array.empty[ArrowWritableColumnVector] } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala index 35227fe0c..2ae2051e5 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala @@ -69,38 +69,13 @@ case class ArrowPartitionReaderFactory( val parquetFileFields = factory.inspect().getFields.asScala val caseInsensitiveFieldMap = mutable.Map[String, String]() // TODO: support array/map/struct types in out-of-order schema reading. - val requestColNames = readDataSchema.map(_.name) - val actualReadFields = if (caseSensitive) { - new Schema(parquetFileFields.filter { field => - requestColNames.exists(_.equals(field.getName)) - }.asJava) - } else { - readDataSchema.foreach { readField => - // TODO: check schema inside of complex type - val matchedFields = - parquetFileFields.filter(_.getName.equalsIgnoreCase(readField.name)) - if (matchedFields.size > 1) { - // Need to fail if there is ambiguity, i.e. more than one field is matched - val fieldsString = matchedFields.map(_.getName).mkString("[", ", ", "]") - throw new RuntimeException( - s""" - |Found duplicate field(s) "${readField.name}": $fieldsString - |in case-insensitive mode""".stripMargin.replaceAll("\n", " ")) - } - } - new Schema(parquetFileFields.filter { field => - requestColNames.exists(_.equalsIgnoreCase(field.getName)) - }.asJava) - } + val actualReadFields = + ArrowUtils.getRequestedField(readDataSchema, parquetFileFields, caseSensitive) + + val compare = ArrowUtils.compareStringFunc(caseSensitive) val actualReadFieldNames = actualReadFields.getFields.asScala.map(_.getName).toArray - val actualReadSchema = if (caseSensitive) { - new StructType(actualReadFieldNames.map(f => readDataSchema.find(_.name.equals(f)).get)) - } else { - new StructType( - actualReadFieldNames.map(f => readDataSchema.find(_.name.equalsIgnoreCase(f)).get)) - } - val missingSchema = - new StructType(readDataSchema.filterNot(actualReadSchema.contains).toArray) + val actualReadSchema = new StructType( + actualReadFieldNames.map(f => readDataSchema.find(field => compare(f, field.name)).get)) val dataset = factory.finish(actualReadFields) val hasMissingColumns = actualReadFields.getFields.size() != readDataSchema.size @@ -144,22 +119,10 @@ case class ArrowPartitionReaderFactory( val partitionVectors = ArrowUtils.loadPartitionColumns( batchSize, readPartitionSchema, partitionedFile.partitionValues) - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { - partitionVectors.foreach(_.close()) - }) - val nullVectors = if (hasMissingColumns) { - val vectors = - ArrowWritableColumnVector.allocateColumns(batchSize, missingSchema) - vectors.foreach { vector => - vector.putNulls(0, batchSize) - vector.setValueCount(batchSize) - } - - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { - vectors.foreach(_.close()) - }) - vectors + val missingSchema = + new StructType(readDataSchema.filterNot(actualReadSchema.contains).toArray) + ArrowUtils.loadMissingColumns(batchSize, missingSchema) } else { Array.empty[ArrowWritableColumnVector] } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index 9bf7065fd..5a24a90d3 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -29,6 +29,7 @@ import org.apache.arrow.dataset.file.FileSystemDatasetFactory import org.apache.arrow.vector.ipc.message.ArrowRecordBatch import org.apache.arrow.vector.types.pojo.{Field, Schema} import org.apache.hadoop.fs.FileStatus +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -89,6 +90,24 @@ object ArrowUtils { SparkSchemaUtils.toArrowSchema(t, SparkSchemaUtils.getLocalTimezoneID()) } + def loadMissingColumns( + rowCount: Int, + missingSchema: StructType): Array[ArrowWritableColumnVector] = { + + val vectors = + ArrowWritableColumnVector.allocateColumns(rowCount, missingSchema) + vectors.foreach { vector => + vector.putNulls(0, rowCount) + vector.setValueCount(rowCount) + } + + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { + vectors.foreach(_.close()) + }) + + vectors + } + def loadPartitionColumns( rowCount: Int, partitionSchema: StructType, @@ -99,6 +118,11 @@ object ArrowUtils { partitionColumns(i).setValueCount(rowCount) partitionColumns(i).setIsConstant() }) + + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => { + partitionColumns.foreach(_.close()) + }) + partitionColumns } @@ -221,4 +245,39 @@ object ArrowUtils { val rewritten = new URI(sch, ssp, uri.getFragment) rewritten.toString } + + def compareStringFunc(caseSensitive: Boolean): (String, String) => Boolean = + if (caseSensitive) { + (str1: String, str2: String) => str1.equals(str2) + } else { + (str1: String, str2: String) => str1.equalsIgnoreCase(str2) + } + + def getRequestedField( + requiredSchema: StructType, + parquetFileFields: mutable.Buffer[Field], + caseSensitive: Boolean): Schema = { + val compareFunc = compareStringFunc(caseSensitive) + if (!caseSensitive) { + requiredSchema.foreach { readField => + // TODO: check schema inside of complex type + val matchedFields = + parquetFileFields.filter(field => compareFunc(field.getName, readField.name)) + if (matchedFields.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val fieldsString = matchedFields.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException( + s""" + |Found duplicate field(s) "${readField.name}": $fieldsString + + |in case-insensitive mode""".stripMargin.replaceAll("\n" + , " ")) + } + } + } + val requestColNames = requiredSchema.map(_.name) + new Schema(parquetFileFields.filter { field => + requestColNames.exists(col => compareFunc(col, field.getName)) + }.asJava) + } } From f7fa7a58995a5e91e2365c99c6b875eb31bc914f Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Tue, 6 Dec 2022 21:58:29 +0800 Subject: [PATCH 5/8] optimize code --- .../datasources/v2/arrow/ArrowUtils.scala | 31 ++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index 5a24a90d3..9f56b75c0 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -144,30 +144,19 @@ object ArrowUtils { val finalVectors = mutable.ArrayBuffer[ArrowWritableColumnVector]() val requiredIterator = requiredSchema.iterator - val caseSensitive = SQLConf.get.caseSensitiveAnalysis + val compareFunc = compareStringFunc(SQLConf.get.caseSensitiveAnalysis) while (requiredIterator.hasNext) { val field = requiredIterator.next() finalVectors.append( - if (caseSensitive) { - vectors.find(_.getValueVector.getName.equals(field.name)) - .getOrElse { - // The missing column need to be find in nullVectors - val nullVector = nullVectors.find(_.getValueVector.getName.equals(field.name)).get - nullVector.setValueCount(rowCount) - nullVector.retain() - nullVector - } - } else { - vectors.find(_.getValueVector.getName.equalsIgnoreCase(field.name)) - .getOrElse { - // The missing column need to be find in nullVectors - val nullVector = - nullVectors.find(_.getValueVector.getName.equalsIgnoreCase(field.name)).get - nullVector.setValueCount(rowCount) - nullVector.retain() - nullVector - } - }) + vectors.find(_.getValueVector.getName.equals(field.name)) + .getOrElse { + // The missing column need to be find in nullVectors + val nullVector = nullVectors.find(vector => + compareFunc(vector.getValueVector.getName, field.name)).get + nullVector.setValueCount(rowCount) + nullVector.retain() + nullVector + }) } finalVectors.toArray } else { From 761d41b6b5803bd94ad0d79b21471fdca401a9ad Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Tue, 6 Dec 2022 22:00:46 +0800 Subject: [PATCH 6/8] change code --- .../spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index 9f56b75c0..7bb943a06 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -148,7 +148,7 @@ object ArrowUtils { while (requiredIterator.hasNext) { val field = requiredIterator.next() finalVectors.append( - vectors.find(_.getValueVector.getName.equals(field.name)) + vectors.find(vector => compareFunc(vector.getValueVector.getName, field.name)) .getOrElse { // The missing column need to be find in nullVectors val nullVector = nullVectors.find(vector => From 1e1610595053b1f6f9f3c031550c386c1b8acdfb Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Fri, 9 Dec 2022 11:08:34 +0800 Subject: [PATCH 7/8] add schema merge suite tests --- .../datasources/arrow/ArrowFileFormat.scala | 22 +++++-------------- .../datasources/v2/arrow/ArrowFilters.scala | 2 +- .../arrow/ArrowPartitionReaderFactory.scala | 14 +++--------- .../datasources/v2/arrow/ArrowUtils.scala | 16 +++++++++----- .../arrow/ArrowDataSourceTest.scala | 22 +++++++++++++++++++ 5 files changed, 42 insertions(+), 34 deletions(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala index efb40b7e5..3a67f4979 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala @@ -43,7 +43,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions -import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils.UnsafeItr import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils import org.apache.spark.sql.internal.SQLConf @@ -67,13 +66,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi override def inferSchema(sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val arrowOptions = new ArrowOptions(new CaseInsensitiveStringMap(options.asJava).asScala.toMap) - ArrowUtils.getFormat(arrowOptions) match { - case _: org.apache.arrow.dataset.file.format.ParquetFileFormat => - ParquetUtils.inferSchema(sparkSession, options, files) - case _ => - convert(files, options) - } + convert(files, options) } override def prepareWrite(sparkSession: SparkSession, @@ -86,7 +79,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi val conf = ContextUtil.getConfiguration(job) conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) logInfo(s"write parquet with codec: ${parquetOptions.compressionCodecClassName}") - + new OutputWriterFactory { override def getFileExtension(context: TaskAttemptContext): String = { ArrowUtils.getFormat(arrowOptions) match { @@ -100,7 +93,8 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi context: TaskAttemptContext): OutputWriter = { val originPath = path val writeQueue = new ArrowWriteQueue(ArrowUtils.toArrowSchema(dataSchema), - ArrowUtils.getFormat(arrowOptions), parquetOptions.compressionCodecClassName.toLowerCase(), originPath) + ArrowUtils.getFormat(arrowOptions), + parquetOptions.compressionCodecClassName.toLowerCase(), originPath) new OutputWriter { override def write(row: InternalRow): Unit = { @@ -163,12 +157,8 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi } else { filters } - if (pushedFilters == null) { - null - } else { - ArrowFilters.translateFilters( - pushedFilters, caseInsensitiveFieldMap.toMap) - } + ArrowFilters.translateFilters( + pushedFilters, caseInsensitiveFieldMap.toMap) } else { org.apache.arrow.dataset.filter.Filter.EMPTY } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala index a4685f286..aecaf8511 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala @@ -61,7 +61,7 @@ object ArrowFilters { requiredFields: Seq[String]): Seq[Filter] = { val evaluatedFilters = evaluateFilters(pushedFilters, requiredFields) if (evaluatedFilters.exists(_._2 == false)) { - null + Seq.empty[Filter] } else { evaluatedFilters.map(_._1).filterNot(_ == null) } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala index 2ae2051e5..eed3bc527 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala @@ -21,20 +21,16 @@ import java.net.URLDecoder import scala.collection.JavaConverters._ import scala.collection.mutable -import com.google.common.collect.Lists import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowPartitionReaderFactory.ColumnarBatchRetainer import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._ import com.intel.oap.vectorized.ArrowWritableColumnVector import org.apache.arrow.dataset.scanner.ScanOptions -import org.apache.arrow.vector.types.pojo.{Field, Schema} -import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory -import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -85,13 +81,9 @@ case class ArrowPartitionReaderFactory( } else { pushedFilters } - if (filters == null) { - null - } else { - ArrowFilters.translateFilters( - ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema), - caseInsensitiveFieldMap.toMap) - } + ArrowFilters.translateFilters( + ArrowFilters.pruneWithSchema(filters, readDataSchema), + caseInsensitiveFieldMap.toMap) } else { org.apache.arrow.dataset.filter.Filter.EMPTY } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index 7bb943a06..f0bdfb833 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -18,8 +18,6 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow import java.net.URI -import java.nio.charset.StandardCharsets -import java.time.ZoneId import scala.collection.JavaConverters._ import scala.collection.mutable @@ -29,12 +27,12 @@ import org.apache.arrow.dataset.file.FileSystemDatasetFactory import org.apache.arrow.vector.ipc.message.ArrowRecordBatch import org.apache.arrow.vector.types.pojo.{Field, Schema} import org.apache.hadoop.fs.FileStatus -import org.apache.spark.TaskContext +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils} -import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -58,7 +56,13 @@ object ArrowUtils { if (files.isEmpty) { throw new IllegalArgumentException("No input file specified") } - readSchema(files.toList.head, options) // todo merge schema + val arrowOptions = new ArrowOptions(options.asScala.toMap) + ArrowUtils.getFormat(arrowOptions) match { + case _: org.apache.arrow.dataset.file.format.ParquetFileFormat => + ParquetUtils.inferSchema(SparkSession.active, options.asScala.toMap, files) + case _ => + readSchema(files.toList.head, options) // todo merge schema + } } def isOriginalFormatSplitable(options: ArrowOptions): Boolean = { diff --git a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala index e08396d64..d3397d180 100644 --- a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala +++ b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala @@ -482,6 +482,28 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { .arrow(path), 2, 3) } + test("Test schema merge on arrow datasource") { + import testImplicits._ + withTempPath { dir => + val path1 = s"${dir.getCanonicalPath}/table1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.arrow(path1) + val path2 = s"${dir.getCanonicalPath}/table2" + (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.arrow(path2) + + Seq("arrow", "").foreach { v1SourceList => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1SourceList, + SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + + // No matter "c = 1" gets pushed down or not, this query should work without exception. + val df = spark.read.arrow(path1, path2).filter("c = 1").selectExpr("c", "b", "a") + checkAnswer( + df, + Row(1, "1", null)) + } + } + } + } + def verifyFrame(frame: DataFrame, rowCount: Int, columnCount: Int): Unit = { assert(frame.schema.length === columnCount) assert(frame.collect().length === rowCount) From 28c51efffb259a9c7084e98de8af9c6457dc4384 Mon Sep 17 00:00:00 2001 From: jackylee-ch Date: Fri, 9 Dec 2022 11:36:56 +0800 Subject: [PATCH 8/8] add test for struct type --- .../datasources/arrow/ArrowFileFormat.scala | 8 +++++-- .../datasources/v2/arrow/ArrowFilters.scala | 4 ++-- .../arrow/ArrowPartitionReaderFactory.scala | 14 +++++++---- .../arrow/ArrowDataSourceTest.scala | 24 ++++++++++++++----- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala index 3a67f4979..ee2b26b0b 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala @@ -157,8 +157,12 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Logging wi } else { filters } - ArrowFilters.translateFilters( - pushedFilters, caseInsensitiveFieldMap.toMap) + if (pushedFilters == null) { + null + } else { + ArrowFilters.translateFilters( + pushedFilters, caseInsensitiveFieldMap.toMap) + } } else { org.apache.arrow.dataset.filter.Filter.EMPTY } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala index aecaf8511..d92b1c555 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType object ArrowFilters { - def pruneWithSchema(pushedFilters: Array[Filter], schema: StructType): Seq[Filter] = { + def pruneWithSchema(pushedFilters: Seq[Filter], schema: StructType): Seq[Filter] = { pushedFilters.filter(pushedFilter => { isToBeAccepted(pushedFilter, schema) }) @@ -61,7 +61,7 @@ object ArrowFilters { requiredFields: Seq[String]): Seq[Filter] = { val evaluatedFilters = evaluateFilters(pushedFilters, requiredFields) if (evaluatedFilters.exists(_._2 == false)) { - Seq.empty[Filter] + null } else { evaluatedFilters.map(_._1).filterNot(_ == null) } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala index eed3bc527..f3f263cf9 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala @@ -77,13 +77,17 @@ case class ArrowPartitionReaderFactory( val hasMissingColumns = actualReadFields.getFields.size() != readDataSchema.size val filter = if (enableFilterPushDown) { val filters = if (hasMissingColumns) { - ArrowFilters.evaluateMissingFieldFilters(pushedFilters, actualReadFieldNames).toArray + ArrowFilters.evaluateMissingFieldFilters(pushedFilters, actualReadFieldNames) } else { - pushedFilters + pushedFilters.toSeq + } + if (filters == null) { + null + } else { + ArrowFilters.translateFilters( + ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema), + caseInsensitiveFieldMap.toMap) } - ArrowFilters.translateFilters( - ArrowFilters.pruneWithSchema(filters, readDataSchema), - caseInsensitiveFieldMap.toMap) } else { org.apache.arrow.dataset.filter.Filter.EMPTY } diff --git a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala index d3397d180..8561fd320 100644 --- a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala +++ b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala @@ -27,16 +27,14 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions import com.sun.management.UnixOperatingSystemMXBean import org.apache.commons.io.FileUtils -import org.apache.spark.SparkConf -import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} class ArrowDataSourceTest extends QueryTest with SharedSparkSession { import testImplicits._ @@ -490,6 +488,13 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { val path2 = s"${dir.getCanonicalPath}/table2" (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.arrow(path2) + val path3 = s"${dir.getCanonicalPath}/table3" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(path3) + val path4 = s"${dir.getCanonicalPath}/table4" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(path4) + Seq("arrow", "").foreach { v1SourceList => withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1SourceList, SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { @@ -499,6 +504,13 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { checkAnswer( df, Row(1, "1", null)) + + // Not support schema merge and fiter pushdown for struct type + val expr = intercept[SparkException] { + spark.read.arrow(path3, path4).filter("s.c = 1").selectExpr("s").show() + } + assert(expr.getCause.getMessage.contains( + """no more field nodes for for field c""")) } } }