diff --git a/.travis.yml b/.travis.yml
index 6ab1db883..d41228055 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -31,6 +31,7 @@ jobs:
- echo ${TRAVIS_COMMIT_MESSAGE}
#- if [[ ${TRAVIS_COMMIT_MESSAGE} != \[oap-native-sql\]* ]]; then travis_terminate 0 ; fi ;
- sudo apt-get install cmake
+ - sudo apt-get install libboost-all-dev
- export | grep JAVA_HOME
install:
- # Download spark 3.0
diff --git a/oap-data-source/arrow/pom.xml b/oap-data-source/arrow/pom.xml
index 4d65be407..769411cc8 100644
--- a/oap-data-source/arrow/pom.xml
+++ b/oap-data-source/arrow/pom.xml
@@ -8,7 +8,7 @@
2.12.10
2.12
3.1.0-SNAPSHOT
- 0.17.0
+ 0.17.0
UTF-8
UTF-8
@@ -112,40 +112,43 @@
src/test/scala
- org.codehaus.mojo
- build-helper-maven-plugin
+ org.apache.maven.plugins
+ maven-resources-plugin
+ copy-writable-vector-source
generate-sources
- add-source
+ copy-resources
-
-
-
+ ${project.build.directory}/generated-sources/copied/java/
+
+
+ ${project.basedir}/../../oap-native-sql/core/src/main/java/
+
+ **/ArrowWritableColumnVector.java
+
+
+
org.codehaus.mojo
- exec-maven-plugin
+ build-helper-maven-plugin
- download-writable-column-vector
+ add-src-1
generate-sources
- exec
+ add-source
- curl
-
- https://raw.githubusercontent.com/Intel-bigdata/OAP/branch-nativesql-spark-3.0.0/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java
- --create-dirs
- -o>
- ${project.build.directory}/generated-sources/downloaded/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java
-
+
+
+
diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala
index b120fbf3f..b3cf69dcf 100644
--- a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala
+++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala
@@ -17,12 +17,15 @@
package com.intel.oap.spark.sql.execution.datasources.arrow
+import java.net.URLDecoder
+
import scala.collection.JavaConverters._
import com.intel.oap.spark.sql.execution.datasources.arrow.ArrowFileFormat.UnsafeItr
-import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, ArrowOptions}
+import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, ArrowOptions, ExecutionMemoryAllocationListener}
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import org.apache.arrow.dataset.scanner.ScanOptions
+import org.apache.arrow.memory.AllocationListener
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.Job
@@ -71,10 +74,12 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
val sqlConf = sparkSession.sessionState.conf;
val enableFilterPushDown = sqlConf.arrowFilterPushDown
+ val taskMemoryManager = ArrowUtils.getTaskMemoryManager()
val factory = ArrowUtils.makeArrowDiscovery(
- file.filePath, new ArrowOptions(
+ URLDecoder.decode(file.filePath, "UTF-8"), new ArrowOptions(
new CaseInsensitiveStringMap(
- options.asJava).asScala.toMap))
+ options.asJava).asScala.toMap),
+ new ExecutionMemoryAllocationListener(taskMemoryManager))
// todo predicate validation / pushdown
val dataset = factory.finish();
@@ -98,7 +103,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
val itr = itrList
.toIterator
.flatMap(itr => itr.asScala)
- .map(vsr => ArrowUtils.loadVsr(vsr, file.partitionValues, partitionSchema, dataSchema))
+ .map(vsr => ArrowUtils.loadVectors(vsr, file.partitionValues, partitionSchema, dataSchema))
new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]]
}
}
diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala
index 6c2431aaf..47b56f965 100644
--- a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala
+++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala
@@ -16,6 +16,8 @@
*/
package com.intel.oap.spark.sql.execution.datasources.v2.arrow
+import java.net.URLDecoder
+
import scala.collection.JavaConverters._
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowPartitionReaderFactory.ColumnarBatchRetainer
@@ -23,6 +25,7 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.datasources.PartitionedFile
@@ -43,7 +46,7 @@ case class ArrowPartitionReaderFactory(
options: ArrowOptions)
extends FilePartitionReaderFactory {
- private val batchSize = 4096
+ private val batchSize = sqlConf.parquetVectorizedReaderBatchSize
private val enableFilterPushDown: Boolean = sqlConf.arrowFilterPushDown
override def supportColumnarReads(partition: InputPartition): Boolean = true
@@ -56,7 +59,9 @@ case class ArrowPartitionReaderFactory(
override def buildColumnarReader(
partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = {
val path = partitionedFile.filePath
- val factory = ArrowUtils.makeArrowDiscovery(path, options)
+ val taskMemoryManager = ArrowUtils.getTaskMemoryManager()
+ val factory = ArrowUtils.makeArrowDiscovery(URLDecoder.decode(path, "UTF-8"), options,
+ new ExecutionMemoryAllocationListener(taskMemoryManager))
val dataset = factory.finish()
val filter = if (enableFilterPushDown) {
ArrowFilters.translateFilters(ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema))
@@ -79,7 +84,7 @@ case class ArrowPartitionReaderFactory(
val batchItr = vsrItrList
.toIterator
.flatMap(itr => itr.asScala)
- .map(vsr => ArrowUtils.loadVsr(vsr, partitionedFile.partitionValues,
+ .map(bundledVectors => ArrowUtils.loadVectors(bundledVectors, partitionedFile.partitionValues,
readPartitionSchema, readDataSchema))
new PartitionReader[ColumnarBatch] {
diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala
index 7702f0283..ed4e4b183 100644
--- a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala
+++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala
@@ -16,6 +16,7 @@
*/
package com.intel.oap.spark.sql.execution.datasources.v2.arrow
+import org.apache.arrow.memory.AllocationListener
import org.apache.hadoop.fs.FileStatus
import org.apache.spark.sql.SparkSession
diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ExecutionMemoryAllocationListener.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ExecutionMemoryAllocationListener.scala
new file mode 100644
index 000000000..4e2eb96bc
--- /dev/null
+++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ExecutionMemoryAllocationListener.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.oap.spark.sql.execution.datasources.v2.arrow
+
+import org.apache.arrow.memory.{AllocationListener, OutOfMemoryException}
+
+import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager}
+
+class ExecutionMemoryAllocationListener(mm: TaskMemoryManager)
+ extends MemoryConsumer(mm, mm.pageSizeBytes(), MemoryMode.OFF_HEAP) with AllocationListener {
+
+
+ override def onPreAllocation(size: Long): Unit = {
+ if (size == 0) {
+ return
+ }
+ val granted = acquireMemory(size)
+ if (granted < size) {
+ throw new OutOfMemoryException("Failed allocating spark execution memory. Acquired: " +
+ size + ", granted: " + granted)
+ }
+ }
+
+ override def onRelease(size: Long): Unit = {
+ freeMemory(size)
+ }
+
+ override def spill(size: Long, trigger: MemoryConsumer): Long = {
+ // not spillable
+ 0L
+ }
+}
diff --git a/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala
index 101dff283..589fb5e93 100644
--- a/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala
+++ b/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala
@@ -17,28 +17,38 @@
package org.apache.spark.sql.execution.datasources.v2.arrow
import java.net.URI
-import java.util.TimeZone
+import java.util.{TimeZone, UUID}
import scala.collection.JavaConverters._
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions
import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
import org.apache.arrow.dataset.file.{FileSystem, SingleFileDatasetFactory}
-import org.apache.arrow.memory.BaseAllocator
-import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.dataset.scanner.ScanTask
+import org.apache.arrow.memory.{AllocationListener, BaseAllocator}
+import org.apache.arrow.vector.FieldVector
+import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.fs.FileStatus
+import org.apache.spark.TaskContext
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.vectorized.{ColumnVectorUtils, OnHeapColumnVector}
+import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
object ArrowUtils {
+
+ def getTaskMemoryManager(): TaskMemoryManager = {
+ TaskContext.get().taskMemoryManager()
+ }
+
def readSchema(file: FileStatus, options: CaseInsensitiveStringMap): Option[StructType] = {
val factory: SingleFileDatasetFactory =
- makeArrowDiscovery(file.getPath.toString, new ArrowOptions(options.asScala.toMap))
+ makeArrowDiscovery(file.getPath.toString, new ArrowOptions(options.asScala.toMap),
+ AllocationListener.NOOP)
val schema = factory.inspect()
try {
Option(org.apache.spark.sql.util.ArrowUtils.fromArrowSchema(schema))
@@ -50,13 +60,17 @@ object ArrowUtils {
def readSchema(files: Seq[FileStatus], options: CaseInsensitiveStringMap): Option[StructType] =
readSchema(files.toList.head, options) // todo merge schema
- def makeArrowDiscovery(file: String, options: ArrowOptions): SingleFileDatasetFactory = {
+ def makeArrowDiscovery(file: String, options: ArrowOptions,
+ al: AllocationListener): SingleFileDatasetFactory = {
val format = getFormat(options).getOrElse(throw new IllegalStateException)
val fs = getFs(options).getOrElse(throw new IllegalStateException)
-
+ val parent = defaultAllocator()
+ val allocator = parent
+ .newChildAllocator("Spark Managed Allocator - " + UUID.randomUUID().toString, al,
+ 0, parent.getLimit)
val factory = new SingleFileDatasetFactory(
- org.apache.spark.sql.util.ArrowUtils.rootAllocator,
+ allocator,
format,
fs,
rewriteFilePath(file))
@@ -82,12 +96,14 @@ object ArrowUtils {
org.apache.spark.sql.util.ArrowUtils.toArrowSchema(t, TimeZone.getDefault.getID)
}
- def loadVsr(vsr: VectorSchemaRoot, partitionValues: InternalRow,
- partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = {
- val fvs = getDataVectors(vsr, dataSchema)
+ def loadVectors(bundledVectors: ScanTask.ArrowBundledVectors, partitionValues: InternalRow,
+ partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = {
+ val rowCount: Int = getRowCount(bundledVectors)
+ val dataVectors = getDataVectors(bundledVectors, dataSchema)
+ val dictionaryVectors = getDictionaryVectors(bundledVectors, dataSchema)
- val rowCount = vsr.getRowCount
- val vectors = ArrowWritableColumnVector.loadColumns(rowCount, fvs.asJava)
+ val vectors = ArrowWritableColumnVector.loadColumns(rowCount, dataVectors.asJava,
+ dictionaryVectors.asJava)
val partitionColumns = ArrowWritableColumnVector.allocateColumns(rowCount, partitionSchema)
(0 until partitionColumns.length).foreach(i => {
ColumnVectorUtils.populate(partitionColumns(i), partitionValues, i)
@@ -98,16 +114,23 @@ object ArrowUtils {
batch
}
- def rootAllocator(): BaseAllocator = {
+ def defaultAllocator(): BaseAllocator = {
org.apache.spark.sql.util.ArrowUtils.rootAllocator
}
- private def getDataVectors(vsr: VectorSchemaRoot,
+ private def getRowCount(bundledVectors: ScanTask.ArrowBundledVectors) = {
+ val valueVectors = bundledVectors.valueVectors
+ val rowCount = valueVectors.getRowCount
+ rowCount
+ }
+
+ private def getDataVectors(bundledVectors: ScanTask.ArrowBundledVectors,
dataSchema: StructType): List[FieldVector] = {
// TODO Deprecate following (bad performance maybe brought).
// TODO Assert vsr strictly matches dataSchema instead.
+ val valueVectors = bundledVectors.valueVectors
dataSchema.map(f => {
- val vector = vsr.getVector(f.name)
+ val vector = valueVectors.getVector(f.name)
if (vector == null) {
throw new IllegalStateException("Error: no vector named " + f.name + " in record bach")
}
@@ -115,6 +138,30 @@ object ArrowUtils {
}).toList
}
+ private def getDictionaryVectors(bundledVectors: ScanTask.ArrowBundledVectors,
+ dataSchema: StructType): List[FieldVector] = {
+ val valueVectors = bundledVectors.valueVectors
+ val dictionaryVectorMap = bundledVectors.dictionaryVectors
+
+ val fieldNameToDictionaryEncoding = valueVectors.getSchema.getFields.asScala.map(f => {
+ f.getName -> f.getDictionary
+ }).toMap
+
+ val dictionaryVectorsWithNulls = dataSchema.map(f => {
+ val de = fieldNameToDictionaryEncoding(f.name)
+
+ Option(de) match {
+ case None => null
+ case _ =>
+ if (de.getIndexType.getTypeID != ArrowTypeID.Int) {
+ throw new IllegalArgumentException("Wrong index type: " + de.getIndexType)
+ }
+ dictionaryVectorMap.get(de.getId).getVector
+ }
+ }).toList
+ dictionaryVectorsWithNulls
+ }
+
private def getFormat(
options: ArrowOptions): Option[org.apache.arrow.dataset.file.FileFormat] = {
Option(options.originalFormat match {
diff --git a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala
index 402c76a88..976475836 100644
--- a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala
+++ b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.{Executors, TimeUnit}
import com.intel.oap.spark.sql.DataFrameReaderImplicits._
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions
+import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.datasources.v2.arrow.ArrowUtils
import org.apache.spark.sql.internal.SQLConf
@@ -38,6 +39,13 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession {
private val orders = prefix + tpchFolder + "/orders"
private val nation = prefix + tpchFolder + "/nation"
+
+ override protected def sparkConf: SparkConf = {
+ val conf = super.sparkConf
+ conf.set("spark.memory.offHeap.size", String.valueOf(128 * 1024 * 1024))
+ conf
+ }
+
ignore("tpch lineitem - desc") {
val frame = spark.read
.option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet")
@@ -48,6 +56,16 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession {
spark.sql("describe lineitem").show()
}
+ ignore("tpch part - special characters in path") {
+ val frame = spark.read
+ .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet")
+ .option(ArrowOptions.KEY_FILESYSTEM, "hdfs")
+ .arrow(part)
+ frame.createOrReplaceTempView("part")
+
+ spark.sql("select * from part limit 100").show()
+ }
+
ignore("tpch lineitem - read partition values") {
val frame = spark.read
.option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet")
@@ -110,7 +128,7 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession {
val aPrev = System.currentTimeMillis()
(0 until iterations).foreach(_ => {
// scalastyle:off println
- println(ArrowUtils.rootAllocator().getAllocatedMemory())
+ println(ArrowUtils.defaultAllocator().getAllocatedMemory())
// scalastyle:on println
spark.sql("select\n\tsum(l_extendedprice * l_discount) as revenue\n" +
"from\n\tlineitem_arrow\n" +
@@ -258,7 +276,7 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession {
})
Executors.newSingleThreadScheduledExecutor().scheduleWithFixedDelay(() => {
println("[org.apache.spark.sql.util.ArrowUtils.rootAllocator] " +
- "Allocated memory amount: " + ArrowUtils.rootAllocator().getAllocatedMemory)
+ "Allocated memory amount: " + ArrowUtils.defaultAllocator().getAllocatedMemory)
println("[com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector.allocator] " +
"Allocated memory amount: " + com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector.allocator.getAllocatedMemory)
}, 0L, 100L, TimeUnit.MILLISECONDS)
diff --git a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala
index efaa81423..a648438cd 100644
--- a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala
+++ b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala
@@ -22,16 +22,26 @@ import java.lang.management.ManagementFactory
import com.intel.oap.spark.sql.DataFrameReaderImplicits._
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions
+import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
import com.sun.management.UnixOperatingSystemMXBean
import org.apache.commons.io.FileUtils
+import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.execution.datasources.v2.arrow.ArrowUtils
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{LongType, StructField, StructType}
+import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
class ArrowDataSourceTest extends QueryTest with SharedSparkSession {
private val parquetFile1 = "parquet-1.parquet"
private val parquetFile2 = "parquet-2.parquet"
+ private val parquetFile3 = "parquet-3.parquet"
+
+ override protected def sparkConf: SparkConf = {
+ val conf = super.sparkConf
+ conf.set("spark.memory.offHeap.size", String.valueOf(1 * 1024 * 1024))
+ conf
+ }
override def beforeAll(): Unit = {
super.beforeAll()
@@ -47,11 +57,19 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession {
.toDS())
.repartition(1)
.write.parquet(ArrowDataSourceTest.locateResourcePath(parquetFile2))
+
+ spark.read
+ .json(Seq("{\"col1\": \"apple\", \"col2\": 100}", "{\"col1\": \"pear\", \"col2\": 200}",
+ "{\"col1\": \"apple\", \"col2\": 300}")
+ .toDS())
+ .repartition(1)
+ .write.parquet(ArrowDataSourceTest.locateResourcePath(parquetFile3))
}
override def afterAll(): Unit = {
delete(ArrowDataSourceTest.locateResourcePath(parquetFile1))
delete(ArrowDataSourceTest.locateResourcePath(parquetFile2))
+ delete(ArrowDataSourceTest.locateResourcePath(parquetFile3))
super.afterAll()
}
@@ -64,7 +82,7 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession {
.arrow(path))
}
- test("simple SQL query on parquet file") {
+ test("simple SQL query on parquet file - 1") {
val path = ArrowDataSourceTest.locateResourcePath(parquetFile1)
val frame = spark.read
.option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet")
@@ -76,6 +94,27 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession {
verifyParquet(spark.sql("select col from ptab where col is not null or col is null"))
}
+ test("simple SQL query on parquet file - 2") {
+ val path = ArrowDataSourceTest.locateResourcePath(parquetFile3)
+ val frame = spark.read
+ .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet")
+ .option(ArrowOptions.KEY_FILESYSTEM, "hdfs")
+ .arrow(path)
+ frame.createOrReplaceTempView("ptab")
+ val sqlFrame = spark.sql("select * from ptab")
+ assert(
+ sqlFrame.schema ===
+ StructType(Seq(StructField("col1", StringType), StructField("col2", LongType))))
+ val rows = sqlFrame.collect()
+ assert(rows(0).get(0) == "apple")
+ assert(rows(0).get(1) == 100)
+ assert(rows(1).get(0) == "pear")
+ assert(rows(1).get(1) == 200)
+ assert(rows(2).get(0) == "apple")
+ assert(rows(2).get(1) == 300)
+ assert(rows.length === 3)
+ }
+
test("simple SQL query on parquet file with pushed filters") {
val path = ArrowDataSourceTest.locateResourcePath(parquetFile1)
val frame = spark.read
@@ -162,6 +201,11 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession {
def delete(path: String): Unit = {
FileUtils.forceDelete(new File(path))
}
+
+ def closeAllocators(): Unit = {
+ ArrowUtils.defaultAllocator().close()
+ ArrowWritableColumnVector.allocator.close()
+ }
}
object ArrowDataSourceTest {
diff --git a/oap-native-sql/README.md b/oap-native-sql/README.md
index 6ba286707..82ffc6bcf 100644
--- a/oap-native-sql/README.md
+++ b/oap-native-sql/README.md
@@ -1,161 +1,44 @@
-# SparkColumnarPlugin
+# Spark Native SQL Engine
-## Contents
-
-- [Introduction](#introduction)
-- [Installation](#installation)
-- [Benchmark](#benchmark)
-- [Contact](#contact)
+A Native Engine for Spark SQL with vectorze SIMD optimizations
## Introduction
-### Key concepts of this project:
-1. Using Apache Arrow as column vector format as intermediate data among spark operator.
-2. Enable Apache Arrow native readers for Parquet and other formats.
-3. Leverage Apache Arrow Gandiva/Compute to evaluate columnar operator expressions.
-4. (WIP)New native columnar shuffle operator with efficient compression support.
-
-![Overview](/oap-native-sql/resource/Native_SQL_Engine_Intro.jpg)
-
-## Installation
-
-For detailed testing scripts, please refer to [solution guide](https://github.com/Intel-bigdata/Solution_navigator/tree/master/nativesql)
-
-### Installation option 1: For evaluation, simple and fast
-
-#### install spark 3.0.0 or above
-
-[spark download](https://spark.apache.org/downloads.html)
-
-Remove original Arrow Jars inside Spark assemply folder
-``` shell
-yes | rm assembly/target/scala-2.12/jars/arrow-format-0.15.1.jar
-yes | rm assembly/target/scala-2.12/jars/arrow-vector-0.15.1.jar
-yes | rm assembly/target/scala-2.12/jars/arrow-memory-0.15.1.jar
-```
-
-#### install arrow 0.17.0
-
-```
-git clone https://github.com/apache/arrow && cd arrow & git checkout arrow-0.17.0
-vim ci/conda_env_gandiva.yml
-clangdev=7
-llvmdev=7
-
-conda create -y -n pyarrow-dev -c conda-forge \
- --file ci/conda_env_unix.yml \
- --file ci/conda_env_cpp.yml \
- --file ci/conda_env_python.yml \
- --file ci/conda_env_gandiva.yml \
- compilers \
- python=3.7 \
- pandas
-conda activate pyarrow-dev
-```
-
-#### Build native-sql cpp
-
-``` shell
-git clone https://github.com/Intel-bigdata/OAP.git
-cd OAP && git checkout branch-nativesql-spark-3.0.0
-cd oap-native-sql
-cp cpp/src/resources/libhdfs.so ${HADOOP_HOME}/lib/native/
-cp cpp/src/resources/libprotobuf.so.13 /usr/lib64/
-```
-
-Download spark-columnar-core-1.0-jar-with-dependencies.jar to local, add classPath to spark.driver.extraClassPath and spark.executor.extraClassPath
-``` shell
-Internal Location: vsr602://mnt/nvme2/chendi/000000/spark-columnar-core-1.0-jar-with-dependencies.jar
-```
-
-Download spark-sql_2.12-3.1.0-SNAPSHOT.jar to ${SPARK_HOME}/assembly/target/scala-2.12/jars/spark-sql_2.12-3.1.0-SNAPSHOT.jar
-``` shell
-Internal Location: vsr602://mnt/nvme2/chendi/000000/spark-sql_2.12-3.1.0-SNAPSHOT.jar
-```
-
-### Installation option 2: For contribution, Patch and build
-
-#### install spark 3.0.0 or above
-
-Please refer this link to install Spark.
-[Apache Spark Installation](/oap-native-sql/resource/SparkInstallation.md)
-
-Remove original Arrow Jars inside Spark assemply folder
-``` shell
-yes | rm assembly/target/scala-2.12/jars/arrow-format-0.15.1.jar
-yes | rm assembly/target/scala-2.12/jars/arrow-vector-0.15.1.jar
-yes | rm assembly/target/scala-2.12/jars/arrow-memory-0.15.1.jar
-```
-
-#### install arrow 0.17.0
+![Overview](/oap-native-sql/resource/nativesql_arch.png)
-Please refer this markdown to install Apache Arrow and Gandiva.
-[Apache Arrow Installation](/oap-native-sql/resource/ApacheArrowInstallation.md)
+Spark SQL works very well with structured row-based data. It used WholeStageCodeGen to improve the performance by Java JIT code. However Java JIT is usually not working very well on utilizing latest SIMD instructions, espeically under complicated queries. [Apache Arrow](https://arrow.apache.org/) provided CPU-cahce friendly columnar in-memory layout, its SIMD optimized kernels and LLVM based SQL engine Gandiva are also very efficient. Native SQL Engine used these technoligies and brought better performance to Spark SQL.
-#### compile and install oap-native-sql
+## Key Features
-##### Install Googletest and Googlemock
+### Apache Arrow formated intermediate data among Spark operator
-``` shell
-yum install gtest-devel
-yum install gmock
-```
+![Overview](/oap-native-sql/resource/columnar.png)
-##### Build this project
+With [Spark 27396](https://issues.apache.org/jira/browse/SPARK-27396) its possible to pass a RDD of Columnarbatch to operators. We implementd this API with Arrow columnar format.
-``` shell
-git clone https://github.com/Intel-bigdata/OAP.git
-cd OAP && git checkout branch-nativesql-spark-3.0.0
-cd oap-native-sql
-cd cpp/
-mkdir build/
-cd build/
-cmake .. -DTESTS=ON
-make -j
-make install
-#when deploying on multiple node, make sure all nodes copied libhdfs.so and libprotobuf.so.13
-```
+### Apache Arrow based Native Readers for Paruqet and other formats
-``` shell
-cd SparkColumnarPlugin/core/
-mvn clean package -DskipTests
-```
-### Additonal Notes
-[Notes for Installation Issues](/oap-native-sql/resource/InstallationNotes.md)
-
+![Overview](/oap-native-sql/resource/dataset.png)
-## Spark Configuration
+A native parquet reader was developed to speed up the data loading. it's based on Apache Arrow Dataset. For details please check [Arrow Data Source](../oap-data-source/README.md)
-Add below configuration to spark-defaults.conf
-
-```
-##### Columnar Process Configuration
+### Apache Arrow Compute/Gandiva based operators
-spark.sql.parquet.columnarReaderBatchSize 4096
-spark.sql.sources.useV1SourceList avro
-spark.sql.join.preferSortMergeJoin false
-spark.sql.extensions com.intel.sparkColumnarPlugin.ColumnarPlugin
+![Overview](/oap-native-sql/resource/kernel.png)
-spark.driver.extraClassPath ${PATH_TO_OAP_NATIVE_SQL}/core/target/spark-columnar-core-1.0-jar-with-dependencies.jar
-spark.executor.extraClassPath ${PATH_TO_OAP_NATIVE_SQL}/core/target/spark-columnar-core-1.0-jar-with-dependencies.jar
+We implemented common operatos based on Apache Arrow Compute and Gandiva. The SQL expression was compiled to one expression tree with protobuf and passed to native kernels. The native kernels will then evaluate the these expressions based on the input columnar batch.
-######
-```
-## Benchmark
+### Native Columnar Shuffle Operator with efficient compression support
-For initial microbenchmark performance, we add 10 fields up with spark, data size is 200G data
+![Overview](/oap-native-sql/resource/shuffle.png)
-![Performance](/oap-native-sql/resource/performance.png)
+We implemented columnar shuffle to improve the shuffle performance. With the columnar layout we could do very efficient data compression for different data format.
-## Coding Style
+## Testing
-* For Java code, we used [google-java-format](https://github.com/google/google-java-format)
-* For Scala code, we used [Spark Scala Format](https://github.com/apache/spark/blob/master/dev/.scalafmt.conf), please use [scalafmt](https://github.com/scalameta/scalafmt) or run ./scalafmt for scala codes format
-* For Cpp codes, we used Clang-Format, check on this link [google-vim-codefmt](https://github.com/google/vim-codefmt) for details.
+Check out the detailed installation/testing guide(/oap-native-sql/resource/installation.md) for quick testing
-## contact
+## Contact
chendi.xue@intel.com
-yuan.zhou@intel.com
-binwei.yang@intel.com
-jian.zhang@intel.com
+binwei.yang@intel.com
\ No newline at end of file
diff --git a/oap-native-sql/core/pom.xml b/oap-native-sql/core/pom.xml
index abb37bedf..9d452b94b 100644
--- a/oap-native-sql/core/pom.xml
+++ b/oap-native-sql/core/pom.xml
@@ -17,6 +17,7 @@
0.17.0
+ ../cpp/
../cpp/build/releases/
@@ -33,18 +34,39 @@
${spark.version}
provided
+
+ com.intel.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
com.intel.spark
spark-catalyst_${scala.binary.version}
${spark.version}
provided
+
+ com.intel.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
com.intel.spark
spark-sql_${scala.binary.version}
${spark.version}
provided
+
+ com.intel.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
com.intel.spark
spark-network-shuffle_${scala.binary.version}
@@ -83,6 +105,22 @@
+
+ org.apache.arrow
+ arrow-dataset
+ ${arrow.version}
+
+
+ io.netty
+ netty-common
+
+
+ io.netty
+ netty-buffer
+
+
+
+
org.apache.arrow.gandiva
arrow-gandiva
@@ -136,19 +174,19 @@
com.fasterxml.jackson.core
jackson-core
- 2.9.8
+ 2.10.0
test
com.fasterxml.jackson.core
jackson-annotations
- 2.9.8
+ 2.10.0
test
com.fasterxml.jackson.core
jackson-databind
- 2.9.8
+ 2.10.0
test
@@ -157,14 +195,28 @@
${cpp.build.dir}
-
- **/libspark_columnar_jni.so
-
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
+
+ exec-maven-plugin
+ org.codehaus.mojo
+ 1.6.0
+
+
+ Build cpp
+ generate-resources
+
+ exec
+
+
+ ${cpp.dir}/compile.sh
+
+
+
+
net.alchim31.maven
scala-maven-plugin
@@ -278,6 +330,18 @@
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ test
+
+ test
+
+
+
+
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java
index 7a81f3640..b55dcfa93 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java
@@ -47,8 +47,8 @@
import org.slf4j.LoggerFactory;
public class VectorizedParquetArrowReader extends VectorizedParquetRecordReader {
-
- private static final Logger LOG = LoggerFactory.getLogger(VectorizedParquetArrowReader.class);
+ private static final Logger LOG =
+ LoggerFactory.getLogger(VectorizedParquetArrowReader.class);
private ParquetReader reader = null;
private String path;
private long capacity;
@@ -58,45 +58,40 @@ public class VectorizedParquetArrowReader extends VectorizedParquetRecordReader
private int numLoaded = 0;
private int numReaded = 0;
private long totalLength;
+ private String tmp_dir;
private ArrowRecordBatch next_batch;
- //private ColumnarBatch last_columnar_batch;
+ // private ColumnarBatch last_columnar_batch;
private StructType sourceSchema;
private StructType readDataSchema;
private Schema schema = null;
- public VectorizedParquetArrowReader(
- String path,
- ZoneId convertTz,
- boolean useOffHeap,
- int capacity,
- StructType sourceSchema,
- StructType readDataSchema
- ) {
+ public VectorizedParquetArrowReader(String path, ZoneId convertTz, boolean useOffHeap,
+ int capacity, StructType sourceSchema, StructType readDataSchema, String tmp_dir) {
super(convertTz, useOffHeap, capacity);
this.capacity = capacity;
this.path = path;
+ this.tmp_dir = tmp_dir;
this.sourceSchema = sourceSchema;
this.readDataSchema = readDataSchema;
}
@Override
- public void initBatch(StructType partitionColumns, InternalRow partitionValues) {
- }
+ public void initBatch(StructType partitionColumns, InternalRow partitionValues) {}
@Override
public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
- throws IOException, InterruptedException, UnsupportedOperationException {
+ throws IOException, InterruptedException, UnsupportedOperationException {
final ParquetInputSplit parquetInputSplit = toParquetSplit(inputSplit);
final Configuration configuration = ContextUtil.getConfiguration(taskAttemptContext);
initialize(parquetInputSplit, configuration);
}
public void initialize(ParquetInputSplit inputSplit, Configuration configuration)
- throws IOException, InterruptedException, UnsupportedOperationException {
+ throws IOException, InterruptedException, UnsupportedOperationException {
this.totalLength = inputSplit.getLength();
int ordinal = 0;
@@ -104,7 +99,7 @@ public void initialize(ParquetInputSplit inputSplit, Configuration configuration
int[] column_indices = new int[readDataSchema.size()];
List targetSchema = Arrays.asList(readDataSchema.names());
- for (String fieldName: sourceSchema.names()) {
+ for (String fieldName : sourceSchema.names()) {
if (targetSchema.contains(fieldName)) {
column_indices[cur_index++] = ordinal;
}
@@ -116,16 +111,17 @@ public void initialize(ParquetInputSplit inputSplit, Configuration configuration
if (uriPath.contains("hdfs")) {
uriPath = this.path + "?user=root&replication=1";
}
- ParquetInputSplit split = (ParquetInputSplit)inputSplit;
- LOG.info("ParquetReader uri path is " + uriPath + ", rowGroupIndices is " + Arrays.toString(rowGroupIndices) + ", column_indices is " + Arrays.toString(column_indices));
- this.reader = new ParquetReader(uriPath,
- split.getStart(), split.getEnd(), column_indices, capacity, ArrowWritableColumnVector.getNewAllocator());
+ ParquetInputSplit split = (ParquetInputSplit) inputSplit;
+ LOG.info("ParquetReader uri path is " + uriPath + ", rowGroupIndices is "
+ + Arrays.toString(rowGroupIndices) + ", column_indices is "
+ + Arrays.toString(column_indices));
+ this.reader = new ParquetReader(uriPath, split.getStart(), split.getEnd(),
+ column_indices, capacity, ArrowWritableColumnVector.getNewAllocator(), tmp_dir);
}
@Override
- public void initialize(String path, List columns) throws IOException,
- UnsupportedOperationException {
- }
+ public void initialize(String path, List columns)
+ throws IOException, UnsupportedOperationException {}
@Override
public boolean nextKeyValue() throws IOException {
@@ -155,7 +151,7 @@ public Object getCurrentValue() {
}
numReaded += lastReadLength;
ArrowWritableColumnVector[] columnVectors =
- ArrowWritableColumnVector.loadColumns(next_batch.getLength(), schema, next_batch);
+ ArrowWritableColumnVector.loadColumns(next_batch.getLength(), schema, next_batch);
next_batch.close();
return new ColumnarBatch(columnVectors, next_batch.getLength());
}
@@ -170,10 +166,11 @@ public void close() throws IOException {
@Override
public float getProgress() {
- return (float) (numReaded/totalLength);
+ return (float) (numReaded / totalLength);
}
- private int[] filterRowGroups(ParquetInputSplit parquetInputSplit, Configuration configuration) throws IOException {
+ private int[] filterRowGroups(ParquetInputSplit parquetInputSplit,
+ Configuration configuration) throws IOException {
final long[] rowGroupOffsets = parquetInputSplit.getRowGroupOffsets();
if (rowGroupOffsets != null) {
throw new UnsupportedOperationException();
@@ -184,30 +181,34 @@ private int[] filterRowGroups(ParquetInputSplit parquetInputSplit, Configuration
final List filteredRowGroups;
final List unfilteredRowGroups;
- try (ParquetFileReader reader = ParquetFileReader.open(
- HadoopInputFile.fromPath(path, configuration), createOptions(parquetInputSplit, configuration))) {
+ try (ParquetFileReader reader =
+ ParquetFileReader.open(HadoopInputFile.fromPath(path, configuration),
+ createOptions(parquetInputSplit, configuration))) {
unfilteredRowGroups = reader.getFooter().getBlocks();
filteredRowGroups = reader.getRowGroups();
}
final int[] acc = {0};
- final Map dict = unfilteredRowGroups.stream()
- .collect(Collectors.toMap(BlockMetaDataWrapper::wrap, b -> acc[0]++));
+ final Map dict = unfilteredRowGroups.stream().collect(
+ Collectors.toMap(BlockMetaDataWrapper::wrap, b -> acc[0]++));
return filteredRowGroups.stream()
- .map(BlockMetaDataWrapper::wrap)
- .map(b -> {
- if (!dict.containsKey(b)) {
- // This should not happen
- throw new IllegalStateException("Unrecognizable filtered row group: " + b);
- }
- return dict.get(b);
- }).mapToInt(n -> n).toArray();
+ .map(BlockMetaDataWrapper::wrap)
+ .map(b -> {
+ if (!dict.containsKey(b)) {
+ // This should not happen
+ throw new IllegalStateException("Unrecognizable filtered row group: " + b);
+ }
+ return dict.get(b);
+ })
+ .mapToInt(n -> n)
+ .toArray();
}
- private ParquetReadOptions createOptions(ParquetInputSplit split, Configuration configuration) {
+ private ParquetReadOptions createOptions(
+ ParquetInputSplit split, Configuration configuration) {
return HadoopReadOptions.builder(configuration)
- .withRange(split.getStart(), split.getEnd())
- .build();
+ .withRange(split.getStart(), split.getEnd())
+ .build();
}
private ParquetInputSplit toParquetSplit(InputSplit split) throws IOException {
@@ -215,11 +216,12 @@ private ParquetInputSplit toParquetSplit(InputSplit split) throws IOException {
return (ParquetInputSplit) split;
} else {
throw new IllegalArgumentException(
- "Invalid split (not a ParquetInputSplit): " + split);
+ "Invalid split (not a ParquetInputSplit): " + split);
}
}
- // ID for BlockMetaData, to prevent from resulting in mutable BlockMetaData instances after being filtered
+ // ID for BlockMetaData, to prevent from resulting in mutable BlockMetaData instances
+ // after being filtered
private static class BlockMetaDataWrapper {
private BlockMetaData m;
@@ -233,8 +235,10 @@ public static BlockMetaDataWrapper wrap(BlockMetaData m) {
@Override
public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
BlockMetaDataWrapper that = (BlockMetaDataWrapper) o;
return equals(m, that.m);
}
@@ -252,6 +256,4 @@ private int hash(BlockMetaData m) {
return Objects.hash(m.getStartingPos());
}
}
-
}
-
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java
index a0d5530bb..80af1d356 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java
@@ -36,7 +36,6 @@
/** Parquet Reader Class. */
public class ParquetReader implements AutoCloseable {
-
/** reference to native reader instance. */
private long nativeInstanceId;
@@ -56,14 +55,9 @@ public class ParquetReader implements AutoCloseable {
* @param allocator A BufferAllocator reference.
* @throws IOException throws io exception in case of native failure.
*/
- public ParquetReader(
- String path,
- int[] rowGroupIndices,
- int[] columnIndices,
- long batchSize,
- BufferAllocator allocator)
- throws IOException {
- this.jniWrapper = new ParquetReaderJniWrapper();
+ public ParquetReader(String path, int[] rowGroupIndices, int[] columnIndices,
+ long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException {
+ this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir);
this.allocator = allocator;
this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize);
jniWrapper.nativeInitParquetReader(nativeInstanceId, columnIndices, rowGroupIndices);
@@ -80,18 +74,13 @@ public ParquetReader(
* @param allocator A BufferAllocator reference.
* @throws IOException throws io exception in case of native failure.
*/
- public ParquetReader(
- String path,
- long startPos,
- long endPos,
- int[] columnIndices,
- long batchSize,
- BufferAllocator allocator)
- throws IOException {
- this.jniWrapper = new ParquetReaderJniWrapper();
+ public ParquetReader(String path, long startPos, long endPos, int[] columnIndices,
+ long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException {
+ this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir);
this.allocator = allocator;
this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize);
- jniWrapper.nativeInitParquetReader2(nativeInstanceId, columnIndices, startPos, endPos);
+ jniWrapper.nativeInitParquetReader2(
+ nativeInstanceId, columnIndices, startPos, endPos);
}
/**
@@ -103,10 +92,9 @@ public ParquetReader(
public Schema getSchema() throws IOException {
byte[] schemaBytes = jniWrapper.nativeGetSchema(nativeInstanceId);
- try (MessageChannelReader schemaReader =
- new MessageChannelReader(
- new ReadChannel(new ByteArrayReadableSeekableByteChannel(schemaBytes)), allocator)) {
-
+ try (MessageChannelReader schemaReader = new MessageChannelReader(
+ new ReadChannel(new ByteArrayReadableSeekableByteChannel(schemaBytes)),
+ allocator)) {
MessageResult result = schemaReader.readNext();
if (result == null) {
throw new IOException("Unexpected end of input. Missing schema.");
@@ -123,7 +111,8 @@ public Schema getSchema() throws IOException {
* @throws IOException throws io exception in case of native failure
*/
public ArrowRecordBatch readNext() throws IOException {
- ArrowRecordBatchBuilder recordBatchBuilder = jniWrapper.nativeReadNext(nativeInstanceId);
+ ArrowRecordBatchBuilder recordBatchBuilder =
+ jniWrapper.nativeReadNext(nativeInstanceId);
if (recordBatchBuilder == null) {
return null;
}
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java
index f52951ddc..efc4d6ead 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java
@@ -22,13 +22,11 @@
import java.io.IOException;
-
/** Wrapper for Parquet Reader native API. */
public class ParquetReaderJniWrapper {
-
/** Construct a Jni Instance. */
- ParquetReaderJniWrapper() throws IOException {
- JniUtils.getInstance();
+ ParquetReaderJniWrapper(String tmp_dir) throws IOException {
+ JniUtils.getInstance(tmp_dir);
}
/**
@@ -39,7 +37,8 @@ public class ParquetReaderJniWrapper {
* @return long id of the parquet reader instance
* @throws IOException throws exception in case of any io exception in native codes
*/
- public native long nativeOpenParquetReader(String path, long batchSize) throws IOException;
+ public native long nativeOpenParquetReader(String path, long batchSize)
+ throws IOException;
/**
* Init a parquet file reader by specifying columns and rowgroups.
@@ -49,8 +48,8 @@ public class ParquetReaderJniWrapper {
* @param rowGroupIndices a array of indexes indicate which row groups to be read
* @throws IOException throws exception in case of any io exception in native codes
*/
- public native void nativeInitParquetReader(long id, int[] columnIndices, int[] rowGroupIndices)
- throws IOException;
+ public native void nativeInitParquetReader(
+ long id, int[] columnIndices, int[] rowGroupIndices) throws IOException;
/**
* Init a parquet file reader by specifying columns and rowgroups.
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java
index 0a46b8deb..67f6a875b 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java
@@ -17,13 +17,11 @@
package com.intel.sparkColumnarPlugin.vectorized;
+import io.netty.buffer.ArrowBuf;
import java.io.IOException;
-import java.lang.UnsupportedOperationException;
import java.util.concurrent.atomic.AtomicInteger;
-import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.memory.OwnershipTransferResult;
-import org.apache.arrow.memory.ReferenceManager;
+import org.apache.arrow.memory.*;
import org.apache.arrow.util.Preconditions;
import io.netty.buffer.ArrowBuf;
@@ -31,8 +29,8 @@
import org.slf4j.LoggerFactory;
/**
- * A simple reference manager implementation for memory allocated by native
- * code. The underlying memory will be released when reference count reach zero.
+ * A simple reference manager implementation for memory allocated by native code. The underlying
+ * memory will be released when reference count reach zero.
*/
public class AdaptorReferenceManager implements ReferenceManager {
private native void nativeRelease(long nativeMemoryHolder);
@@ -42,10 +40,14 @@ public class AdaptorReferenceManager implements ReferenceManager {
private long nativeMemoryHolder;
private int size = 0;
+ // Required by netty dependencies, but is never used.
+ private BaseAllocator allocator;
+
AdaptorReferenceManager(long nativeMemoryHolder, int size) throws IOException {
JniUtils.getInstance();
this.nativeMemoryHolder = nativeMemoryHolder;
this.size = size;
+ this.allocator = new RootAllocator(0);
}
@Override
@@ -60,7 +62,8 @@ public boolean release() {
@Override
public boolean release(int decrement) {
- Preconditions.checkState(decrement >= 1, "ref count decrement should be greater than or equal to 1");
+ Preconditions.checkState(
+ decrement >= 1, "ref count decrement should be greater than or equal to 1");
// decrement the ref count
final int refCnt;
synchronized (this) {
@@ -104,13 +107,14 @@ public ArrowBuf deriveBuffer(ArrowBuf sourceBuffer, long index, long length) {
}
@Override
- public OwnershipTransferResult transferOwnership(ArrowBuf sourceBuffer, BufferAllocator targetAllocator) {
- throw new UnsupportedOperationException();
+ public OwnershipTransferResult transferOwnership(
+ ArrowBuf sourceBuffer, BufferAllocator targetAllocator) {
+ return NO_OP.transferOwnership(sourceBuffer, targetAllocator);
}
@Override
public BufferAllocator getAllocator() {
- return null;
+ return allocator;
}
@Override
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowCompressedStreamReader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowCompressedStreamReader.java
new file mode 100644
index 000000000..f5fbf52f4
--- /dev/null
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowCompressedStreamReader.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.sparkColumnarPlugin.vectorized;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.ipc.ArrowStreamReader;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.DictionaryUtility;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.*;
+
+/**
+ * This class reads from an input stream containing compressed buffers and produces
+ * ArrowRecordBatches.
+ */
+public class ArrowCompressedStreamReader extends ArrowStreamReader {
+
+ public ArrowCompressedStreamReader(InputStream in, BufferAllocator allocator) {
+ super(in, allocator);
+ }
+
+ protected void initialize() throws IOException {
+ Schema originalSchema = readSchema();
+ List fields = new ArrayList<>();
+ List vectors = new ArrayList<>();
+ Map dictionaries = new HashMap<>();
+
+ // Convert fields with dictionaries to have the index type
+ for (Field field : originalSchema.getFields()) {
+ Field updated = DictionaryUtility.toMemoryFormat(field, allocator, dictionaries);
+ fields.add(updated);
+ vectors.add(updated.createVector(allocator));
+ }
+ Schema schema = new Schema(fields, originalSchema.getCustomMetadata());
+
+ this.root = new VectorSchemaRoot(schema, vectors, 0);
+ this.loader = new CompressedVectorLoader(root);
+ this.dictionaries = Collections.unmodifiableMap(dictionaries);
+ }
+
+ @Override
+ protected void loadRecordBatch(ArrowRecordBatch batch) {
+ try {
+ ((CompressedVectorLoader) loader).loadCompressed(batch);
+ } finally {
+ batch.close();
+ }
+ }
+}
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java
index c0f6e1cf0..bb876ac3f 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java
@@ -29,6 +29,7 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
+import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -62,6 +63,7 @@ public final class ArrowWritableColumnVector extends WritableColumnVector {
private int ordinal;
private ValueVector vector;
+ private ValueVector dictionaryVector;
public static BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
public static BufferAllocator getNewAllocator() {
@@ -90,7 +92,24 @@ public static ArrowWritableColumnVector[] allocateColumns(int capacity, StructTy
return vectors;
}
- public static ArrowWritableColumnVector[] loadColumns(int capacity, List fieldVectors) {
+ public static ArrowWritableColumnVector[] loadColumns(int capacity,
+ List fieldVectors,
+ List dictionaryVectors) {
+ if (fieldVectors.size() != dictionaryVectors.size()) {
+ throw new IllegalArgumentException("Mismatched field vectors and dictionary vectors. " +
+ "Field vector count: " + fieldVectors.size() + ", " +
+ "dictionary vector count: " + dictionaryVectors.size());
+ }
+ ArrowWritableColumnVector[] vectors = new ArrowWritableColumnVector[fieldVectors.size()];
+ for (int i = 0; i < fieldVectors.size(); i++) {
+ vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), dictionaryVectors.get(i),
+ i, capacity, false);
+ }
+ return vectors;
+ }
+
+ public static ArrowWritableColumnVector[] loadColumns(int capacity,
+ List fieldVectors) {
ArrowWritableColumnVector[] vectors = new ArrowWritableColumnVector[fieldVectors.size()];
for (int i = 0; i < fieldVectors.size(); i++) {
vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), i, capacity, false);
@@ -106,19 +125,26 @@ public static ArrowWritableColumnVector[] loadColumns(int capacity, Schema arrow
return loadColumns(capacity, root.getFieldVectors());
}
- public ArrowWritableColumnVector(ValueVector vector, int ordinal, int capacity, boolean init){
+ @Deprecated
+ public ArrowWritableColumnVector(ValueVector vector, int ordinal, int capacity, boolean init) {
+ this(vector, null, ordinal, capacity, init);
+ }
+
+ public ArrowWritableColumnVector(ValueVector vector, ValueVector dicionaryVector,
+ int ordinal, int capacity, boolean init) {
super(capacity, ArrowUtils.fromArrowField(vector.getField()));
vectorCount.getAndIncrement();
refCnt.getAndIncrement();
this.ordinal = ordinal;
this.vector = vector;
+ this.dictionaryVector = dicionaryVector;
if (init) {
vector.setInitialCapacity(capacity);
vector.allocateNew();
}
writer = createVectorWriter(vector);
- createVectorAccessor(vector);
+ createVectorAccessor(vector, dicionaryVector);
}
public ArrowWritableColumnVector(int capacity, DataType dataType) {
@@ -135,15 +161,30 @@ public ArrowWritableColumnVector(int capacity, DataType dataType) {
vector.setInitialCapacity(capacity);
vector.allocateNew();
this.writer = createVectorWriter(vector);
- createVectorAccessor(vector);
-
+ createVectorAccessor(vector, null);
}
public ValueVector getValueVector() {
return vector;
}
- private void createVectorAccessor(ValueVector vector) {
+ private void createVectorAccessor(ValueVector vector, ValueVector dictionary) {
+ if (dictionary != null) {
+ if (!(vector instanceof IntVector)) {
+ throw new IllegalArgumentException("Expect int32 index vector. Found: " +
+ vector.getMinorType());
+ }
+ IntVector index = (IntVector) vector;
+ if (dictionary instanceof VarBinaryVector) {
+ accessor = new DictionaryEncodedBinaryAccessor(index, (VarBinaryVector) dictionary);
+ } else if (dictionary instanceof VarCharVector) {
+ accessor = new DictionaryEncodedStringAccessor(index, (VarCharVector) dictionary);
+ } else {
+ throw new IllegalArgumentException("Unrecognized index value type: " +
+ dictionary.getMinorType());
+ }
+ return;
+ }
if (vector instanceof BitVector) {
accessor = new BooleanAccessor((BitVector) vector);
} else if (vector instanceof TinyIntVector) {
@@ -271,6 +312,9 @@ public void close() {
childColumns = null;
}
vector.close();
+ if (dictionaryVector != null) {
+ dictionaryVector.close();
+ }
}
public static String stat() {
@@ -675,7 +719,8 @@ public void putArray(int rowId, int offset, int length) {
@Override
public int putByteArray(int rowId, byte[] value, int offset, int length) {
- throw new UnsupportedOperationException();
+ writer.setBytes(rowId, length, value, offset);
+ return length;
}
//
@@ -963,6 +1008,32 @@ final UTF8String getUTF8String(int rowId) {
}
}
+ private static class DictionaryEncodedStringAccessor extends ArrowVectorAccessor {
+
+ private final IntVector index;
+ private final VarCharVector dictionary;
+ private final NullableVarCharHolder stringResult = new NullableVarCharHolder();
+
+ DictionaryEncodedStringAccessor(IntVector index, VarCharVector dictionary) {
+ super(index);
+ this.index = index;
+ this.dictionary = dictionary;
+ }
+
+ @Override
+ final UTF8String getUTF8String(int rowId) {
+ int idx = index.get(rowId);
+ dictionary.get(idx, stringResult);
+ if (stringResult.isSet == 0) {
+ return null;
+ } else {
+ return UTF8String.fromAddress(null,
+ stringResult.buffer.memoryAddress() + stringResult.start,
+ stringResult.end - stringResult.start);
+ }
+ }
+ }
+
private static class BinaryAccessor extends ArrowVectorAccessor {
private final VarBinaryVector accessor;
@@ -978,6 +1049,23 @@ final byte[] getBinary(int rowId) {
}
}
+ private static class DictionaryEncodedBinaryAccessor extends ArrowVectorAccessor {
+ private final IntVector index;
+ private final VarBinaryVector dictionary;
+
+ DictionaryEncodedBinaryAccessor(IntVector index, VarBinaryVector dictionary) {
+ super(index);
+ this.index = index;
+ this.dictionary = dictionary;
+ }
+
+ @Override
+ final byte[] getBinary(int rowId) {
+ int idx = index.get(rowId);
+ return dictionary.getObject(idx);
+ }
+ }
+
private static class DateAccessor extends ArrowVectorAccessor {
private final DateDayVector accessor;
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/CompressedVectorLoader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/CompressedVectorLoader.java
new file mode 100644
index 000000000..52398cad0
--- /dev/null
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/CompressedVectorLoader.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.sparkColumnarPlugin.vectorized;
+
+import java.util.Iterator;
+
+import org.apache.arrow.util.Collections2;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+
+import io.netty.buffer.ArrowBuf;
+
+/** Loads compressed buffers into vectors. */
+public class CompressedVectorLoader extends VectorLoader {
+
+ /**
+ * Construct with a root to load and will create children in root based on schema.
+ *
+ * @param root the root to add vectors to based on schema
+ */
+ public CompressedVectorLoader(VectorSchemaRoot root) {
+ super(root);
+ }
+
+ /**
+ * Loads the record batch in the vectors. will not close the record batch
+ *
+ * @param recordBatch the batch to load
+ */
+ public void loadCompressed(ArrowRecordBatch recordBatch) {
+ Iterator buffers = recordBatch.getBuffers().iterator();
+ Iterator nodes = recordBatch.getNodes().iterator();
+ for (FieldVector fieldVector : root.getFieldVectors()) {
+ loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes);
+ }
+ root.setRootRowCount(recordBatch.getLength());
+ if (nodes.hasNext() || buffers.hasNext()) {
+ throw new IllegalArgumentException(
+ "not all nodes and buffers were consumed. nodes: "
+ + Collections2.toList(nodes).toString()
+ + " buffers: "
+ + Collections2.toList(buffers).toString());
+ }
+ }
+}
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java
index 4076d8079..8311ea93c 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java
@@ -17,6 +17,7 @@
package com.intel.sparkColumnarPlugin.vectorized;
+import com.intel.sparkColumnarPlugin.ColumnarPluginConfig;
import io.netty.buffer.ArrowBuf;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -33,59 +34,75 @@
import org.apache.arrow.vector.types.pojo.Schema;
public class ExpressionEvaluator implements AutoCloseable {
-
private long nativeHandler = 0;
private ExpressionEvaluatorJniWrapper jniWrapper;
/** Wrapper for native API. */
public ExpressionEvaluator() throws IOException {
- jniWrapper = new ExpressionEvaluatorJniWrapper();
+ String tmp_dir = ColumnarPluginConfig.getTempFile();
+ if (tmp_dir == null) {
+ tmp_dir = System.getProperty("java.io.tmpdir");
+ }
+ jniWrapper = new ExpressionEvaluatorJniWrapper(tmp_dir);
+ jniWrapper.nativeSetJavaTmpDir(tmp_dir);
+ jniWrapper.nativeSetBatchSize(ColumnarPluginConfig.getBatchSize());
}
/** Convert ExpressionTree into native function. */
public void build(Schema schema, List exprs)
throws RuntimeException, IOException, GandivaException {
- nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, false);
+ nativeHandler = jniWrapper.nativeBuild(
+ getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, false);
}
/** Convert ExpressionTree into native function. */
public void build(Schema schema, List exprs, boolean finishReturn)
throws RuntimeException, IOException, GandivaException {
- nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, finishReturn);
+ nativeHandler = jniWrapper.nativeBuild(
+ getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, finishReturn);
}
/** Convert ExpressionTree into native function. */
public void build(Schema schema, List exprs, Schema resSchema)
throws RuntimeException, IOException, GandivaException {
- nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), false);
+ nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema),
+ getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), false);
}
/** Convert ExpressionTree into native function. */
- public void build(Schema schema, List exprs, Schema resSchema, boolean finishReturn)
- throws RuntimeException, IOException, GandivaException {
- nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), finishReturn);
+ public void build(Schema schema, List exprs, Schema resSchema,
+ boolean finishReturn) throws RuntimeException, IOException, GandivaException {
+ nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema),
+ getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), finishReturn);
}
/** Convert ExpressionTree into native function. */
- public void build(Schema schema, List exprs, List finish_exprs)
+ public void build(
+ Schema schema, List exprs, List finish_exprs)
throws RuntimeException, IOException, GandivaException {
- nativeHandler = jniWrapper.nativeBuildWithFinish(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), getExprListBytesBuf(finish_exprs));
+ nativeHandler = jniWrapper.nativeBuildWithFinish(getSchemaBytesBuf(schema),
+ getExprListBytesBuf(exprs), getExprListBytesBuf(finish_exprs));
}
/** Set result Schema in some special cases */
- public void setReturnFields(Schema schema) throws RuntimeException, IOException, GandivaException {
+ public void setReturnFields(Schema schema)
+ throws RuntimeException, IOException, GandivaException {
jniWrapper.nativeSetReturnFields(nativeHandler, getSchemaBytesBuf(schema));
}
- /** Evaluate input data using builded native function, and output as recordBatch. */
+ /**
+ * Evaluate input data using builded native function, and output as recordBatch.
+ */
public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch)
throws RuntimeException, IOException {
return evaluate(recordBatch, null);
}
- /** Evaluate input data using builded native function, and output as recordBatch. */
- public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVectorInt16 selectionVector)
- throws RuntimeException, IOException {
+ /**
+ * Evaluate input data using builded native function, and output as recordBatch.
+ */
+ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch,
+ SelectionVectorInt16 selectionVector) throws RuntimeException, IOException {
List buffers = recordBatch.getBuffers();
List buffersLayout = recordBatch.getBuffersLayout();
long[] bufAddrs = new long[buffers.size()];
@@ -105,14 +122,15 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVector
int selectionVectorRecordCount = selectionVector.getRecordCount();
long selectionVectorAddr = selectionVector.getBuffer().memoryAddress();
long selectionVectorSize = selectionVector.getBuffer().capacity();
- resRecordBatchBuilderList =
- jniWrapper.nativeEvaluateWithSelection(nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes,
- selectionVectorRecordCount, selectionVectorAddr, selectionVectorSize);
+ resRecordBatchBuilderList = jniWrapper.nativeEvaluateWithSelection(nativeHandler,
+ recordBatch.getLength(), bufAddrs, bufSizes, selectionVectorRecordCount,
+ selectionVectorAddr, selectionVectorSize);
} else {
- resRecordBatchBuilderList =
- jniWrapper.nativeEvaluate(nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes);
+ resRecordBatchBuilderList = jniWrapper.nativeEvaluate(
+ nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes);
}
- ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[resRecordBatchBuilderList.length];
+ ArrowRecordBatch[] recordBatchList =
+ new ArrowRecordBatch[resRecordBatchBuilderList.length];
for (int i = 0; i < resRecordBatchBuilderList.length; i++) {
if (resRecordBatchBuilderList[i] == null) {
recordBatchList[i] = null;
@@ -125,7 +143,9 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVector
return recordBatchList;
}
- /** Evaluate input data using builded native function, and output as recordBatch. */
+ /**
+ * Evaluate input data using builded native function, and output as recordBatch.
+ */
public void SetMember(ArrowRecordBatch recordBatch)
throws RuntimeException, IOException {
List buffers = recordBatch.getBuffers();
@@ -142,12 +162,15 @@ public void SetMember(ArrowRecordBatch recordBatch)
bufSizes[idx++] = bufLayout.getSize();
}
- jniWrapper.nativeSetMember(nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes);
+ jniWrapper.nativeSetMember(
+ nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes);
}
public ArrowRecordBatch[] finish() throws RuntimeException, IOException {
- ArrowRecordBatchBuilder[] resRecordBatchBuilderList = jniWrapper.nativeFinish(nativeHandler);
- ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[resRecordBatchBuilderList.length];
+ ArrowRecordBatchBuilder[] resRecordBatchBuilderList =
+ jniWrapper.nativeFinish(nativeHandler);
+ ArrowRecordBatch[] recordBatchList =
+ new ArrowRecordBatch[resRecordBatchBuilderList.length];
for (int i = 0; i < resRecordBatchBuilderList.length; i++) {
if (resRecordBatchBuilderList[i] == null) {
recordBatchList[i] = null;
@@ -169,7 +192,8 @@ public void setDependency(BatchIterator child) throws RuntimeException, IOExcept
jniWrapper.nativeSetDependency(nativeHandler, child.getInstanceId(), -1);
}
- public void setDependency(BatchIterator child, int index) throws RuntimeException, IOException {
+ public void setDependency(BatchIterator child, int index)
+ throws RuntimeException, IOException {
jniWrapper.nativeSetDependency(nativeHandler, child.getInstanceId(), index);
}
@@ -185,7 +209,8 @@ byte[] getSchemaBytesBuf(Schema schema) throws IOException {
}
byte[] getExprListBytesBuf(List exprs) throws GandivaException {
- GandivaTypes.ExpressionList.Builder builder = GandivaTypes.ExpressionList.newBuilder();
+ GandivaTypes.ExpressionList.Builder builder =
+ GandivaTypes.ExpressionList.newBuilder();
for (ExpressionTree expr : exprs) {
builder.addExprs(expr.toProtobuf());
}
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java
index de8a14357..b6b0d7997 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java
@@ -20,102 +20,133 @@
import java.io.IOException;
/**
- * This class is implemented in JNI. This provides the Java interface to invoke functions in JNI.
- * This file is used to generated the .h files required for jni. Avoid all external dependencies in
- * this file.
+ * This class is implemented in JNI. This provides the Java interface to invoke
+ * functions in JNI. This file is used to generated the .h files required for
+ * jni. Avoid all external dependencies in this file.
*/
public class ExpressionEvaluatorJniWrapper {
-
/** Wrapper for native API. */
- public ExpressionEvaluatorJniWrapper() throws IOException {
- JniUtils.getInstance();
+ public ExpressionEvaluatorJniWrapper(String tmp_dir) throws IOException {
+ JniUtils.getInstance(tmp_dir);
}
/**
- * Generates the projector module to evaluate the expressions with custom configuration.
+ * Set native env variables NATIVE_TMP_DIR
+ *
+ * @param path tmp path for native codes, use java.io.tmpdir
+ */
+ native void nativeSetJavaTmpDir(String path);
+
+ /**
+ * Set native env variables NATIVE_BATCH_SIZE
*
- * @param schemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf
- * specification
- * @param exprListBuf The serialized protobuf of the expression vector. Each expression is created
- * using TreeBuilder::MakeExpression.
- * @param resSchemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf
- * specification
- * @param finishReturn This parameter is used to indicate that this expression should return when calling finish
- * @return A nativeHandler that is passed to the evaluateProjector() and closeProjector() methods
+ * @param batch_size numRows of one batch, use
+ * spark.sql.execution.arrow.maxRecordsPerBatch
*/
- native long nativeBuild(byte[] schemaBuf, byte[] exprListBuf, byte[] resSchemaBuf, boolean finishReturn) throws RuntimeException;
+ native void nativeSetBatchSize(int batch_size);
/**
- * Generates the projector module to evaluate the expressions with custom configuration.
+ * Generates the projector module to evaluate the expressions with custom
+ * configuration.
*
- * @param schemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf
- * specification
- * @param exprListBuf The serialized protobuf of the expression vector. Each expression is created
- * using TreeBuilder::MakeExpression.
- * @param finishExprListBuf The serialized protobuf of the expression vector. Each expression is created
- * using TreeBuilder::MakeExpression.
- * @return A nativeHandler that is passed to the evaluateProjector() and closeProjector() methods
+ * @param schemaBuf The schema serialized as a protobuf. See Types.proto to
+ * see the protobuf specification
+ * @param exprListBuf The serialized protobuf of the expression vector. Each
+ * expression is created using TreeBuilder::MakeExpression.
+ * @param resSchemaBuf The schema serialized as a protobuf. See Types.proto to
+ * see the protobuf specification
+ * @param finishReturn This parameter is used to indicate that this expression
+ * should return when calling finish
+ * @return A nativeHandler that is passed to the evaluateProjector() and
+ * closeProjector() methods
*/
- native long nativeBuildWithFinish(byte[] schemaBuf, byte[] exprListBuf, byte[] finishExprListBuf) throws RuntimeException;
+ native long nativeBuild(byte[] schemaBuf, byte[] exprListBuf, byte[] resSchemaBuf,
+ boolean finishReturn) throws RuntimeException;
+
+ /**
+ * Generates the projector module to evaluate the expressions with custom
+ * configuration.
+ *
+ * @param schemaBuf The schema serialized as a protobuf. See Types.proto
+ * to see the protobuf specification
+ * @param exprListBuf The serialized protobuf of the expression vector.
+ * Each expression is created using
+ * TreeBuilder::MakeExpression.
+ * @param finishExprListBuf The serialized protobuf of the expression vector.
+ * Each expression is created using
+ * TreeBuilder::MakeExpression.
+ * @return A nativeHandler that is passed to the evaluateProjector() and
+ * closeProjector() methods
+ */
+ native long nativeBuildWithFinish(byte[] schemaBuf, byte[] exprListBuf,
+ byte[] finishExprListBuf) throws RuntimeException;
/**
* Set return schema for this expressionTree.
*
- * @param nativeHandler nativeHandler representing expressions. Created using a call to
- * buildNativeCode
- * @param schemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf
- * specification
+ * @param nativeHandler nativeHandler representing expressions. Created using a
+ * call to buildNativeCode
+ * @param schemaBuf The schema serialized as a protobuf. See Types.proto to
+ * see the protobuf specification
*/
- native void nativeSetReturnFields(long nativeHandler, byte[] schemaBuf) throws RuntimeException;
+ native void nativeSetReturnFields(long nativeHandler, byte[] schemaBuf)
+ throws RuntimeException;
/**
- * Evaluate the expressions represented by the nativeHandler on a record batch and store the
- * output in ValueVectors. Throws an exception in case of errors
+ * Evaluate the expressions represented by the nativeHandler on a record batch
+ * and store the output in ValueVectors. Throws an exception in case of errors
*
- * @param nativeHandler nativeHandler representing expressions. Created using a call to
- * buildNativeCode
- * @param numRows Number of rows in the record batch
- * @param bufAddrs An array of memory addresses. Each memory address points to a validity vector
- * or a data vector (will add support for offset vectors later).
- * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs, the size of the
- * buffer is present in bufSizes
- * @return A list of ArrowRecordBatchBuilder which can be used to build a List of ArrowRecordBatch
+ * @param nativeHandler nativeHandler representing expressions. Created using a
+ * call to buildNativeCode
+ * @param numRows Number of rows in the record batch
+ * @param bufAddrs An array of memory addresses. Each memory address points
+ * to a validity vector or a data vector (will add support
+ * for offset vectors later).
+ * @param bufSizes An array of buffer sizes. For each memory address in
+ * bufAddrs, the size of the buffer is present in bufSizes
+ * @return A list of ArrowRecordBatchBuilder which can be used to build a List
+ * of ArrowRecordBatch
*/
- native ArrowRecordBatchBuilder[] nativeEvaluate(
- long nativeHandler, int numRows, long[] bufAddrs, long[] bufSizes) throws RuntimeException;
+ native ArrowRecordBatchBuilder[] nativeEvaluate(long nativeHandler, int numRows,
+ long[] bufAddrs, long[] bufSizes) throws RuntimeException;
/**
- * Evaluate the expressions represented by the nativeHandler on a record batch and store the
- * output in ValueVectors. Throws an exception in case of errors
+ * Evaluate the expressions represented by the nativeHandler on a record batch
+ * and store the output in ValueVectors. Throws an exception in case of errors
*
- * @param nativeHandler nativeHandler representing expressions. Created using a call to
- * buildNativeCode
- * @param numRows Number of rows in the record batch
- * @param bufAddrs An array of memory addresses. Each memory address points to a validity vector
- * or a data vector (will add support for offset vectors later).
- * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs, the size of the
- * buffer is present in bufSizes
- * @param selectionVector valid selected item record count
- * @param selectionVector selectionVector memory address
+ * @param nativeHandler nativeHandler representing expressions. Created
+ * using a call to buildNativeCode
+ * @param numRows Number of rows in the record batch
+ * @param bufAddrs An array of memory addresses. Each memory address
+ * points to a validity vector or a data vector (will
+ * add support for offset vectors later).
+ * @param bufSizes An array of buffer sizes. For each memory address
+ * in bufAddrs, the size of the buffer is present in
+ * bufSizes
+ * @param selectionVector valid selected item record count
+ * @param selectionVector selectionVector memory address
* @param selectionVectorSize selectionVector total size
- * @return A list of ArrowRecordBatchBuilder which can be used to build a List of ArrowRecordBatch
+ * @return A list of ArrowRecordBatchBuilder which can be used to build a List
+ * of ArrowRecordBatch
*/
- native ArrowRecordBatchBuilder[] nativeEvaluateWithSelection(
- long nativeHandler, int numRows, long[] bufAddrs, long[] bufSizes,
- int selectionVectorRecordCount, long selectionVectorAddr, long selectionVectorSize) throws RuntimeException;
+ native ArrowRecordBatchBuilder[] nativeEvaluateWithSelection(long nativeHandler,
+ int numRows, long[] bufAddrs, long[] bufSizes, int selectionVectorRecordCount,
+ long selectionVectorAddr, long selectionVectorSize) throws RuntimeException;
native void nativeSetMember(
long nativeHandler, int numRows, long[] bufAddrs, long[] bufSizes);
/**
- * Evaluate the expressions represented by the nativeHandler on a record batch and store the
- * output in ValueVectors. Throws an exception in case of errors
+ * Evaluate the expressions represented by the nativeHandler on a record batch
+ * and store the output in ValueVectors. Throws an exception in case of errors
*
- * @param nativeHandler nativeHandler representing expressions. Created using a call to
- * buildNativeCode
- * @return A list of ArrowRecordBatchBuilder which can be used to build a List of ArrowRecordBatch
+ * @param nativeHandler nativeHandler representing expressions. Created using a
+ * call to buildNativeCode
+ * @return A list of ArrowRecordBatchBuilder which can be used to build a List
+ * of ArrowRecordBatch
*/
- native ArrowRecordBatchBuilder[] nativeFinish(long nativeHandler) throws RuntimeException;
+ native ArrowRecordBatchBuilder[] nativeFinish(long nativeHandler)
+ throws RuntimeException;
/**
* Call Finish to get result, result will be as a iterator.
@@ -128,11 +159,12 @@ native void nativeSetMember(
/**
* Set another evaluator's iterator as this one's dependency.
*
- * @param nativeHandler nativeHandler of this expression
+ * @param nativeHandler nativeHandler of this expression
* @param childInstanceId childInstanceId of a child BatchIterator
- * @param index exptected index of the output of BatchIterator
+ * @param index exptected index of the output of BatchIterator
*/
- native void nativeSetDependency(long nativeHandler, long childInstanceId, int index) throws RuntimeException;
+ native void nativeSetDependency(long nativeHandler, long childInstanceId, int index)
+ throws RuntimeException;
/**
* Closes the projector referenced by nativeHandler.
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java
index 5f80c6bfa..c750ffc4e 100644
--- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java
@@ -19,64 +19,146 @@
import java.io.File;
import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.net.JarURLConnection;
+import java.net.URL;
+import java.net.URLConnection;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
+import java.util.Enumeration;
+import java.util.jar.JarEntry;
+import java.util.jar.JarFile;
/** Helper class for JNI related operations. */
public class JniUtils {
private static final String LIBRARY_NAME = "spark_columnar_jni";
private static boolean isLoaded = false;
private static volatile JniUtils INSTANCE;
+ private String tmp_dir;
public static JniUtils getInstance() throws IOException {
+ String tmp_dir = System.getProperty("java.io.tmpdir");
+ return getInstance(tmp_dir);
+ }
+
+ public static JniUtils getInstance(String tmp_dir) throws IOException {
if (INSTANCE == null) {
synchronized (JniUtils.class) {
if (INSTANCE == null) {
try {
- INSTANCE = new JniUtils();
+ INSTANCE = new JniUtils(tmp_dir);
} catch (IllegalAccessException ex) {
throw new IOException("IllegalAccess");
}
}
}
}
-
return INSTANCE;
}
- private JniUtils() throws IOException, IllegalAccessException {
- try {
- loadLibraryFromJar();
- } catch (IOException ex) {
- System.loadLibrary(LIBRARY_NAME);
+ private JniUtils(String tmp_dir)
+ throws IOException, IllegalAccessException, IllegalStateException {
+ if (!isLoaded) {
+ try {
+ loadLibraryFromJar(tmp_dir);
+ } catch (IOException ex) {
+ System.loadLibrary(LIBRARY_NAME);
+ }
+ loadIncludeFromJar(tmp_dir);
+ isLoaded = true;
}
}
- static void loadLibraryFromJar() throws IOException, IllegalAccessException {
+ static void loadLibraryFromJar(String tmp_dir)
+ throws IOException, IllegalAccessException {
synchronized (JniUtils.class) {
- if (!isLoaded) {
- final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME);
- final File libraryFile =
- moveFileFromJarToTemp(System.getProperty("java.io.tmpdir"), libraryToLoad);
- System.load(libraryFile.getAbsolutePath());
- isLoaded = true;
+ if (tmp_dir == null) {
+ tmp_dir = System.getProperty("java.io.tmpdir");
+ System.out.println("loadLibraryFromJar " + tmp_dir);
}
+ final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME);
+ final File libraryFile = moveFileFromJarToTemp(tmp_dir, libraryToLoad);
+ System.load(libraryFile.getAbsolutePath());
}
}
- private static File moveFileFromJarToTemp(final String tmpDir, String libraryToLoad)
+ private static void loadIncludeFromJar(String tmp_dir)
+ throws IOException, IllegalAccessException {
+ synchronized (JniUtils.class) {
+ if (tmp_dir == null) {
+ tmp_dir = System.getProperty("java.io.tmpdir");
+ System.out.println("loadIncludeFromJar " + tmp_dir);
+ }
+ final String folderToLoad = "include";
+ final URLConnection urlConnection =
+ JniUtils.class.getClassLoader().getResource("include").openConnection();
+ if (urlConnection instanceof JarURLConnection) {
+ final JarFile jarFile = ((JarURLConnection) urlConnection).getJarFile();
+ copyResourcesToDirectory(jarFile, folderToLoad, tmp_dir + "/" + "nativesql_include");
+ } else {
+ throw new IOException(urlConnection.toString() + " is not JarUrlConnection");
+ }
+ }
+ }
+
+ private static File moveFileFromJarToTemp(String tmpDir, String libraryToLoad)
throws IOException {
- final File temp = File.createTempFile(tmpDir, libraryToLoad);
+ // final File temp = File.createTempFile(tmpDir, libraryToLoad);
+ final File temp = new File(tmpDir + "/" + libraryToLoad);
try (final InputStream is =
- JniUtils.class.getClassLoader().getResourceAsStream(libraryToLoad)) {
+ JniUtils.class.getClassLoader().getResourceAsStream(libraryToLoad)) {
if (is == null) {
throw new FileNotFoundException(libraryToLoad);
} else {
- Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING);
+ try {
+ Files.copy(is, temp.toPath());
+ } catch (Exception e) {
+ }
}
}
return temp;
}
+
+ /**
+ * Copies a directory from a jar file to an external directory.
+ */
+ public static void copyResourcesToDirectory(
+ JarFile fromJar, String jarDir, String destDir) throws IOException {
+ for (Enumeration entries = fromJar.entries(); entries.hasMoreElements();) {
+ JarEntry entry = entries.nextElement();
+ if (entry.getName().startsWith(jarDir + "/") && !entry.isDirectory()) {
+ File dest =
+ new File(destDir + "/" + entry.getName().substring(jarDir.length() + 1));
+ File parent = dest.getParentFile();
+ if (parent != null) {
+ parent.mkdirs();
+ }
+
+ FileOutputStream out = new FileOutputStream(dest);
+ InputStream in = fromJar.getInputStream(entry);
+
+ try {
+ byte[] buffer = new byte[8 * 1024];
+
+ int s = 0;
+ while ((s = in.read(buffer)) > 0) {
+ out.write(buffer, 0, s);
+ }
+ } catch (IOException e) {
+ throw new IOException("Could not copy asset from jar file", e);
+ } finally {
+ try {
+ in.close();
+ } catch (IOException ignored) {
+ }
+ try {
+ out.close();
+ } catch (IOException ignored) {
+ }
+ }
+ }
+ }
+ }
}
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/PartitionFileInfo.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/PartitionFileInfo.java
new file mode 100644
index 000000000..bedab8f15
--- /dev/null
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/PartitionFileInfo.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.sparkColumnarPlugin.vectorized;
+
+public class PartitionFileInfo {
+ private int pid;
+ private String filePath;
+
+ public PartitionFileInfo(int pid, String filePath) {
+ this.pid = pid;
+ this.filePath = filePath;
+ }
+
+ public int getPid() {
+ return pid;
+ }
+
+ public String getFilePath() {
+ return filePath;
+ }
+}
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleDecompressionJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleDecompressionJniWrapper.java
new file mode 100644
index 000000000..5bc46f7c8
--- /dev/null
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleDecompressionJniWrapper.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.sparkColumnarPlugin.vectorized;
+
+import java.io.IOException;
+
+public class ShuffleDecompressionJniWrapper {
+
+ public ShuffleDecompressionJniWrapper() throws IOException {
+ JniUtils.getInstance();
+ }
+
+ /**
+ * Make for multiple decompression with the same schema
+ *
+ * @param schemaBuf serialized arrow schema
+ * @return native schema holder id
+ * @throws RuntimeException
+ */
+ public native long make(byte[] schemaBuf) throws RuntimeException;
+
+ public native ArrowRecordBatchBuilder decompress(
+ long schemaHolderId,
+ String compressionCodec,
+ int numRows,
+ long[] bufAddrs,
+ long[] bufSizes,
+ long[] bufMask)
+ throws RuntimeException;
+
+ /**
+ * Release resources associated with designated schema holder instance.
+ *
+ * @param schemaHolderId of the schema holder instance.
+ */
+ public native void close(long schemaHolderId);
+}
diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleSplitterJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleSplitterJniWrapper.java
new file mode 100644
index 000000000..1cdc300a5
--- /dev/null
+++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleSplitterJniWrapper.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.sparkColumnarPlugin.vectorized;
+
+import java.io.IOException;
+
+public class ShuffleSplitterJniWrapper {
+
+ public ShuffleSplitterJniWrapper() throws IOException {
+ JniUtils.getInstance();
+ }
+
+ /**
+ * Construct native splitter for shuffled RecordBatch over
+ *
+ * @param schemaBuf serialized arrow schema
+ * @param bufferSize size of native buffers hold by each partition writer
+ * @return native splitter instance id if created successfully.
+ * @throws RuntimeException
+ */
+ public native long make(byte[] schemaBuf, long bufferSize, String localDirs)
+ throws RuntimeException;
+
+ /**
+ * Split one record batch represented by bufAddrs and bufSizes into several batches. The batch is
+ * split according to the first column as partition id. During splitting, the data in native
+ * buffers will be write to disk when the buffers are full.
+ *
+ * @param splitterId
+ * @param numRows Rows per batch
+ * @param bufAddrs Addresses of buffers
+ * @param bufSizes Sizes of buffers
+ * @throws RuntimeException
+ */
+ public native void split(long splitterId, int numRows, long[] bufAddrs, long[] bufSizes)
+ throws RuntimeException;
+
+ /**
+ * Write the data remained in the buffers hold by native splitter to each partition's temporary
+ * file. And stop processing splitting
+ *
+ * @param splitterId
+ * @throws RuntimeException
+ */
+ public native void stop(long splitterId) throws RuntimeException;
+
+ /**
+ * Set the output buffer for each partition. Splitter will maintain one buffer for each partition
+ * id occurred, and write data to file when buffer is full. Default buffer size will be set to
+ * 4096 rows.
+ *
+ * @param splitterId
+ * @param bufferSize In row, not bytes. Default buffer size will be set to 4096 rows.
+ */
+ public native void setPartitionBufferSize(long splitterId, long bufferSize);
+
+ /**
+ * Set compression codec for splitter's output. Default will be uncompressed.
+ *
+ * @param splitterId
+ * @param codec "lz4", "zstd", "uncompressed"
+ */
+ public native void setCompressionCodec(long splitterId, String codec);
+
+ /**
+ * Get all files information created by the splitter. Used by the {@link
+ * org.apache.spark.shuffle.ColumnarShuffleWriter} These files are temporarily existed and will be
+ * deleted after the combination.
+ *
+ * @param splitterId
+ * @return an array of all files information
+ */
+ public native PartitionFileInfo[] getPartitionFileInfo(long splitterId);
+
+ /**
+ * Get the total bytes written to disk.
+ *
+ * @param splitterId
+ * @return
+ */
+ public native long getTotalBytesWritten(long splitterId);
+
+ /**
+ * Release resources associated with designated splitter instance.
+ *
+ * @param splitterId of the splitter instance.
+ */
+ public native void close(long splitterId);
+}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala
index 21bca93d7..5842150fb 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala
@@ -17,8 +17,9 @@
package com.intel.sparkColumnarPlugin
-import com.intel.sparkColumnarPlugin.execution._
+import java.util.Locale
+import com.intel.sparkColumnarPlugin.execution._
import org.apache.spark.internal.Logging
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.rules.Rule
@@ -73,12 +74,20 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
logDebug(s"Columnar Processing for ${plan.getClass} is not currently supported.")
plan.withNewChildren(children)
}
- /*case plan: ShuffleExchangeExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- new ColumnarShuffleExchangeExec(
- plan.outputPartitioning,
- replaceWithColumnarPlan(plan.child),
- plan.canChangeNumPartitions)*/
+ case plan: ShuffleExchangeExec =>
+ if (columnarConf.enableColumnarShuffle) {
+ val child = replaceWithColumnarPlan(plan.child)
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ CoalesceBatchesExec(
+ new ColumnarShuffleExchangeExec(
+ plan.outputPartitioning,
+ child,
+ plan.canChangeNumPartitions))
+ } else {
+ val children = plan.children.map(replaceWithColumnarPlan)
+ logDebug(s"Columnar Processing for ${plan.getClass} is not currently supported.")
+ plan.withNewChildren(children)
+ }
case plan: ShuffledHashJoinExec =>
val left = replaceWithColumnarPlan(plan.left)
val right = replaceWithColumnarPlan(plan.right)
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala
index fd72fcaab..ee1cf7ac1 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala
@@ -22,6 +22,13 @@ import org.apache.spark.SparkConf
class ColumnarPluginConfig(conf: SparkConf) {
val enableColumnarSort: Boolean =
conf.getBoolean("spark.sql.columnar.sort", defaultValue = false)
+ val enableColumnarShuffle: Boolean = conf
+ .get("spark.shuffle.manager", "sort")
+ .equals("org.apache.spark.shuffle.sort.ColumnarShuffleManager")
+ val batchSize: Int =
+ conf.getInt("spark.sql.execution.arrow.maxRecordsPerBatch", defaultValue = 10000)
+ val tmpFile: String =
+ conf.getOption("spark.sql.columnar.tmp_dir").getOrElse(null)
}
object ColumnarPluginConfig {
@@ -41,4 +48,18 @@ object ColumnarPluginConfig {
ins
}
}
+ def getBatchSize: Int = synchronized {
+ if (ins == null) {
+ 10000
+ } else {
+ ins.batchSize
+ }
+ }
+ def getTempFile: String = synchronized {
+ if (ins != null) {
+ ins.tmpFile
+ } else {
+ System.getProperty("java.io.tmpdir")
+ }
+ }
}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/CoalesceBatchesExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/CoalesceBatchesExec.scala
new file mode 100644
index 000000000..f92eea89a
--- /dev/null
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/CoalesceBatchesExec.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.sparkColumnarPlugin.execution
+
+import java.util.concurrent.TimeUnit
+
+import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
+import org.apache.arrow.vector.util.VectorBatchAppender
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import scala.collection.mutable.ListBuffer
+
+case class CoalesceBatchesExec(child: SparkPlan) extends UnaryExecNode {
+
+ override def output: Seq[Attribute] = child.output
+
+ override def supportsColumnar: Boolean = true
+
+ override def nodeName: String = "CoalesceBatches"
+
+ override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException
+
+ override lazy val metrics: Map[String, SQLMetric] = Map(
+ "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"),
+ "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"),
+ "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"),
+ "concatTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "concat batch time total"),
+ "collectTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "collect batch time total"),
+ "avgCoalescedNumRows" -> SQLMetrics
+ .createAverageMetric(sparkContext, "avg coalesced batch num rows"))
+ // TODO: peak device memory total
+
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ import CoalesceBatchesExec._
+
+ val recordsPerBatch = conf.arrowMaxRecordsPerBatch
+ val numInputRows = longMetric("numInputRows")
+ val numInputBatches = longMetric("numInputBatches")
+ val numOutputBatches = longMetric("numOutputBatches")
+ val concatTime = longMetric("concatTime")
+ val collectTime = longMetric("collectTime")
+ val avgCoalescedNumRows = longMetric("avgCoalescedNumRows")
+
+ child.executeColumnar().mapPartitions { iter =>
+ if (iter.hasNext) {
+ new Iterator[ColumnarBatch] {
+ var target: ColumnarBatch = _
+ var numBatchesTotal: Long = _
+ var numRowsTotal: Long = _
+
+ TaskContext.get().addTaskCompletionListener[Unit] { _ =>
+ closePrevious()
+ if (numBatchesTotal > 0) {
+ avgCoalescedNumRows.set(numRowsTotal.toDouble / numBatchesTotal)
+ }
+ }
+
+ private def closePrevious(): Unit = {
+ if (target != null) {
+ target.close()
+ target = null
+ }
+ }
+
+ override def hasNext: Boolean = {
+ iter.hasNext
+ }
+
+ override def next(): ColumnarBatch = {
+ closePrevious()
+ var rowCount = 0
+ val batchesToAppend = ListBuffer[ColumnarBatch]()
+
+ val beforeCollect = System.nanoTime
+ target = iter.next()
+ target.retain()
+ rowCount += target.numRows
+
+ while (iter.hasNext && rowCount < recordsPerBatch) {
+ val delta = iter.next()
+ delta.retain()
+ rowCount += delta.numRows
+ batchesToAppend += delta
+ }
+
+ val beforeConcat = System.nanoTime
+ collectTime += beforeConcat - beforeCollect
+
+ coalesce(target, batchesToAppend.toList)
+ target.setNumRows(rowCount)
+
+ concatTime += System.nanoTime - beforeConcat
+ numInputRows += rowCount
+ numInputBatches += (1 + batchesToAppend.length)
+ numOutputBatches += 1
+
+ // used for calculating avgCoalescedNumRows
+ numRowsTotal += rowCount
+ numBatchesTotal += 1
+
+ batchesToAppend.foreach(cb => cb.close())
+
+ target
+ }
+ }
+ } else {
+ Iterator.empty
+ }
+ }
+ }
+}
+
+object CoalesceBatchesExec {
+ implicit class ArrowColumnarBatchRetainer(val cb: ColumnarBatch) {
+ def retain(): Unit = {
+ (0 until cb.numCols).toList.foreach(i =>
+ cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain())
+ }
+ }
+
+ def coalesce(targetBatch: ColumnarBatch, batchesToAppend: List[ColumnarBatch]): Unit = {
+ (0 until targetBatch.numCols).toList.foreach { i =>
+ val targetVector =
+ targetBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector
+ val vectorsToAppend = batchesToAppend.map { cb =>
+ cb.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector
+ }
+ VectorBatchAppender.batchAppend(targetVector, vectorsToAppend: _*)
+ }
+ }
+}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala
index 6acef2272..45501a95a 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala
@@ -17,6 +17,7 @@
package com.intel.sparkColumnarPlugin.execution
+import com.intel.sparkColumnarPlugin.ColumnarPluginConfig
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan}
@@ -24,9 +25,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
-class ColumnarBatchScanExec(
- output: Seq[AttributeReference],
- @transient scan: Scan) extends BatchScanExec(output, scan) {
+class ColumnarBatchScanExec(output: Seq[AttributeReference], @transient scan: Scan)
+ extends BatchScanExec(output, scan) {
+ val tmpDir = ColumnarPluginConfig.getConf(sparkContext.getConf).tmpFile
override def supportsColumnar(): Boolean = true
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
@@ -34,7 +35,8 @@ class ColumnarBatchScanExec(
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numOutputRows = longMetric("numOutputRows")
val scanTime = longMetric("scanTime")
- val inputColumnarRDD = new ColumnarDataSourceRDD(sparkContext, partitions, readerFactory, true, scanTime)
+ val inputColumnarRDD =
+ new ColumnarDataSourceRDD(sparkContext, partitions, readerFactory, true, scanTime, tmpDir)
inputColumnarRDD.map { r =>
numOutputRows += r.numRows()
r
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala
index 748deba7d..9a8cbfefc 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala
@@ -20,7 +20,11 @@ package com.intel.sparkColumnarPlugin.execution
import com.intel.sparkColumnarPlugin.vectorized._
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, PartitionReader}
+import org.apache.spark.sql.connector.read.{
+ InputPartition,
+ PartitionReaderFactory,
+ PartitionReader
+}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -28,7 +32,8 @@ import org.apache.spark.sql.execution.datasources.v2.VectorizedFilePartitionRead
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory
class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition)
- extends Partition with Serializable
+ extends Partition
+ with Serializable
// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for
// columnar scan.
@@ -37,8 +42,9 @@ class ColumnarDataSourceRDD(
@transient private val inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory,
columnarReads: Boolean,
- scanTime: SQLMetric)
- extends RDD[ColumnarBatch](sc, Nil) {
+ scanTime: SQLMetric,
+ tmp_dir: String)
+ extends RDD[ColumnarBatch](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
inputPartitions.zipWithIndex.map {
@@ -56,7 +62,7 @@ class ColumnarDataSourceRDD(
val reader = if (columnarReads) {
partitionReaderFactory match {
case factory: ParquetPartitionReaderFactory =>
- VectorizedFilePartitionReaderHandler.get(inputPartition, factory)
+ VectorizedFilePartitionReaderHandler.get(inputPartition, factory, tmp_dir)
case _ => partitionReaderFactory.createColumnarReader(inputPartition)
}
} else {
@@ -92,7 +98,8 @@ class ColumnarDataSourceRDD(
reader.get()
}
}
- val closeableColumnarBatchIterator = new CloseableColumnBatchIterator(iter.asInstanceOf[Iterator[ColumnarBatch]])
+ val closeableColumnarBatchIterator = new CloseableColumnBatchIterator(
+ iter.asInstanceOf[Iterator[ColumnarBatch]])
// TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(context, closeableColumnarBatchIterator)
}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala
index b3e2b438f..cf79c5d55 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala
@@ -67,6 +67,7 @@ class ColumnarHashAggregateExec(
resultExpressions,
child) {
+ val sparkConf = sparkContext.getConf
override def supportsColumnar = true
// Disable code generation
@@ -77,7 +78,8 @@ class ColumnarHashAggregateExec(
"numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"),
"numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of Input batches"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation process"),
- "elapseTime" -> SQLMetrics.createTimingMetric(sparkContext, "elapse time from very begin to this process"))
+ "elapseTime" -> SQLMetrics
+ .createTimingMetric(sparkContext, "elapse time from very begin to this process"))
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numOutputRows = longMetric("numOutputRows")
@@ -108,10 +110,13 @@ class ColumnarHashAggregateExec(
numOutputBatches,
numOutputRows,
aggTime,
- elapseTime)
- TaskContext.get().addTaskCompletionListener[Unit](_ => {
- aggregation.close()
- })
+ elapseTime,
+ sparkConf)
+ TaskContext
+ .get()
+ .addTaskCompletionListener[Unit](_ => {
+ aggregation.close()
+ })
new CloseableColumnBatchIterator(aggregation.createIterator(iter))
}
res
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala
index abc079b6c..dbd106c15 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala
@@ -61,15 +61,10 @@ class ColumnarShuffledHashJoinExec(
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
- right: SparkPlan) extends ShuffledHashJoinExec(
- leftKeys,
- rightKeys,
- joinType,
- buildSide,
- condition,
- left,
- right) {
+ right: SparkPlan)
+ extends ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) {
+ val sparkConf = sparkContext.getConf
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"joinTime" -> SQLMetrics.createTimingMetric(sparkContext, "join time"),
@@ -85,15 +80,30 @@ class ColumnarShuffledHashJoinExec(
val joinTime = longMetric("joinTime")
val buildTime = longMetric("buildTime")
val resultSchema = this.schema
- streamedPlan.executeColumnar().zipPartitions(buildPlan.executeColumnar()) { (streamIter, buildIter) =>
- //val hashed = buildHashedRelation(buildIter)
- //join(streamIter, hashed, numOutputRows)
- val vjoin = ColumnarShuffledHashJoin.create(leftKeys, rightKeys, resultSchema, joinType, buildSide, condition, left, right, buildTime, joinTime, numOutputRows)
- val vjoinResult = vjoin.columnarInnerJoin(streamIter, buildIter)
- TaskContext.get().addTaskCompletionListener[Unit](_ => {
- vjoin.close()
- })
- new CloseableColumnBatchIterator(vjoinResult)
+ streamedPlan.executeColumnar().zipPartitions(buildPlan.executeColumnar()) {
+ (streamIter, buildIter) =>
+ //val hashed = buildHashedRelation(buildIter)
+ //join(streamIter, hashed, numOutputRows)
+ val vjoin = ColumnarShuffledHashJoin.create(
+ leftKeys,
+ rightKeys,
+ resultSchema,
+ joinType,
+ buildSide,
+ condition,
+ left,
+ right,
+ buildTime,
+ joinTime,
+ numOutputRows,
+ sparkConf)
+ val vjoinResult = vjoin.columnarInnerJoin(streamIter, buildIter)
+ TaskContext
+ .get()
+ .addTaskCompletionListener[Unit](_ => {
+ vjoin.close()
+ })
+ new CloseableColumnBatchIterator(vjoinResult)
}
}
}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala
index ff2b2fb3a..6e58aa73b 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala
@@ -40,6 +40,8 @@ class ColumnarSortExec(
child: SparkPlan,
testSpillFrequency: Int = 0)
extends SortExec(sortOrder, global, child, testSpillFrequency) {
+
+ val sparkConf = sparkContext.getConf
override def supportsColumnar = true
// Disable code generation
@@ -72,7 +74,8 @@ class ColumnarSortExec(
numOutputBatches,
numOutputRows,
shuffleTime,
- elapse)
+ elapse,
+ sparkConf)
TaskContext
.get()
.addTaskCompletionListener[Unit](_ => {
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala
index 221562a9e..70a9c4f2c 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala
@@ -30,16 +30,8 @@ object CodeGeneration {
val timeZoneId = SQLConf.get.sessionLocalTimeZone
def getResultType(left: ArrowType, right: ArrowType): ArrowType = {
- if (left.equals(right)) {
- left
- } else {
- val left_precise_level = getPreciseLevel(left)
- val right_precise_level = getPreciseLevel(right)
- if (left_precise_level > right_precise_level)
- left
- else
- right
- }
+ //TODO(): remove this API
+ left
}
def getResultType(dataType: DataType): ArrowType = {
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala
index a2ed66892..7ea630549 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala
@@ -80,7 +80,7 @@ class ColumnarAggregateExpression(
var aggrFieldList: List[Field] = _
val (funcName, argSize, resSize) = mode match {
- case Partial =>
+ case Partial | PartialMerge =>
aggregateFunction.prettyName match {
case "avg" => ("sum_count", 1, 2)
case "count" => {
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala
index 4b0fb0810..24c4043d3 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala
@@ -23,6 +23,7 @@ import java.util.Collections
import java.util.concurrent.TimeUnit._
import util.control.Breaks._
+import com.intel.sparkColumnarPlugin.ColumnarPluginConfig
import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
import org.apache.spark.sql.util.ArrowUtils
import com.intel.sparkColumnarPlugin.vectorized.ExpressionEvaluator
@@ -30,6 +31,7 @@ import com.intel.sparkColumnarPlugin.vectorized.BatchIterator
import com.google.common.collect.Lists
import org.apache.hadoop.mapreduce.TaskAttemptID
+import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -68,9 +70,11 @@ class ColumnarAggregation(
numOutputBatches: SQLMetric,
numOutputRows: SQLMetric,
aggrTime: SQLMetric,
- elapseTime: SQLMetric)
+ elapseTime: SQLMetric,
+ sparkConf: SparkConf)
extends Logging {
// build gandiva projection here.
+ ColumnarPluginConfig.getConf(sparkConf)
var elapseTime_make: Long = 0
var rowId: Int = 0
var processedNumRows: Int = 0
@@ -103,12 +107,12 @@ class ColumnarAggregation(
// 3. map original input to aggregate input
var beforeAggregateProjector: ColumnarProjection = _
- var projectOrdinalList : List[Int] = _
- var aggregateInputAttributes : List[AttributeReference] = _
+ var projectOrdinalList: List[Int] = _
+ var aggregateInputAttributes: List[AttributeReference] = _
if (mode == null) {
projectOrdinalList = List[Int]()
- aggregateInputAttributes = List[AttributeReference]()
+ aggregateInputAttributes = List[AttributeReference]()
} else {
mode match {
case Partial => {
@@ -119,7 +123,7 @@ class ColumnarAggregation(
projectOrdinalList = beforeAggregateProjector.getOrdinalList
aggregateInputAttributes = beforeAggregateProjector.output
}
- case Final => {
+ case Final | PartialMerge => {
val ordinal_attr_list = originalInputAttributes.toList.zipWithIndex
.filter{case(expr, i) => !groupingOrdinalList.contains(i)}
.map{case(expr, i) => {
@@ -160,7 +164,14 @@ class ColumnarAggregation(
// 5. create nativeAggregate evaluator
val allNativeExpressions = groupingNativeExpression ::: aggregateNativeExpressions
val allAggregateInputFieldList = groupingFieldList ::: aggregateFieldList
- val allAggregateResultAttributes = groupingAttributes ::: aggregateAttributes.toList
+ var allAggregateResultAttributes : List[Attribute] = _
+ mode match {
+ case Partial | PartialMerge =>
+ val aggregateResultAttributes = getAttrForAggregateExpr(aggregateExpressions)
+ allAggregateResultAttributes = groupingAttributes ::: aggregateResultAttributes
+ case _ =>
+ allAggregateResultAttributes = groupingAttributes ::: aggregateAttributes.toList
+ }
val aggregateResultFieldList = allAggregateResultAttributes.map(attr => {
Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
})
@@ -169,7 +180,7 @@ class ColumnarAggregation(
val resultType = CodeGeneration.getResultType()
val resultField = Field.nullable(s"dummy_res", resultType)
- val expressionTree: List[ExpressionTree] = allNativeExpressions.map( expr => {
+ val expressionTree: List[ExpressionTree] = allNativeExpressions.map(expr => {
val node =
expr.doColumnarCodeGen_ext((groupingFieldList, allAggregateInputFieldList, resultType, resultField))
TreeBuilder.makeExpression(node, resultField)
@@ -201,6 +212,47 @@ class ColumnarAggregation(
}
}
+ def getAttrForAggregateExpr(aggregateExpressions: Seq[AggregateExpression]): List[Attribute] = {
+ var aggregateAttr = new ListBuffer[Attribute]()
+ val size = aggregateExpressions.size
+ for (expIdx <- 0 until size) {
+ val exp: AggregateExpression = aggregateExpressions(expIdx)
+ val aggregateFunc = exp.aggregateFunction
+ aggregateFunc match {
+ case Average(_) =>
+ val avg = aggregateFunc.asInstanceOf[Average]
+ val aggBufferAttr = avg.inputAggBufferAttributes
+ for (index <- 0 until aggBufferAttr.size) {
+ val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
+ aggregateAttr += attr
+ }
+ case Sum(_) =>
+ val sum = aggregateFunc.asInstanceOf[Sum]
+ val aggBufferAttr = sum.inputAggBufferAttributes
+ val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0))
+ aggregateAttr += attr
+ case Count(_) =>
+ val count = aggregateFunc.asInstanceOf[Count]
+ val aggBufferAttr = count.inputAggBufferAttributes
+ val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0))
+ aggregateAttr += attr
+ case Max(_) =>
+ val max = aggregateFunc.asInstanceOf[Max]
+ val aggBufferAttr = max.inputAggBufferAttributes
+ val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0))
+ aggregateAttr += attr
+ case Min(_) =>
+ val min = aggregateFunc.asInstanceOf[Min]
+ val aggBufferAttr = min.inputAggBufferAttributes
+ val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0))
+ aggregateAttr += attr
+ case other =>
+ throw new UnsupportedOperationException(s"not currently supported: $other.")
+ }
+ }
+ aggregateAttr.toList
+ }
+
def updateAggregationResult(columnarBatch: ColumnarBatch): Unit = {
val numRows = columnarBatch.numRows
val groupingProjectCols = groupingOrdinalList.map(i => {
@@ -290,7 +342,7 @@ class ColumnarAggregation(
if (nextCalled == false && resultColumnarBatch != null) {
return true
}
- if ( !nextBatch ) {
+ if (!nextBatch) {
return false
}
@@ -358,7 +410,8 @@ object ColumnarAggregation {
numOutputBatches: SQLMetric,
numOutputRows: SQLMetric,
aggrTime: SQLMetric,
- elapseTime: SQLMetric): ColumnarAggregation = synchronized {
+ elapseTime: SQLMetric,
+ sparkConf: SparkConf): ColumnarAggregation = synchronized {
columnarAggregation = new ColumnarAggregation(
partIndex,
groupingExpressions,
@@ -371,7 +424,8 @@ object ColumnarAggregation {
numOutputBatches,
numOutputRows,
aggrTime,
- elapseTime)
+ elapseTime,
+ sparkConf)
columnarAggregation
}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala
index 6dad05b4e..4e0d6d486 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala
@@ -37,9 +37,9 @@ import scala.collection.mutable.ListBuffer
* coalesce(1, 2) => 1
* coalesce(null, 1, 2) => 1
* coalesce(null, null, 2) => 2
+ * coalesce(null, null, null) => null
* }}}
**/
-//TODO(): coalesce(null, null, null) => null
class ColumnarCoalesce(exps: Seq[Expression], original: Expression)
extends Coalesce(exps: Seq[Expression])
@@ -47,7 +47,7 @@ class ColumnarCoalesce(exps: Seq[Expression], original: Expression)
with Logging {
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val iter: Iterator[Expression] = exps.iterator
- val exp = exps.head
+ val exp = iter.next()
val (exp_node, expType): (TreeNode, ArrowType) =
exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
@@ -63,10 +63,7 @@ class ColumnarCoalesce(exps: Seq[Expression], original: Expression)
// Return the last element no matter if it is null
val (exp_node, expType): (TreeNode, ArrowType) =
exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
- val isnotnullNode =
- TreeBuilder.makeFunction("isnotnull", Lists.newArrayList(exp_node), new ArrowType.Bool())
- val funcNode = TreeBuilder.makeIf(isnotnullNode, exp_node, exp_node, expType)
- funcNode
+ exp_node
} else {
val exp = iter.next()
val (exp_node, expType): (TreeNode, ArrowType) =
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarConcatOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarConcatOperator.scala
new file mode 100644
index 000000000..ff0f70999
--- /dev/null
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarConcatOperator.scala
@@ -0,0 +1,64 @@
+package com.intel.sparkColumnarPlugin.expression
+
+import com.google.common.collect.Lists
+import com.google.common.collect.Sets
+import org.apache.arrow.gandiva.evaluator._
+import org.apache.arrow.gandiva.exceptions.GandivaException
+import org.apache.arrow.gandiva.expression._
+import org.apache.arrow.vector.types.pojo.ArrowType
+import org.apache.arrow.vector.types.pojo.Field
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+import scala.collection.mutable.ListBuffer
+
+class ColumnarConcat(exps: Seq[Expression], original: Expression)
+ extends Concat(exps: Seq[Expression])
+ with ColumnarExpression
+ with Logging {
+ override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
+ val iter: Iterator[Expression] = exps.iterator
+ val exp = iter.next()
+ val iterFaster: Iterator[Expression] = exps.iterator
+ iterFaster.next()
+ iterFaster.next()
+
+ val (exp_node, expType): (TreeNode, ArrowType) =
+ exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
+
+ val resultType = new ArrowType.Utf8()
+ val funcNode = TreeBuilder.makeFunction("concat",
+ Lists.newArrayList(exp_node, rightNode(args, exps, iter, iterFaster)), resultType)
+ (funcNode, expType)
+ }
+
+ def rightNode(args: java.lang.Object, exps: Seq[Expression],
+ iter: Iterator[Expression], iterFaster: Iterator[Expression]): TreeNode = {
+ if (!iterFaster.hasNext) {
+ // When iter reaches the last but one expression
+ val (exp_node, expType): (TreeNode, ArrowType) =
+ exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
+ exp_node
+ } else {
+ val exp = iter.next()
+ iterFaster.next()
+ val (exp_node, expType): (TreeNode, ArrowType) =
+ exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
+ val resultType = new ArrowType.Utf8()
+ val funcNode = TreeBuilder.makeFunction("concat",
+ Lists.newArrayList(exp_node, rightNode(args, exps, iter, iterFaster)), resultType)
+ funcNode
+ }
+ }
+}
+
+object ColumnarConcatOperator {
+
+ def create(exps: Seq[Expression], original: Expression): Expression = original match {
+ case c: Concat =>
+ new ColumnarConcat(exps, original)
+ case other =>
+ throw new UnsupportedOperationException(s"not currently supported: $other.")
+ }
+}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala
index ff9cf131f..3dffba3f6 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala
@@ -25,10 +25,10 @@ object ColumnarExpressionConverter extends Logging {
var check_if_no_calculation = true
- def replaceWithColumnarExpression(expr: Expression, attributeSeq: Seq[Attribute] = null): Expression = expr match {
+ def replaceWithColumnarExpression(expr: Expression, attributeSeq: Seq[Attribute] = null, expIdx : Int = -1): Expression = expr match {
case a: Alias =>
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
- new ColumnarAlias(replaceWithColumnarExpression(a.child, attributeSeq), a.name)(
+ new ColumnarAlias(replaceWithColumnarExpression(a.child, attributeSeq, expIdx), a.name)(
a.exprId,
a.qualifier,
a.explicitMetadata)
@@ -37,7 +37,11 @@ object ColumnarExpressionConverter extends Logging {
if (attributeSeq != null) {
val bindReference = BindReferences.bindReference(expr, attributeSeq, true)
if (bindReference == expr) {
- return new ColumnarAttributeReference(a.name, a.dataType, a.nullable, a.metadata)(a.exprId, a.qualifier)
+ if (expIdx == -1) {
+ return new ColumnarAttributeReference(a.name, a.dataType, a.nullable, a.metadata)(a.exprId, a.qualifier)
+ } else {
+ return new ColumnarBoundReference(expIdx, a.dataType, a.nullable)
+ }
}
val b = bindReference.asInstanceOf[BoundReference]
new ColumnarBoundReference(b.ordinal, b.dataType, b.nullable)
@@ -144,6 +148,22 @@ object ColumnarExpressionConverter extends Logging {
case s: org.apache.spark.sql.execution.ScalarSubquery =>
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
new ColumnarScalarSubquery(s)
+ case c: Concat =>
+ check_if_no_calculation = false
+ logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
+ val exps = c.children.map{ expr =>
+ replaceWithColumnarExpression(expr, attributeSeq)
+ }
+ ColumnarConcatOperator.create(
+ exps,
+ expr)
+ case r: Round =>
+ check_if_no_calculation = false
+ logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
+ ColumnarRoundOperator.create(
+ replaceWithColumnarExpression(r.child, attributeSeq),
+ replaceWithColumnarExpression(r.scale),
+ expr)
case expr =>
logWarning(s"${expr.getClass} ${expr} is not currently supported.")
expr
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala
index b071b1aa0..b1ae938b0 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala
@@ -39,13 +39,23 @@ class ColumnarLiteral(lit: Literal)
val resultType = CodeGeneration.getResultType(dataType)
dataType match {
case t: StringType =>
- (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType)
+ value match {
+ case null =>
+ (TreeBuilder.makeStringLiteral("null": java.lang.String), resultType)
+ case _ =>
+ (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType)
+ }
case t: IntegerType =>
(TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType)
case t: LongType =>
(TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Long]), resultType)
case t: DoubleType =>
- (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType)
+ value match {
+ case null =>
+ (TreeBuilder.makeLiteral(0.0: java.lang.Double), resultType)
+ case _ =>
+ (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType)
+ }
case d: DecimalType =>
val v = value.asInstanceOf[Decimal]
(TreeBuilder.makeDecimalLiteral(v.toString, v.precision, v.scale), resultType)
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala
index f317d0407..2e899a72c 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala
@@ -73,10 +73,7 @@ class ColumnarProjection (
case (expr, i) => {
ColumnarExpressionConverter.reset()
var columnarExpr: Expression =
- ColumnarExpressionConverter.replaceWithColumnarExpression(expr, originalInputAttributes)
- if (columnarExpr.isInstanceOf[AttributeReference]) {
- columnarExpr = new ColumnarBoundReference(i, columnarExpr.dataType, columnarExpr.nullable)
- }
+ ColumnarExpressionConverter.replaceWithColumnarExpression(expr, originalInputAttributes, i)
if (ColumnarExpressionConverter.ifNoCalculation == false) {
check_if_no_calculation = false
}
@@ -87,7 +84,8 @@ class ColumnarProjection (
}
}
- val (ordinalList, arrowSchema) = if (projPrepareList.size > 0 && check_if_no_calculation == false) {
+ val (ordinalList, arrowSchema) = if (projPrepareList.size > 0 &&
+ (!check_if_no_calculation || projPrepareList.size != inputList.size)) {
val inputFieldList = inputList.asScala.toList.distinct
val schema = new Schema(inputFieldList.asJava)
projector = Projector.make(schema, projPrepareList.toList.asJava)
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarRoundOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarRoundOperator.scala
new file mode 100644
index 000000000..506ad4849
--- /dev/null
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarRoundOperator.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.sparkColumnarPlugin.expression
+
+import com.google.common.collect.Lists
+
+import org.apache.arrow.gandiva.evaluator._
+import org.apache.arrow.gandiva.exceptions.GandivaException
+import org.apache.arrow.gandiva.expression._
+import org.apache.arrow.vector.types.pojo.ArrowType
+import org.apache.arrow.vector.types.FloatingPointPrecision
+import org.apache.arrow.vector.types.pojo.Field
+import org.apache.arrow.vector.types.DateUnit
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.types._
+
+import scala.collection.mutable.ListBuffer
+
+class ColumnarRound(child: Expression, scale: Expression, original: Expression)
+ extends Round(child: Expression, scale: Expression)
+ with ColumnarExpression
+ with Logging {
+ override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
+ val (child_node, childType): (TreeNode, ArrowType) =
+ child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
+ val (scale_node, scaleType): (TreeNode, ArrowType) =
+ scale.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
+ //TODO(): get precision and scale from decimal
+ val precision = 8
+ val decimalScale = 2
+ val castNode = TreeBuilder.makeFunction("castDECIMAL", Lists.newArrayList(child_node),
+ new ArrowType.Decimal(precision, decimalScale))
+ val funcNode = TreeBuilder.makeFunction("round", Lists.newArrayList(castNode, scale_node),
+ new ArrowType.Decimal(precision, decimalScale))
+
+ val resultType = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+ val castFloat = TreeBuilder.makeFunction("castFLOAT8",
+ Lists.newArrayList(funcNode), resultType)
+ (castFloat, resultType)
+ }
+}
+
+object ColumnarRoundOperator {
+
+ def create(child: Expression, scale: Expression, original: Expression): Expression = original match {
+ case r: Round =>
+ new ColumnarRound(child, scale, original)
+ case other =>
+ throw new UnsupportedOperationException(s"not currently supported: $other.")
+ }
+}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala
index f66eb1994..26be8f85e 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala
@@ -19,6 +19,7 @@ package com.intel.sparkColumnarPlugin.expression
import java.util.concurrent.TimeUnit._
+import com.intel.sparkColumnarPlugin.ColumnarPluginConfig
import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
import org.apache.spark.TaskContext
@@ -34,6 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import scala.collection.JavaConverters._
+import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import scala.collection.mutable.ListBuffer
@@ -66,184 +68,192 @@ class ColumnarShuffledHashJoin(
right: SparkPlan,
buildTime: SQLMetric,
joinTime: SQLMetric,
- totalOutputNumRows: SQLMetric
- ) extends Logging{
-
- var build_cb : ColumnarBatch = null
- var last_cb: ColumnarBatch = null
-
- val inputBatchHolder = new ListBuffer[ColumnarBatch]()
- // TODO
- val l_input_schema: List[Attribute] = left.output.toList
- val r_input_schema: List[Attribute] = right.output.toList
- logInfo(s"\nleft_schema is ${l_input_schema}, right_schema is ${r_input_schema}, \nleftKeys is ${leftKeys}, rightKeys is ${rightKeys}, \nresultSchema is ${resultSchema}, \njoinType is ${joinType}, buildSide is ${buildSide}, condition is ${conditionOption}")
-
- val l_input_field_list: List[Field] = l_input_schema.toList.map(attr => {
- Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
- })
- val r_input_field_list: List[Field] = r_input_schema.toList.map(attr => {
- Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
- })
-
- val resultFieldList = resultSchema.map(field => {
+ totalOutputNumRows: SQLMetric,
+ sparkConf: SparkConf)
+ extends Logging {
+ ColumnarPluginConfig.getConf(sparkConf)
+
+ var build_cb: ColumnarBatch = null
+ var last_cb: ColumnarBatch = null
+
+ val inputBatchHolder = new ListBuffer[ColumnarBatch]()
+ // TODO
+ val l_input_schema: List[Attribute] = left.output.toList
+ val r_input_schema: List[Attribute] = right.output.toList
+ logInfo(
+ s"\nleft_schema is ${l_input_schema}, right_schema is ${r_input_schema}, \nleftKeys is ${leftKeys}, rightKeys is ${rightKeys}, \nresultSchema is ${resultSchema}, \njoinType is ${joinType}, buildSide is ${buildSide}, condition is ${conditionOption}")
+
+ val l_input_field_list: List[Field] = l_input_schema.toList.map(attr => {
+ Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
+ })
+ val r_input_field_list: List[Field] = r_input_schema.toList.map(attr => {
+ Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
+ })
+
+ val resultFieldList = resultSchema
+ .map(field => {
Field.nullable(field.name, CodeGeneration.getResultType(field.dataType))
- }).toList
-
- val leftKeyAttributes = leftKeys.toList.map(expr => {
- ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression]
- })
- val rightKeyAttributes = rightKeys.toList.map(expr => {
- ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression]
- })
-
- val lkeyFieldList: List[Field] = leftKeyAttributes.toList.map(expr => {
- val attr = ConverterUtils.getAttrFromExpr(expr)
- Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType))
})
+ .toList
+
+ val leftKeyAttributes = leftKeys.toList.map(expr => {
+ ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression]
+ })
+ val rightKeyAttributes = rightKeys.toList.map(expr => {
+ ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression]
+ })
+
+ val lkeyFieldList: List[Field] = leftKeyAttributes.toList.map(expr => {
+ val attr = ConverterUtils.getAttrFromExpr(expr)
+ Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType))
+ })
+
+ val rkeyFieldList: List[Field] = rightKeyAttributes.toList.map(expr => {
+ val attr = ConverterUtils.getAttrFromExpr(expr)
+ Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType))
+ })
+
+ val (
+ build_key_field_list,
+ stream_key_field_list,
+ stream_key_ordinal_list,
+ build_input_field_list,
+ stream_input_field_list) = buildSide match {
+ case BuildLeft =>
+ val stream_key_expr_list = bindReferences(rightKeyAttributes, r_input_schema)
+ (
+ lkeyFieldList,
+ rkeyFieldList,
+ stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal),
+ l_input_field_list,
+ r_input_field_list)
+
+ case BuildRight =>
+ val stream_key_expr_list = bindReferences(leftKeyAttributes, l_input_schema)
+ (
+ rkeyFieldList,
+ lkeyFieldList,
+ stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal),
+ r_input_field_list,
+ l_input_field_list)
- val rkeyFieldList: List[Field] = rightKeyAttributes.toList.map(expr => {
- val attr = ConverterUtils.getAttrFromExpr(expr)
- Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType))
- })
-
- val (build_key_field_list, stream_key_field_list, stream_key_ordinal_list, build_input_field_list, stream_input_field_list)
- = buildSide match {
- case BuildLeft =>
- val stream_key_expr_list = bindReferences(rightKeyAttributes, r_input_schema)
- (lkeyFieldList, rkeyFieldList, stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal), l_input_field_list, r_input_field_list)
-
- case BuildRight =>
- val stream_key_expr_list = bindReferences(leftKeyAttributes, l_input_schema)
- (rkeyFieldList, lkeyFieldList, stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal), r_input_field_list, l_input_field_list)
-
- }
-
- val (probe_func_name, build_output_field_list, stream_output_field_list) = joinType match {
- case _: InnerLike =>
- ("conditionedProbeArraysInner", build_input_field_list, stream_input_field_list)
- case LeftSemi =>
- ("conditionedProbeArraysSemi", List[Field](), stream_input_field_list)
- case LeftOuter =>
- ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list)
- case RightOuter =>
- ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list)
- case LeftAnti =>
- ("conditionedProbeArraysAnti", List[Field](), stream_input_field_list)
- case _ =>
- throw new UnsupportedOperationException(s"Join Type ${joinType} is not supported yet.")
- }
+ }
- val build_input_arrow_schema: Schema = new Schema(build_input_field_list.asJava)
+ val (probe_func_name, build_output_field_list, stream_output_field_list) = joinType match {
+ case _: InnerLike =>
+ ("conditionedProbeArraysInner", build_input_field_list, stream_input_field_list)
+ case LeftSemi =>
+ ("conditionedProbeArraysSemi", List[Field](), stream_input_field_list)
+ case LeftOuter =>
+ ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list)
+ case RightOuter =>
+ ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list)
+ case LeftAnti =>
+ ("conditionedProbeArraysAnti", List[Field](), stream_input_field_list)
+ case _ =>
+ throw new UnsupportedOperationException(s"Join Type ${joinType} is not supported yet.")
+ }
- val stream_input_arrow_schema: Schema = new Schema(stream_input_field_list.asJava)
+ val build_input_arrow_schema: Schema = new Schema(build_input_field_list.asJava)
- val build_output_arrow_schema: Schema = new Schema(build_output_field_list.asJava)
- val stream_output_arrow_schema: Schema = new Schema(stream_output_field_list.asJava)
+ val stream_input_arrow_schema: Schema = new Schema(stream_input_field_list.asJava)
- val stream_key_arrow_schema: Schema = new Schema(stream_key_field_list.asJava)
- var output_arrow_schema: Schema = _
+ val build_output_arrow_schema: Schema = new Schema(build_output_field_list.asJava)
+ val stream_output_arrow_schema: Schema = new Schema(stream_output_field_list.asJava)
+ val stream_key_arrow_schema: Schema = new Schema(stream_key_field_list.asJava)
+ var output_arrow_schema: Schema = _
- logInfo(s"\nbuild_key_field_list is ${build_key_field_list}, stream_key_field_list is ${stream_key_field_list}, stream_key_ordinal_list is ${stream_key_ordinal_list}, \nbuild_input_field_list is ${build_input_field_list}, stream_input_field_list is ${stream_input_field_list}, \nbuild_output_field_list is ${build_output_field_list}, stream_output_field_list is ${stream_output_field_list}")
+ logInfo(
+ s"\nbuild_key_field_list is ${build_key_field_list}, stream_key_field_list is ${stream_key_field_list}, stream_key_ordinal_list is ${stream_key_ordinal_list}, \nbuild_input_field_list is ${build_input_field_list}, stream_input_field_list is ${stream_input_field_list}, \nbuild_output_field_list is ${build_output_field_list}, stream_output_field_list is ${stream_output_field_list}")
- /////////////////////////////// Create Prober /////////////////////////////
- // Prober is used to insert left table primary key into hashMap
- // Then use iterator to probe right table primary key from hashmap
- // to get corresponding indices of left table
- //
- val condition = conditionOption match {
- case Some(c) =>
- c
- case None =>
- null
- }
- val (conditionInputFieldList, conditionOutputFieldList) = buildSide match {
- case BuildLeft =>
- (build_input_field_list, build_output_field_list ::: stream_output_field_list)
- case BuildRight =>
- (build_input_field_list, stream_output_field_list ::: build_output_field_list)
- }
- val conditionArrowSchema = new Schema(conditionInputFieldList.asJava)
- output_arrow_schema = new Schema(conditionOutputFieldList.asJava)
- var conditionInputList : java.util.List[Field] = Lists.newArrayList()
- val build_args_node = TreeBuilder.makeFunction(
- "codegen_left_schema",
- build_input_field_list.map(field => {
+ /////////////////////////////// Create Prober /////////////////////////////
+ // Prober is used to insert left table primary key into hashMap
+ // Then use iterator to probe right table primary key from hashmap
+ // to get corresponding indices of left table
+ //
+ val condition = conditionOption match {
+ case Some(c) =>
+ c
+ case None =>
+ null
+ }
+ val (conditionInputFieldList, conditionOutputFieldList) = buildSide match {
+ case BuildLeft =>
+ (build_input_field_list, build_output_field_list ::: stream_output_field_list)
+ case BuildRight =>
+ (build_input_field_list, stream_output_field_list ::: build_output_field_list)
+ }
+ val conditionArrowSchema = new Schema(conditionInputFieldList.asJava)
+ output_arrow_schema = new Schema(conditionOutputFieldList.asJava)
+ var conditionInputList: java.util.List[Field] = Lists.newArrayList()
+ val build_args_node = TreeBuilder.makeFunction(
+ "codegen_left_schema",
+ build_input_field_list
+ .map(field => {
TreeBuilder.makeField(field)
- }).asJava,
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val stream_args_node = TreeBuilder.makeFunction(
- "codegen_right_schema",
- stream_input_field_list.map(field => {
+ })
+ .asJava,
+ new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
+ val stream_args_node = TreeBuilder.makeFunction(
+ "codegen_right_schema",
+ stream_input_field_list
+ .map(field => {
TreeBuilder.makeField(field)
- }).asJava,
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val build_keys_node = TreeBuilder.makeFunction(
- "codegen_left_key_schema",
- build_key_field_list.map(field => {
+ })
+ .asJava,
+ new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
+ val build_keys_node = TreeBuilder.makeFunction(
+ "codegen_left_key_schema",
+ build_key_field_list
+ .map(field => {
TreeBuilder.makeField(field)
- }).asJava,
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val stream_keys_node = TreeBuilder.makeFunction(
- "codegen_right_key_schema",
- stream_key_field_list.map(field => {
+ })
+ .asJava,
+ new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
+ val stream_keys_node = TreeBuilder.makeFunction(
+ "codegen_right_key_schema",
+ stream_key_field_list
+ .map(field => {
TreeBuilder.makeField(field)
- }).asJava,
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val condition_expression_node_list : java.util.List[org.apache.arrow.gandiva.expression.TreeNode] =
- if (condition != null) {
- val columnarExpression: Expression =
- ColumnarExpressionConverter.replaceWithColumnarExpression(condition)
- val (condition_expression_node, resultType) =
- columnarExpression.asInstanceOf[ColumnarExpression].doColumnarCodeGen(conditionInputList)
- Lists.newArrayList(build_keys_node, stream_keys_node, condition_expression_node)
- } else {
- Lists.newArrayList(build_keys_node, stream_keys_node)
- }
- val retType = Field.nullable("res", new ArrowType.Int(32, true))
-
- // Make Expresion for conditionedProbe
- var prober = new ExpressionEvaluator()
- val condition_probe_node = TreeBuilder.makeFunction(
- probe_func_name,
- condition_expression_node_list,
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val codegen_probe_node = TreeBuilder.makeFunction(
- "codegen_withTwoInputs",
- Lists.newArrayList(condition_probe_node, build_args_node, stream_args_node),
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val condition_probe_expr = TreeBuilder.makeExpression(
- codegen_probe_node,
- retType)
-
- prober.build(build_input_arrow_schema, Lists.newArrayList(condition_probe_expr), true)
-
- /////////////////////////////// Create Shuffler /////////////////////////////
- // Shuffler will use input indices array to shuffle current table
- // output a new table with new sequence.
- //
- var build_shuffler = new ExpressionEvaluator()
- var probe_iterator: BatchIterator = _
- var build_shuffle_iterator: BatchIterator = _
-
- // Make Expresion for conditionedShuffle
- val condition_shuffle_node = TreeBuilder.makeFunction(
- "conditionedShuffleArrayList",
- Lists.newArrayList(),
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val codegen_shuffle_node = TreeBuilder.makeFunction(
- "codegen_withTwoInputs",
- Lists.newArrayList(condition_shuffle_node, build_args_node, stream_args_node),
- new ArrowType.Int(32, true)/*dummy ret type, won't be used*/)
- val condition_shuffle_expr = TreeBuilder.makeExpression(
- codegen_shuffle_node,
- retType)
- build_shuffler.build(conditionArrowSchema, Lists.newArrayList(condition_shuffle_expr), output_arrow_schema, true)
-
+ })
+ .asJava,
+ new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
+ val condition_expression_node_list
+ : java.util.List[org.apache.arrow.gandiva.expression.TreeNode] =
+ if (condition != null) {
+ val columnarExpression: Expression =
+ ColumnarExpressionConverter.replaceWithColumnarExpression(condition)
+ val (condition_expression_node, resultType) =
+ columnarExpression.asInstanceOf[ColumnarExpression].doColumnarCodeGen(conditionInputList)
+ Lists.newArrayList(build_keys_node, stream_keys_node, condition_expression_node)
+ } else {
+ Lists.newArrayList(build_keys_node, stream_keys_node)
+ }
+ val retType = Field.nullable("res", new ArrowType.Int(32, true))
+
+ // Make Expresion for conditionedProbe
+ var prober = new ExpressionEvaluator()
+ val condition_probe_node = TreeBuilder.makeFunction(
+ probe_func_name,
+ condition_expression_node_list,
+ new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
+ val codegen_probe_node = TreeBuilder.makeFunction(
+ "codegen_withTwoInputs",
+ Lists.newArrayList(condition_probe_node, build_args_node, stream_args_node),
+ new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
+ val condition_probe_expr = TreeBuilder.makeExpression(codegen_probe_node, retType)
+
+ prober.build(
+ build_input_arrow_schema,
+ Lists.newArrayList(condition_probe_expr),
+ output_arrow_schema,
+ true)
+ var probe_iterator: BatchIterator = _
def columnarInnerJoin(
- streamIter: Iterator[ColumnarBatch],
- buildIter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
+ streamIter: Iterator[ColumnarBatch],
+ buildIter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
val beforeBuild = System.nanoTime()
@@ -253,10 +263,10 @@ class ColumnarShuffledHashJoin(
}
build_cb = buildIter.next()
val build_rb = ConverterUtils.createArrowRecordBatch(build_cb)
- (0 until build_cb.numCols).toList.foreach(i => build_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain())
+ (0 until build_cb.numCols).toList.foreach(i =>
+ build_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain())
inputBatchHolder += build_cb
prober.evaluate(build_rb)
- build_shuffler.evaluate(build_rb)
ConverterUtils.releaseArrowRecordBatch(build_rb)
}
if (build_cb != null) {
@@ -279,13 +289,11 @@ class ColumnarShuffledHashJoin(
// there will be different when condition is null or not null
probe_iterator = prober.finishByIterator()
- build_shuffler.setDependency(probe_iterator)
- build_shuffle_iterator = build_shuffler.finishByIterator()
buildTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
-
+
new Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
- if(streamIter.hasNext) {
+ if (streamIter.hasNext) {
true
} else {
inputBatchHolder.foreach(cb => cb.close())
@@ -298,9 +306,8 @@ class ColumnarShuffledHashJoin(
last_cb = cb
val beforeJoin = System.nanoTime()
val stream_rb: ArrowRecordBatch = ConverterUtils.createArrowRecordBatch(cb)
- probe_iterator.processAndCacheOne(stream_input_arrow_schema, stream_rb)
- val output_rb = build_shuffle_iterator.process(stream_input_arrow_schema, stream_rb)
-
+ val output_rb = probe_iterator.process(stream_input_arrow_schema, stream_rb)
+
ConverterUtils.releaseArrowRecordBatch(stream_rb)
joinTime += NANOSECONDS.toMillis(System.nanoTime() - beforeJoin)
if (output_rb == null) {
@@ -311,22 +318,14 @@ class ColumnarShuffledHashJoin(
val outputNumRows = output_rb.getLength
val output = ConverterUtils.fromArrowRecordBatch(output_arrow_schema, output_rb)
ConverterUtils.releaseArrowRecordBatch(output_rb)
- totalOutputNumRows += outputNumRows
+ totalOutputNumRows += outputNumRows
new ColumnarBatch(output.map(v => v.asInstanceOf[ColumnVector]).toArray, outputNumRows)
- }
+ }
}
}
}
def close(): Unit = {
- if (build_shuffler != null) {
- build_shuffler.close()
- build_shuffler = null
- }
- if (build_shuffle_iterator != null) {
- build_shuffle_iterator.close()
- build_shuffle_iterator = null
- }
if (prober != null) {
prober.close()
prober = null
@@ -341,19 +340,31 @@ class ColumnarShuffledHashJoin(
object ColumnarShuffledHashJoin {
var columnarShuffedHahsJoin: ColumnarShuffledHashJoin = _
def create(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- resultSchema: StructType,
- joinType: JoinType,
- buildSide: BuildSide,
- condition: Option[Expression],
- left: SparkPlan,
- right: SparkPlan,
- buildTime: SQLMetric,
- joinTime: SQLMetric,
- numOutputRows: SQLMetric
- ): ColumnarShuffledHashJoin = synchronized {
- columnarShuffedHahsJoin = new ColumnarShuffledHashJoin(leftKeys, rightKeys, resultSchema, joinType, buildSide, condition, left, right, buildTime, joinTime, numOutputRows)
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ resultSchema: StructType,
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ buildTime: SQLMetric,
+ joinTime: SQLMetric,
+ numOutputRows: SQLMetric,
+ sparkConf: SparkConf): ColumnarShuffledHashJoin = synchronized {
+ columnarShuffedHahsJoin = new ColumnarShuffledHashJoin(
+ leftKeys,
+ rightKeys,
+ resultSchema,
+ joinType,
+ buildSide,
+ condition,
+ left,
+ right,
+ buildTime,
+ joinTime,
+ numOutputRows,
+ sparkConf)
columnarShuffedHahsJoin
}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala
index 39a0091ed..629e66253 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala
@@ -21,11 +21,13 @@ import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit._
import com.google.common.collect.Lists
+import com.intel.sparkColumnarPlugin.ColumnarPluginConfig
import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
import com.intel.sparkColumnarPlugin.vectorized.ExpressionEvaluator
import com.intel.sparkColumnarPlugin.vectorized.BatchIterator
import org.apache.spark.internal.Logging
+import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -57,10 +59,12 @@ class ColumnarSorter(
outputBatches: SQLMetric,
outputRows: SQLMetric,
shuffleTime: SQLMetric,
- elapse: SQLMetric)
+ elapse: SQLMetric,
+ sparkConf: SparkConf)
extends Logging {
logInfo(s"ColumnarSorter sortOrder is ${sortOrder}, outputAttributes is ${outputAttributes}")
+ ColumnarPluginConfig.getConf(sparkConf)
/////////////// Prepare ColumnarSorter //////////////
var processedNumRows: Long = 0
var sort_elapse: Long = 0
@@ -216,7 +220,8 @@ object ColumnarSorter {
outputBatches: SQLMetric,
outputRows: SQLMetric,
shuffleTime: SQLMetric,
- elapse: SQLMetric): ColumnarSorter = synchronized {
+ elapse: SQLMetric,
+ sparkConf: SparkConf): ColumnarSorter = synchronized {
new ColumnarSorter(
sortOrder,
outputAsColumnar,
@@ -225,7 +230,8 @@ object ColumnarSorter {
outputBatches,
outputRows,
shuffleTime,
- elapse)
+ elapse,
+ sparkConf)
}
}
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala
index 5991fc467..bed28e169 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala
@@ -52,13 +52,23 @@ class ColumnarScalarSubquery(
val resultType = CodeGeneration.getResultType(query.dataType)
query.dataType match {
case t: StringType =>
- (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType)
+ value match {
+ case null =>
+ (TreeBuilder.makeStringLiteral("null": java.lang.String), resultType)
+ case _ =>
+ (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType)
+ }
case t: IntegerType =>
(TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType)
case t: LongType =>
(TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Long]), resultType)
case t: DoubleType =>
- (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType)
+ value match {
+ case null =>
+ (TreeBuilder.makeLiteral(0.0: java.lang.Double), resultType)
+ case _ =>
+ (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType)
+ }
case d: DecimalType =>
val v = value.asInstanceOf[Decimal]
(TreeBuilder.makeDecimalLiteral(v.toString, v.precision, v.scale), resultType)
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala
index 4071a23e9..0d25f88cc 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala
@@ -132,6 +132,47 @@ class ColumnarUpper(child: Expression, original: Expression)
}
}
+class ColumnarCast(child: Expression, datatype: DataType, timeZoneId: Option[String], original: Expression)
+ extends Cast(child: Expression, datatype: DataType, timeZoneId: Option[String])
+ with ColumnarExpression
+ with Logging {
+ override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
+ val (child_node, childType): (TreeNode, ArrowType) =
+ child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
+
+ val resultType = CodeGeneration.getResultType(dataType)
+ if (dataType == StringType) {
+ //TODO: fix cast uft8
+ (child_node, childType)
+ } else if (dataType == IntegerType) {
+ val funcNode =
+ TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node), resultType)
+ (funcNode, resultType)
+ } else if (dataType == LongType) {
+ val funcNode =
+ TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node), resultType)
+ (funcNode, resultType)
+ //(child_node, childType)
+ } else if (dataType == FloatType) {
+ val funcNode =
+ TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node), resultType)
+ (funcNode, resultType)
+ } else if (dataType == DoubleType) {
+ val funcNode =
+ TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node), resultType)
+ (funcNode, resultType)
+ } else if (dataType == DateType) {
+ val funcNode =
+ TreeBuilder.makeFunction("castDATE", Lists.newArrayList(child_node), resultType)
+ (funcNode, resultType)
+ } else if (dataType == DecimalType) {
+ throw new UnsupportedOperationException(s"not currently supported: ${dataType}.")
+ } else {
+ throw new UnsupportedOperationException(s"not currently supported: ${dataType}.")
+ }
+ }
+}
+
object ColumnarUnaryOperator {
def create(child: Expression, original: Expression): Expression = original match {
@@ -148,7 +189,7 @@ object ColumnarUnaryOperator {
case u: Upper =>
new ColumnarUpper(child, u)
case c: Cast =>
- child
+ new ColumnarCast(child, c.dataType, c.timeZoneId, c)
case a: KnownFloatingPointNormalized =>
child
case a: NormalizeNaNAndZero =>
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala
index 51f7f4683..16c417003 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala
@@ -17,24 +17,19 @@
package com.intel.sparkColumnarPlugin.expression
-import java.util.concurrent.atomic.AtomicLong
-import io.netty.buffer.ArrowBuf
-
import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
-
+import io.netty.buffer.ArrowBuf
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.ipc.message.{ArrowFieldNode, ArrowRecordBatch}
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer._
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
-import org.apache.arrow.vector._
-import org.apache.arrow.vector.ipc.message.ArrowFieldNode
-import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.arrow.vector.types.pojo.Field
-import org.apache.arrow.vector.types.pojo.ArrowType
-import org.apache.spark.internal.Logging
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.ColumnarBatch
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
@@ -75,7 +70,9 @@ object ConverterUtils extends Logging {
new ArrowRecordBatch(numRowsInBatch, fieldNodes.toList.asJava, inputData.toList.asJava)
}
- def fromArrowRecordBatch(recordBatchSchema: Schema, recordBatch: ArrowRecordBatch): Array[ArrowWritableColumnVector] = {
+ def fromArrowRecordBatch(
+ recordBatchSchema: Schema,
+ recordBatch: ArrowRecordBatch): Array[ArrowWritableColumnVector] = {
val numRows = recordBatch.getLength()
ArrowWritableColumnVector.loadColumns(numRows, recordBatchSchema, recordBatch)
}
@@ -115,15 +112,24 @@ object ConverterUtils extends Logging {
getAttrFromExpr(c.children(0))
case i: IsNull =>
getAttrFromExpr(i.child)
+ case a: Add =>
+ getAttrFromExpr(a.left)
+ case s: Subtract =>
+ getAttrFromExpr(s.left)
+ case u: Upper =>
+ getAttrFromExpr(u.child)
+ case ss: Substring =>
+ getAttrFromExpr(ss.children(0))
case other =>
- throw new UnsupportedOperationException(s"makeStructField is unable to parse from $other (${other.getClass}).")
+ throw new UnsupportedOperationException(
+ s"makeStructField is unable to parse from $other (${other.getClass}).")
}
}
- def getResultAttrFromExpr(fieldExpr: Expression, name: String = "None"): AttributeReference = {
+ def getResultAttrFromExpr(fieldExpr: Expression, name: String = "None", dataType: Option[DataType]=None): AttributeReference = {
fieldExpr match {
case a: Cast =>
- getResultAttrFromExpr(a.child, name)
+ getResultAttrFromExpr(a.child, name, Some(a.dataType))
case a: AttributeReference =>
if (name != "None") {
new AttributeReference(name, a.dataType, a.nullable)()
@@ -134,9 +140,15 @@ object ConverterUtils extends Logging {
//TODO: a walkaround since we didn't support cast yet
if (a.child.isInstanceOf[Cast]) {
val tmp = if (name != "None") {
- new Alias(a.child.asInstanceOf[Cast].child, name)(a.exprId, a.qualifier, a.explicitMetadata)
+ new Alias(a.child.asInstanceOf[Cast].child, name)(
+ a.exprId,
+ a.qualifier,
+ a.explicitMetadata)
} else {
- new Alias(a.child.asInstanceOf[Cast].child, a.name)(a.exprId, a.qualifier, a.explicitMetadata)
+ new Alias(a.child.asInstanceOf[Cast].child, a.name)(
+ a.exprId,
+ a.qualifier,
+ a.explicitMetadata)
}
tmp.toAttribute.asInstanceOf[AttributeReference]
} else {
@@ -147,13 +159,22 @@ object ConverterUtils extends Logging {
a.toAttribute.asInstanceOf[AttributeReference]
}
}
+ case d: ColumnarDivide =>
+ new AttributeReference(name, DoubleType, d.nullable)()
+ case m: ColumnarMultiply =>
+ new AttributeReference(name, m.dataType, m.nullable)()
case other =>
val a = if (name != "None") {
new Alias(other, name)()
} else {
new Alias(other, "res")()
}
- a.toAttribute.asInstanceOf[AttributeReference]
+ val tmpAttr = a.toAttribute.asInstanceOf[AttributeReference]
+ if (dataType.isDefined) {
+ new AttributeReference(tmpAttr.name, dataType.getOrElse(null), tmpAttr.nullable)()
+ } else {
+ tmpAttr
+ }
}
}
@@ -166,15 +187,21 @@ object ConverterUtils extends Logging {
}
def combineArrowRecordBatch(rb1: ArrowRecordBatch, rb2: ArrowRecordBatch): ArrowRecordBatch = {
- val numRows = rb1.getLength()
- val rb1_nodes = rb1.getNodes()
- val rb2_nodes = rb2.getNodes()
- val rb1_bufferlist = rb1.getBuffers()
- val rb2_bufferlist = rb2.getBuffers()
-
- val combined_nodes = rb1_nodes.addAll(rb2_nodes)
- val combined_bufferlist = rb1_bufferlist.addAll(rb2_bufferlist)
- new ArrowRecordBatch(numRows, rb1_nodes, rb1_bufferlist)
+ val numRows = rb1.getLength()
+ val rb1_nodes = rb1.getNodes()
+ val rb2_nodes = rb2.getNodes()
+ val rb1_bufferlist = rb1.getBuffers()
+ val rb2_bufferlist = rb2.getBuffers()
+
+ val combined_nodes = rb1_nodes.addAll(rb2_nodes)
+ val combined_bufferlist = rb1_bufferlist.addAll(rb2_bufferlist)
+ new ArrowRecordBatch(numRows, rb1_nodes, rb1_bufferlist)
+ }
+
+ def toArrowSchema(attributes: Seq[Attribute]): Schema = {
+ def fromAttributes(attributes: Seq[Attribute]): StructType =
+ StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
+ ArrowUtils.toArrowSchema(fromAttributes(attributes), SQLConf.get.sessionLocalTimeZone)
}
override def toString(): String = {
@@ -182,4 +209,3 @@ object ConverterUtils extends Logging {
}
}
-
diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala
index 3a8b92ad9..f66bbf541 100644
--- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala
+++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala
@@ -20,102 +20,68 @@ package com.intel.sparkColumnarPlugin.vectorized
import java.io._
import java.nio.ByteBuffer
-import com.intel.sparkColumnarPlugin.expression.CodeGeneration
-import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
-
import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
-import org.apache.arrow.vector.types.pojo.{Field, Schema}
-import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
-import org.apache.spark.sql.util.ArrowUtils
-
+import org.apache.arrow.util.SchemaUtils
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{
DeserializationStream,
SerializationStream,
Serializer,
SerializerInstance
}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import scala.collection.JavaConverters._
-import scala.collection.immutable.List
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag
-class ArrowColumnarBatchSerializer extends Serializer with Serializable {
+class ArrowColumnarBatchSerializer(readBatchNumRows: SQLMetric = null)
+ extends Serializer
+ with Serializable {
/** Creates a new [[SerializerInstance]]. */
override def newInstance(): SerializerInstance =
- new ArrowColumnarBatchSerializerInstance
+ new ArrowColumnarBatchSerializerInstance(readBatchNumRows)
}
-private class ArrowColumnarBatchSerializerInstance extends SerializerInstance {
-
- override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
-
- override def writeValue[T: ClassTag](value: T): SerializationStream = {
- val root = createVectorSchemaRoot(value.asInstanceOf[ColumnarBatch])
- val writer = new ArrowStreamWriter(root, null, out)
- writer.writeBatch()
- writer.end()
- this
- }
-
- override def writeKey[T: ClassTag](key: T): SerializationStream = {
- // The key is only needed on the map side when computing partition ids. It does not need to
- // be shuffled.
- assert(null == key || key.isInstanceOf[Int])
- this
- }
-
- override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
- // This method is never called by shuffle code.
- throw new UnsupportedOperationException
- }
-
- override def writeObject[T: ClassTag](t: T): SerializationStream = {
- // This method is never called by shuffle code.
- throw new UnsupportedOperationException
- }
-
- override def flush(): Unit = {
- out.flush()
- }
-
- override def close(): Unit = {
- out.close()
- }
-
- private def createVectorSchemaRoot(cb: ColumnarBatch): VectorSchemaRoot = {
- val fieldTypesList = List
- .range(0, cb.numCols())
- .map(i => Field.nullable(s"c_$i", CodeGeneration.getResultType(cb.column(i).dataType())))
- val arrowSchema = new Schema(fieldTypesList.asJava)
- val vectors = List
- .range(0, cb.numCols())
- .map(
- i =>
- cb.column(i)
- .asInstanceOf[ArrowWritableColumnVector]
- .getValueVector
- .asInstanceOf[FieldVector])
- val root = new VectorSchemaRoot(arrowSchema, vectors.asJava, cb.numRows)
- root.setRowCount(cb.numRows)
- root
- }
- }
+private class ArrowColumnarBatchSerializerInstance(readBatchNumRows: SQLMetric)
+ extends SerializerInstance
+ with Logging {
override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
- private[this] val columnBatchSize = SQLConf.get.columnBatchSize
- private[this] var allocator: BufferAllocator = _
+ private val columnBatchSize = SQLConf.get.columnBatchSize
+ private val compressionEnabled =
+ SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)
+ private val compressionCodec = SparkEnv.get.conf.get("spark.io.compression.codec", "lz4")
+ private val allocator: BufferAllocator = ArrowUtils.rootAllocator
+ .newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue)
- private[this] var reader: ArrowStreamReader = _
- private[this] var root: VectorSchemaRoot = _
- private[this] var vectors: Array[ArrowWritableColumnVector] = _
+ private var reader: ArrowStreamReader = _
+ private var root: VectorSchemaRoot = _
+ private var vectors: Array[ColumnVector] = _
+ private var cb: ColumnarBatch = _
+ private var batchLoaded = true
- // TODO: see if asKeyValueIterator should be override
+ private var jniWrapper: ShuffleDecompressionJniWrapper = _
+ private var schemaHolderId: Long = 0
+ private var vectorLoader: VectorLoader = _
+
+ private var numBatchesTotal: Long = _
+ private var numRowsTotal: Long = _
+
+ override def asIterator: Iterator[Any] = {
+ // This method is never called by shuffle code.
+ throw new UnsupportedOperationException
+ }
override def readKey[T: ClassTag](): T = {
// We skipped serialization of the key in writeKey(), so just return a dummy value since
@@ -125,33 +91,60 @@ private class ArrowColumnarBatchSerializerInstance extends SerializerInstance {
@throws(classOf[EOFException])
override def readValue[T: ClassTag](): T = {
- try {
- allocator = ArrowUtils.rootAllocator
- .newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue)
- reader = new ArrowStreamReader(in, allocator)
- root = reader.getVectorSchemaRoot
- // vectors = new ArrayBuffer[ColumnVector]()
- } catch {
- case _: IOException =>
- reader.close(false)
- root.close()
+ if (reader != null && batchLoaded) {
+ root.clear()
+ if (cb != null) {
+ cb.close()
+ cb = null
+ }
+
+ try {
+ batchLoaded = reader.loadNextBatch()
+ } catch {
+ case ioe: IOException =>
+ this.close()
+ logError("Failed to load next RecordBatch", ioe)
+ throw ioe
+ }
+ if (batchLoaded) {
+ val numRows = root.getRowCount
+ logDebug(s"Read ColumnarBatch of ${numRows} rows")
+
+ numBatchesTotal += 1
+ numRowsTotal += numRows
+
+ // jni call to decompress buffers
+ if (compressionEnabled) {
+ decompressVectors()
+ }
+
+ val newFieldVectors = root.getFieldVectors.asScala.map { vector =>
+ val newVector = vector.getField.createVector(allocator)
+ vector.makeTransferPair(newVector).transfer()
+ newVector
+ }.asJava
+
+ vectors = ArrowWritableColumnVector
+ .loadColumns(numRows, newFieldVectors)
+ .toArray[ColumnVector]
+
+ cb = new ColumnarBatch(vectors, numRows)
+ cb.asInstanceOf[T]
+ } else {
+ this.close()
throw new EOFException
+ }
+ } else {
+ reader = new ArrowCompressedStreamReader(in, allocator)
+ try {
+ root = reader.getVectorSchemaRoot
+ } catch {
+ case _: IOException =>
+ this.close()
+ throw new EOFException
+ }
+ readValue()
}
-
- var numRows = 0
- // "root.rowCount" is set to 0 when reader.loadNextBatch reach EOF
- while (reader.loadNextBatch()) {
- numRows += root.getRowCount
- }
-
- assert(
- numRows <= columnBatchSize,
- "the number of loaded rows exceed the maximum columnar batch size")
-
- vectors = ArrowWritableColumnVector.loadColumns(numRows, root.getFieldVectors)
- val cb = new ColumnarBatch(vectors.toArray, numRows)
- // the underlying ColumnVectors of this ColumnarBatch might be empty
- cb.asInstanceOf[T]
}
override def readObject[T: ClassTag](): T = {
@@ -160,13 +153,65 @@ private class ArrowColumnarBatchSerializerInstance extends SerializerInstance {
}
override def close(): Unit = {
- if (reader != null) reader.close(false)
- if (root != null) root.close()
- in.close()
+ if (numBatchesTotal > 0) {
+ readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal)
+ }
+ if (cb != null) cb.close()
+ if (reader != null) reader.close(true)
+ if (jniWrapper != null) jniWrapper.close(schemaHolderId)
+ }
+
+ private def decompressVectors(): Unit = {
+ if (jniWrapper == null) {
+ jniWrapper = new ShuffleDecompressionJniWrapper
+ schemaHolderId = jniWrapper.make(SchemaUtils.get.serialize(root.getSchema))
+ }
+ if (vectorLoader == null) {
+ vectorLoader = new VectorLoader(root)
+ }
+ val bufAddrs = new ListBuffer[Long]()
+ val bufSizes = new ListBuffer[Long]()
+ val bufBS = mutable.BitSet()
+ var bufIdx = 0
+
+ root.getFieldVectors.asScala.foreach { vector =>
+ val validityBuf = vector.getValidityBuffer
+ if (validityBuf
+ .capacity() <= 8 || java.lang.Long.bitCount(validityBuf.getLong(0)) == 64 ||
+ java.lang.Long.bitCount(validityBuf.getLong(0)) == 0) {
+ bufBS.add(bufIdx)
+ }
+ vector.getBuffers(false).foreach { buffer =>
+ bufAddrs += buffer.memoryAddress()
+ // buffer.readableBytes() will return wrong readable length here since it is initialized by
+ // data stored in IPC message header, which is not the actual compressed length
+ bufSizes += buffer.capacity()
+ bufIdx += 1
+ }
+ }
+
+ val builder = jniWrapper.decompress(
+ schemaHolderId,
+ compressionCodec,
+ root.getRowCount,
+ bufAddrs.toArray,
+ bufSizes.toArray,
+ bufBS.toBitMask)
+ val builerImpl = new ArrowRecordBatchBuilderImpl(builder)
+ val decompressedRecordBatch = builerImpl.build
+
+ root.clear()
+ if (decompressedRecordBatch != null) {
+ vectorLoader.load(decompressedRecordBatch)
+ }
}
}
}
+ // Columnar shuffle write process don't need this.
+ override def serializeStream(s: OutputStream): SerializationStream =
+ throw new UnsupportedOperationException
+
// These methods are never called by shuffle code.
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
new file mode 100644
index 000000000..7b49d8c8e
--- /dev/null
+++ b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.sql.execution.metric.SQLMetric
+
+import scala.reflect.ClassTag
+
+/**
+ * :: DeveloperApi ::
+ * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
+ * the RDD is transient since we don't need it on the executor side.
+ *
+ * @param _rdd the parent RDD
+ * @param partitioner partitioner used to partition the shuffle output
+ * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If not set
+ * explicitly then the default serializer, as specified by `spark.serializer`
+ * config option, will be used.
+ * @param keyOrdering key ordering for RDD's shuffles
+ * @param aggregator map/reduce-side aggregator for RDD's shuffle
+ * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
+ * @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask
+ * @param serializedSchema serialized [[org.apache.arrow.vector.types.pojo.Schema]] for ColumnarBatch
+ * @param dataSize for shuffle data size tracking
+ */
+class ColumnarShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
+ @transient private val _rdd: RDD[_ <: Product2[K, V]],
+ override val partitioner: Partitioner,
+ override val serializer: Serializer = SparkEnv.get.serializer,
+ override val keyOrdering: Option[Ordering[K]] = None,
+ override val aggregator: Option[Aggregator[K, V, C]] = None,
+ override val mapSideCombine: Boolean = false,
+ override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
+ val serializedSchema: Array[Byte],
+ val dataSize: SQLMetric,
+ val splitTime: SQLMetric)
+ extends ShuffleDependency[K, V, C](
+ _rdd,
+ partitioner,
+ serializer,
+ keyOrdering,
+ aggregator,
+ mapSideCombine,
+ shuffleWriterProcessor) {}
diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
new file mode 100644
index 000000000..1619c4b6f
--- /dev/null
+++ b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.{File, FileInputStream, FileOutputStream, IOException}
+import java.nio.ByteBuffer
+
+import com.google.common.annotations.VisibleForTesting
+import com.google.common.io.Closeables
+import com.intel.sparkColumnarPlugin.vectorized.{
+ ArrowWritableColumnVector,
+ ShuffleSplitterJniWrapper
+}
+import org.apache.arrow.util.SchemaUtils
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.Utils
+
+import scala.collection.mutable.ListBuffer
+
+class ColumnarShuffleWriter[K, V](
+ shuffleBlockResolver: IndexShuffleBlockResolver,
+ handle: BaseShuffleHandle[K, V, V],
+ mapId: Long,
+ writeMetrics: ShuffleWriteMetricsReporter)
+ extends ShuffleWriter[K, V]
+ with Logging {
+
+ private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]
+
+ private val conf = SparkEnv.get.conf
+
+ private val blockManager = SparkEnv.get.blockManager
+
+ // Are we in the process of stopping? Because map tasks can call stop() with success = true
+ // and then call stop() with success = false if they get an exception, we want to make sure
+ // we don't try deleting files, etc twice.
+ private var stopping = false
+
+ private var mapStatus: MapStatus = _
+
+ private val transeferToEnabled = conf.getBoolean("spark.file.transferTo", true)
+ private val compressionEnabled = conf.getBoolean("spark.shuffle.compress", true)
+ private val compressionCodec = conf.get("spark.io.compression.codec", "lz4")
+ private val nativeBufferSize =
+ conf.getLong("spark.sql.execution.arrow.maxRecordsPerBatch", 4096)
+
+ private val jniWrapper = new ShuffleSplitterJniWrapper()
+
+ private var nativeSplitter: Long = 0
+
+ private var partitionLengths: Array[Long] = _
+
+ @throws[IOException]
+ override def write(records: Iterator[Product2[K, V]]): Unit = {
+ if (!records.hasNext) {
+ partitionLengths = new Array[Long](dep.partitioner.numPartitions)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, null)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
+ return
+ }
+
+ if (nativeSplitter == 0) {
+ val schema: Schema = Schema.deserialize(ByteBuffer.wrap(dep.serializedSchema))
+ val localDirs = Utils.getConfiguredLocalDirs(conf).mkString(",")
+ nativeSplitter =
+ jniWrapper.make(SchemaUtils.get.serialize(schema), nativeBufferSize, localDirs)
+ if (compressionEnabled) {
+ jniWrapper.setCompressionCodec(nativeSplitter, compressionCodec)
+ }
+ }
+
+ while (records.hasNext) {
+ val cb = records.next()._2.asInstanceOf[ColumnarBatch]
+ if (cb.numRows == 0 || cb.numCols == 0) {
+ logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols")
+ } else {
+ val bufAddrs = new ListBuffer[Long]()
+ val bufSizes = new ListBuffer[Long]()
+ (0 until cb.numCols).foreach { idx =>
+ val column = cb.column(idx).asInstanceOf[ArrowWritableColumnVector]
+ column.getValueVector
+ .getBuffers(false)
+ .foreach { buffer =>
+ bufAddrs += buffer.memoryAddress()
+ bufSizes += buffer.readableBytes()
+ }
+ }
+ dep.dataSize.add(bufSizes.sum)
+
+ val startTime = System.nanoTime()
+ jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray)
+ dep.splitTime.add(System.nanoTime() - startTime)
+ writeMetrics.incRecordsWritten(1)
+ }
+ }
+
+ val startTime = System.nanoTime()
+ jniWrapper.stop(nativeSplitter)
+ dep.splitTime.add(System.nanoTime() - startTime)
+ writeMetrics.incBytesWritten(jniWrapper.getTotalBytesWritten(nativeSplitter))
+
+ val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+ val tmp = Utils.tempFileWith(output)
+ try {
+ partitionLengths = writePartitionedFile(tmp)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
+ } finally {
+ if (tmp.exists() && !tmp.delete()) {
+ logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
+ }
+ }
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
+ }
+
+ override def stop(success: Boolean): Option[MapStatus] = {
+
+ try {
+ if (stopping) {
+ None
+ }
+ stopping = true
+ if (success) {
+ Option(mapStatus)
+ } else {
+ None
+ }
+ } finally {
+ // delete the temporary files hold by native splitter
+ if (nativeSplitter != 0) {
+ try {
+ jniWrapper.getPartitionFileInfo(nativeSplitter).foreach { fileInfo =>
+ {
+ val pid = fileInfo.getPid
+ val file = new File(fileInfo.getFilePath)
+ if (file.exists()) {
+ if (!file.delete()) {
+ logError(s"Unable to delete file for partition ${pid}")
+ }
+ }
+ }
+ }
+ } finally {
+ jniWrapper.close(nativeSplitter)
+ nativeSplitter = 0
+ }
+ }
+ }
+ }
+
+ @throws[IOException]
+ private def writePartitionedFile(outputFile: File): Array[Long] = {
+
+ val lengths = new Array[Long](dep.partitioner.numPartitions)
+ val out = new FileOutputStream(outputFile, true)
+ val writerStartTime = System.nanoTime()
+ var threwException = true
+
+ try {
+ jniWrapper.getPartitionFileInfo(nativeSplitter).foreach { fileInfo =>
+ {
+ val pid = fileInfo.getPid
+ val filePath = fileInfo.getFilePath
+
+ val file = new File(filePath)
+ if (file.exists()) {
+ val in = new FileInputStream(file)
+ var copyThrewException = true
+
+ try {
+ lengths(pid) = Utils.copyStream(in, out, false, transeferToEnabled)
+ copyThrewException = false
+ } finally {
+ Closeables.close(in, copyThrewException)
+ }
+ if (!file.delete()) {
+ logError(s"Unable to delete file for partition ${pid}")
+ } else {
+ logDebug(s"Deleting temporary shuffle file ${filePath} for partition ${pid}")
+ }
+ } else {
+ logWarning(
+ s"Native shuffle writer temporary file ${filePath} for partition ${pid} not exists")
+ }
+ }
+ }
+ threwException = false
+ } finally {
+ Closeables.close(out, threwException)
+ writeMetrics.incWriteTime(System.nanoTime - writerStartTime)
+ }
+ lengths
+ }
+
+ @VisibleForTesting
+ def getPartitionLengths: Array[Long] = partitionLengths
+
+}
diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
new file mode 100644
index 000000000..b6b191631
--- /dev/null
+++ b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.InputStream
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
+import org.apache.spark.storage.BlockId
+import org.apache.spark.util.collection.OpenHashSet
+
+import scala.collection.JavaConverters._
+
+class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+
+ import ColumnarShuffleManager._
+
+ /**
+ * A mapping from shuffle ids to the number of mappers producing output for those shuffles.
+ */
+ private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]()
+
+ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+
+ override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
+
+ /**
+ * Obtains a [[ShuffleHandle]] to pass to tasks.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ if (dependency.isInstanceOf[ColumnarShuffleDependency[K, V, V]]) {
+ logInfo(s"Registering ColumnarShuffle shuffleId: ${shuffleId}")
+ new ColumnarShuffleHandle[K, V](
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
+ // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+ // need map-side aggregation, then write numPartitions files directly and just concatenate
+ // them at the end. This avoids doing serialization and deserialization twice to merge
+ // together the spilled files, which would happen with the normal code path. The downside is
+ // having multiple files open at a time and thus more memory allocated to buffers.
+ new BypassMergeSortShuffleHandle[K, V](
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+ // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
+ new SerializedShuffleHandle[K, V](
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else {
+ // Otherwise, buffer map outputs in a deserialized form:
+ new BaseShuffleHandle(shuffleId, dependency)
+ }
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ val mapTaskIds =
+ taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16))
+ mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) }
+ val env = SparkEnv.get
+ handle match {
+ case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] =>
+ new ColumnarShuffleWriter(shuffleBlockResolver, columnarShuffleHandle, mapId, metrics)
+ case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
+ new UnsafeShuffleWriter(
+ env.blockManager,
+ context.taskMemoryManager(),
+ unsafeShuffleHandle,
+ mapId,
+ context,
+ env.conf,
+ metrics,
+ shuffleExecutorComponents)
+ case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
+ new BypassMergeSortShuffleWriter(
+ env.blockManager,
+ bypassMergeSortHandle,
+ mapId,
+ env.conf,
+ metrics,
+ shuffleExecutorComponents)
+ case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
+ new SortShuffleWriter(
+ shuffleBlockResolver,
+ other,
+ mapId,
+ context,
+ shuffleExecutorComponents)
+ }
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ val blocksByAddress = SparkEnv.get.mapOutputTracker
+ .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
+ if (handle.isInstanceOf[ColumnarShuffleHandle[K, _]]) {
+ new BlockStoreShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ blocksByAddress,
+ context,
+ metrics,
+ new SerializerManager(
+ SparkEnv.get.serializer,
+ SparkEnv.get.conf,
+ SparkEnv.get.securityManager.getIOEncryptionKey()) {
+ // Bypass the shuffle read compression
+ override def wrapStream(blockId: BlockId, s: InputStream): InputStream = {
+ wrapForEncryption(s)
+ }
+ })
+ } else {
+ new BlockStoreShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ blocksByAddress,
+ context,
+ metrics,
+ shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
+ }
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds =>
+ mapTaskIds.iterator.foreach { mapId =>
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+ }
+ }
+ true
+ }
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {
+ shuffleBlockResolver.stop()
+ }
+
+ override def getReaderForRange[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange(
+ handle.shuffleId,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition)
+ new BlockStoreShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ blocksByAddress,
+ context,
+ metrics)
+ }
+
+}
+
+object ColumnarShuffleManager extends Logging {
+ private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
+ val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
+ val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
+ executorComponents.initializeExecutor(
+ conf.getAppId,
+ SparkEnv.get.executorId,
+ extraConfigs.asJava)
+ executorComponents
+ }
+}
+
+private[spark] class ColumnarShuffleHandle[K, V](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, dependency) {}
diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index 6ff2e7bfc..3d82fe9d4 100644
--- a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -19,14 +19,28 @@ package org.apache.spark.sql.execution
import java.util.Random
-import com.intel.sparkColumnarPlugin.vectorized.ArrowColumnarBatchSerializer
-import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
+import com.intel.sparkColumnarPlugin.expression.ConverterUtils
+import com.intel.sparkColumnarPlugin.vectorized.{
+ ArrowColumnarBatchSerializer,
+ ArrowWritableColumnVector
+}
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.arrow.vector.{FieldVector, IntVector}
+import org.apache.spark._
+import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ColumnarShuffleDependency
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.expressions.{
+ Attribute,
+ AttributeReference,
+ BoundReference,
+ UnsafeProjection
+}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.exchange.Exchange
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor
import org.apache.spark.sql.execution.metric.{
SQLMetric,
@@ -34,67 +48,66 @@ import org.apache.spark.sql.execution.metric.{
SQLShuffleReadMetricsReporter,
SQLShuffleWriteMetricsReporter
}
-import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
-import org.apache.spark.{HashPartitioner, Partitioner, ShuffleDependency, TaskContext}
+import org.apache.spark.util.MutablePair
import scala.collection.JavaConverters._
-import scala.collection.mutable
-case class ColumnarShuffleExchangeExec(
+class ColumnarShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
canChangeNumPartitions: Boolean = true)
- extends Exchange {
+ extends ShuffleExchangeExec(outputPartitioning, child, canChangeNumPartitions) {
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
- private lazy val readMetrics =
+ override private[sql] lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics: Map[String, SQLMetric] = Map(
- "dataSize" -> SQLMetrics
- .createSizeMetric(sparkContext, "data size")) ++ readMetrics ++ writeMetrics
+ "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+ "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "split time"),
+ "avgReadBatchNumRows" -> SQLMetrics
+ .createAverageMetric(sparkContext, "avg read batch num rows")) ++ readMetrics ++ writeMetrics
override def nodeName: String = "ColumnarExchange"
override def supportsColumnar: Boolean = true
- // Disable code generation
- @transient lazy val inputRDD: RDD[ColumnarBatch] = child.executeColumnar()
+ @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar()
- private val serializer: Serializer = new ArrowColumnarBatchSerializer
-
- override protected def doExecute(): RDD[InternalRow] = {
- child.execute()
- }
+ private val serializer: Serializer = new ArrowColumnarBatchSerializer(
+ longMetric("avgReadBatchNumRows"))
@transient
- lazy val shuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
+ lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
ColumnarShuffleExchangeExec.prepareShuffleDependency(
- inputRDD,
+ inputColumnarRDD,
child.output,
outputPartitioning,
serializer,
- writeMetrics)
+ writeMetrics,
+ longMetric("dataSize"),
+ longMetric("splitTime"))
}
- def createShuffledRDD(partitionStartIndices: Option[Array[Int]]): ShuffledColumnarBatchRDD = {
- new ShuffledColumnarBatchRDD(shuffleDependency, readMetrics, partitionStartIndices)
+ def createColumnarShuffledRDD(
+ partitionStartIndices: Option[Array[Int]]): ShuffledColumnarBatchRDD = {
+ new ShuffledColumnarBatchRDD(columnarShuffleDependency, readMetrics, partitionStartIndices)
}
- private var cachedShuffleRDD: ShuffledColumnarBatchRDD = null
+ private var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
if (cachedShuffleRDD == null) {
- cachedShuffleRDD = createShuffledRDD(None)
+ cachedShuffleRDD = createColumnarShuffledRDD(None)
}
cachedShuffleRDD
}
}
-object ColumnarShuffleExchangeExec {
+object ColumnarShuffleExchangeExec extends Logging {
def prepareShuffleDependency(
rdd: RDD[ColumnarBatch],
@@ -102,10 +115,12 @@ object ColumnarShuffleExchangeExec {
newPartitioning: Partitioning,
serializer: Serializer,
writeMetrics: Map[String, SQLMetric],
- enableArrowColumnVector: Boolean = true)
- : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
+ dataSize: SQLMetric,
+ splitTime: SQLMetric): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
- assert(enableArrowColumnVector, "only support arrow column vector")
+ val arrowSchema: Schema =
+ ConverterUtils.toArrowSchema(
+ AttributeReference("pid", IntegerType, nullable = false)() +: outputAttributes)
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
@@ -116,125 +131,172 @@ object ColumnarShuffleExchangeExec {
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
+ case RangePartitioning(sortingExpressions, numPartitions) =>
+ // Extract only fields used for sorting to avoid collecting large fields that does not
+ // affect sorting result when deciding partition bounds in RangePartitioner
+ val rddForSampling = rdd.mapPartitionsInternal { iter =>
+ // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
+ // partition bounds. To get accurate samples, we need to copy the mutable keys.
+ iter.flatMap(batch => {
+ val rows = batch.rowIterator.asScala
+ val projection =
+ UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
+ val mutablePair = new MutablePair[InternalRow, Null]()
+ rows.map(row => mutablePair.update(projection(row).copy(), null))
+ })
+ }
+ // Construct ordering on extracted sort key.
+ val orderingAttributes = sortingExpressions.zipWithIndex.map {
+ case (ord, i) =>
+ ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
+ }
+ implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
+ new RangePartitioner(
+ numPartitions,
+ rddForSampling,
+ ascending = true,
+ samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition =>
new Partitioner {
override def numPartitions: Int = 1
override def getPartition(key: Any): Int = 0
}
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
- // TODO: Handle RangePartitioning.
// TODO: Handle BroadcastPartitioning.
}
val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1
- val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
-
- val schema = StructType.fromAttributes(outputAttributes)
-
- def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match {
- case RoundRobinPartitioning(numPartitions) =>
- // Distributes elements evenly across output partitions, starting from a random partition.
- var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions)
- (row: InternalRow) => {
- // The HashPartitioner will handle the `mod` by the number of partitions
- position += 1
- position
- }
- case h: HashPartitioning =>
- val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
- row => projection(row).getInt(0)
- case RangePartitioning(sortingExpressions, _) =>
- val projection =
- UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
- row => projection(row)
- case SinglePartition => identity
- case _ => sys.error(s"Exchange not implemented for $newPartitioning")
- }
-
- def getPartitionIdMask(cb: ColumnarBatch): Array[Int] = {
- val getPartitionKey = getPartitionKeyExtractor()
- val rowIterator = cb.rowIterator().asScala
- val partitionIdMask = new Array[Int](cb.numRows())
-
- rowIterator.zipWithIndex.foreach {
- case (row, idx) =>
- partitionIdMask(idx) = part.getPartition(getPartitionKey(row))
- }
- partitionIdMask
- }
-
- def createColumnVectors(numRows: Int): Array[WritableColumnVector] = {
- val vectors: Seq[WritableColumnVector] =
- ArrowWritableColumnVector.allocateColumns(numRows, schema)
- vectors.toArray
- }
-
- def partitionIdToColumnVectors(
- partitionCounter: PartitionCounter): Map[Int, Array[WritableColumnVector]] = {
- val columnVectorsWithPartitionId = mutable.Map[Int, Array[WritableColumnVector]]()
- for (pid <- partitionCounter.keys) {
- val numRows = partitionCounter(pid)
- if (numRows > 0) {
- columnVectorsWithPartitionId.update(pid, createColumnVectors(numRows))
- }
- }
- columnVectorsWithPartitionId.toMap[Int, Array[WritableColumnVector]]
- }
+ // RDD passed to ShuffleDependency should be the form of key-value pairs.
+ // As for Columnar Shuffle, we create a new column to store the partition ids for each row in
+ // one ColumnarBatch, and append it to the front. ColumnarShuffleWriter will extract partition ids
+ // from ColumnarBatch other than read the "key" part. Thus in Columnar Shuffle we never use the "key" part.
+ val rddWithDummyKey: RDD[Product2[Int, ColumnarBatch]] = {
+ val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
- val rddWithPartitionIds: RDD[Product2[Int, ColumnarBatch]] = {
rdd.mapPartitionsWithIndexInternal(
(_, cbIterator) => {
- val converters = new RowToColumnConverter(schema)
- cbIterator.flatMap {
- cb =>
- val partitionCounter = new PartitionCounter
- val partitionIdMask = getPartitionIdMask(cb)
- partitionCounter.update(partitionIdMask)
-
- val toNewVectors = partitionIdToColumnVectors(partitionCounter)
- val rowIterator = cb.rowIterator().asScala
- rowIterator.zipWithIndex.foreach { rowWithIdx =>
- val idx = rowWithIdx._2
- val pid = partitionIdMask(idx)
- converters.convert(rowWithIdx._1, toNewVectors(pid))
- }
- toNewVectors.toSeq.map {
- case (pid, vectors) =>
- (pid, new ColumnarBatch(vectors.toArray, partitionCounter(pid)))
- }
+ newPartitioning match {
+ case SinglePartition =>
+ CloseablePairedColumnarBatchIterator(
+ cbIterator
+ .filter(cb => cb.numRows != 0 && cb.numCols != 0)
+ .map(cb => {
+ val pids = Array.fill(cb.numRows)(0)
+ (0, pushFrontPartitionIds(pids, cb))
+ }))
+ case RoundRobinPartitioning(numPartitions) =>
+ // Distributes elements evenly across output partitions, starting from a random partition.
+ var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions)
+ CloseablePairedColumnarBatchIterator(
+ cbIterator
+ .filter(cb => cb.numRows != 0 && cb.numCols != 0)
+ .map(cb => {
+ val pids = cb.rowIterator.asScala.map { _ =>
+ position += 1
+ part.getPartition(position)
+ }.toArray
+ (0, pushFrontPartitionIds(pids, cb))
+ }))
+ case h: HashPartitioning =>
+ val projection =
+ UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
+ CloseablePairedColumnarBatchIterator(
+ cbIterator
+ .filter(cb => cb.numRows != 0 && cb.numCols != 0)
+ .map(cb => {
+ val pids = cb.rowIterator.asScala.map { row =>
+ part.getPartition(projection(row).getInt(0))
+ }.toArray
+ (0, pushFrontPartitionIds(pids, cb))
+ }))
+ case RangePartitioning(sortingExpressions, _) =>
+ val projection =
+ UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
+ CloseablePairedColumnarBatchIterator(
+ cbIterator
+ .filter(cb => cb.numRows != 0 && cb.numCols != 0)
+ .map(cb => {
+ val pids = cb.rowIterator.asScala.map { row =>
+ part.getPartition(projection(row))
+ }.toArray
+ (0, pushFrontPartitionIds(pids, cb))
+ }))
+ case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}
},
isOrderSensitive = isOrderSensitive)
}
val dependency =
- new ShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
- rddWithPartitionIds,
- new PartitionIdPassthrough(part.numPartitions),
+ new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
+ rddWithDummyKey,
+ new PartitionIdPassthrough(newPartitioning.numPartitions),
serializer,
- shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics))
+ shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics),
+ serializedSchema = arrowSchema.toByteArray,
+ dataSize = dataSize,
+ splitTime = splitTime)
dependency
}
+
+ def pushFrontPartitionIds(partitionIds: Seq[Int], cb: ColumnarBatch): ColumnarBatch = {
+ val length = partitionIds.length
+
+ val vectors = (0 until cb.numCols()).map { idx =>
+ val vector = cb
+ .column(idx)
+ .asInstanceOf[ArrowWritableColumnVector]
+ .getValueVector
+ .asInstanceOf[FieldVector]
+ vector.setValueCount(length)
+ vector
+ }
+ val pidVec = new IntVector("pid", vectors(0).getAllocator)
+
+ pidVec.allocateNew(length)
+ (0 until length).foreach { i =>
+ pidVec.set(i, partitionIds(i))
+ }
+ pidVec.setValueCount(length)
+
+ val newVectors = ArrowWritableColumnVector.loadColumns(length, (pidVec +: vectors).asJava)
+ new ColumnarBatch(newVectors.toArray, cb.numRows)
+ }
}
-private class PartitionCounter {
- private val pidCounter = mutable.Map[Int, Int]()
+case class CloseablePairedColumnarBatchIterator(iter: Iterator[(Int, ColumnarBatch)])
+ extends Iterator[(Int, ColumnarBatch)]
+ with Logging {
- def size: Int = pidCounter.size
+ private var cur: (Int, ColumnarBatch) = _
- def keys: Iterable[Int] = pidCounter.keys
+ TaskContext.get().addTaskCompletionListener[Unit] { _ =>
+ closeAppendedVector()
+ }
- def update(partitionIdMask: Array[Int]): Unit = {
- partitionIdMask.foreach { partitionId =>
- pidCounter.update(partitionId, pidCounter.get(partitionId) match {
- case Some(cnt) => cnt + 1
- case None => 1
- })
+ private def closeAppendedVector(): Unit = {
+ if (cur != null) {
+ logDebug("Close appended partition id vector")
+ cur match {
+ case (_, cb: ColumnarBatch) =>
+ cb.column(0).asInstanceOf[ArrowWritableColumnVector].close()
+ }
+ cur = null
}
}
- def apply(partitionId: Int): Int = pidCounter.getOrElse(partitionId, 0)
+ override def hasNext: Boolean = {
+ iter.hasNext
+ }
+
+ override def next(): (Int, ColumnarBatch) = {
+ closeAppendedVector()
+ if (iter.hasNext) {
+ cur = iter.next()
+ cur
+ } else Iterator.empty.next()
+ }
}
diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala
index bf265286b..1d678c0f7 100644
--- a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala
+++ b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala
@@ -26,7 +26,11 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.hadoop.fs.Path
-import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, PartitionReader}
+import org.apache.spark.sql.connector.read.{
+ InputPartition,
+ PartitionReaderFactory,
+ PartitionReader
+}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.datasources.v2.PartitionedFileReader
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory
@@ -35,48 +39,54 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
object VectorizedFilePartitionReaderHandler {
def get(
- inputPartition: InputPartition,
- parquetReaderFactory: ParquetPartitionReaderFactory)
- : FilePartitionReader[ColumnarBatch] = {
- val iter: Iterator[PartitionedFileReader[ColumnarBatch]]
- = inputPartition.asInstanceOf[FilePartition].files.toIterator.map { file =>
- val filePath = new Path(new URI(file.filePath))
- val split =
- new org.apache.parquet.hadoop.ParquetInputSplit(
- filePath,
- file.start,
- file.start + file.length,
- file.length,
- Array.empty,
- null)
- //val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion
- /*val convertTz =
+ inputPartition: InputPartition,
+ parquetReaderFactory: ParquetPartitionReaderFactory,
+ tmpDir: String): FilePartitionReader[ColumnarBatch] = {
+ val iter: Iterator[PartitionedFileReader[ColumnarBatch]] =
+ inputPartition.asInstanceOf[FilePartition].files.toIterator.map { file =>
+ val filePath = new Path(new URI(file.filePath))
+ val split =
+ new org.apache.parquet.hadoop.ParquetInputSplit(
+ filePath,
+ file.start,
+ file.start + file.length,
+ file.length,
+ Array.empty,
+ null)
+ //val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion
+ /*val convertTz =
if (timestampConversion && !isCreatedByParquetMr) {
Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key)))
} else {
None
}*/
- val capacity = 4096
- //partitionReaderFactory.createColumnarReader(inputPartition)
- val dataSchema = parquetReaderFactory.dataSchema
- val readDataSchema = parquetReaderFactory.readDataSchema
-
- val conf = parquetReaderFactory.broadcastedConf.value.value
- val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
- val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val capacity = 4096
+ //partitionReaderFactory.createColumnarReader(inputPartition)
+ val dataSchema = parquetReaderFactory.dataSchema
+ val readDataSchema = parquetReaderFactory.readDataSchema
- val vectorizedReader = new VectorizedParquetArrowReader(split.getPath().toString(), null, false, capacity, dataSchema, readDataSchema)
- vectorizedReader.initialize(split, hadoopAttemptContext)
- val partitionReader = new PartitionReader[ColumnarBatch] {
- override def next(): Boolean = vectorizedReader.nextKeyValue()
- override def get(): ColumnarBatch =
- vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
- override def close(): Unit = vectorizedReader.close()
- }
+ val conf = parquetReaderFactory.broadcastedConf.value.value
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ val vectorizedReader = new VectorizedParquetArrowReader(
+ split.getPath().toString(),
+ null,
+ false,
+ capacity,
+ dataSchema,
+ readDataSchema,
+ tmpDir)
+ vectorizedReader.initialize(split, hadoopAttemptContext)
+ val partitionReader = new PartitionReader[ColumnarBatch] {
+ override def next(): Boolean = vectorizedReader.nextKeyValue()
+ override def get(): ColumnarBatch =
+ vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
+ override def close(): Unit = vectorizedReader.close()
+ }
- PartitionedFileReader(file, partitionReader)
- }
+ PartitionedFileReader(file, partitionReader)
+ }
new FilePartitionReader[ColumnarBatch](iter)
}
}
-
diff --git a/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java b/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java
index 785382a86..701963b9b 100644
--- a/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java
+++ b/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java
@@ -46,7 +46,6 @@
import io.netty.buffer.ArrowBuf;
public class ParquetReadTest {
-
@Rule public TemporaryFolder testFolder = new TemporaryFolder();
private BufferAllocator allocator;
@@ -63,29 +62,26 @@ public void teardown() {
@Test
public void testParquetRead() throws Exception {
-
File testFile = testFolder.newFile("_tmpfile_ParquetReadTest");
- //String path = testFile.getAbsolutePath();
- String path = "hdfs://sr602:9000/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet?user=root&replication=1";
+ // String path = testFile.getAbsolutePath();
+ String path =
+ "hdfs://sr602:9000/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet?user=root&replication=1";
int numColumns = 0;
int[] rowGroupIndices = {0};
int[] columnIndices = new int[numColumns];
- Schema schema =
- new Schema(
- asList(
- field("n_nationkey", new Int(64, true)),
- field("n_name", new Utf8()),
- field("n_regionkey", new Int(64, true)),
- field("n_comment", new Utf8())
- ));
+ Schema schema = new Schema(
+ asList(field("n_nationkey", new Int(64, true)), field("n_name", new Utf8()),
+ field("n_regionkey", new Int(64, true)), field("n_comment", new Utf8())));
- ParquetReader reader = new ParquetReader(path, rowGroupIndices, columnIndices, 16, allocator);
+ ParquetReader reader =
+ new ParquetReader(path, rowGroupIndices, columnIndices, 16, allocator, "");
Schema readedSchema = reader.getSchema();
for (int i = 0; i < readedSchema.getFields().size(); i++) {
- assertEquals(schema.getFields().get(i).getName(), readedSchema.getFields().get(i).getName());
+ assertEquals(
+ schema.getFields().get(i).getName(), readedSchema.getFields().get(i).getName());
}
VectorSchemaRoot actualSchemaRoot = VectorSchemaRoot.create(readedSchema, allocator);
@@ -103,7 +99,8 @@ public void testParquetRead() throws Exception {
testFile.delete();
}
- private static Field field(String name, boolean nullable, ArrowType type, Field... children) {
+ private static Field field(
+ String name, boolean nullable, ArrowType type, Field... children) {
return new Field(name, new FieldType(nullable, type, null, null), asList(children));
}
diff --git a/oap-native-sql/core/src/test/resources/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet b/oap-native-sql/core/src/test/resources/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet
new file mode 100644
index 000000000..9eb182285
Binary files /dev/null and b/oap-native-sql/core/src/test/resources/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet differ
diff --git a/oap-native-sql/core/src/test/scala/com.intel.sparkColumnarPlugin/ExtensionSuite.scala b/oap-native-sql/core/src/test/scala/com.intel.sparkColumnarPlugin/ExtensionSuite.scala
deleted file mode 100644
index bebb60389..000000000
--- a/oap-native-sql/core/src/test/scala/com.intel.sparkColumnarPlugin/ExtensionSuite.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.sparkColumnarPlugin
-
-import org.apache.spark.sql.{Row, SparkSession}
-
-import org.scalatest.FunSuite
-
-class ExtensionSuite extends FunSuite {
-
- private def stop(spark: SparkSession): Unit = {
- spark.stop()
- SparkSession.clearActiveSession()
- SparkSession.clearDefaultSession()
- }
-
- test("inject columnar exchange") {
- val session = SparkSession
- .builder()
- .master("local[1]")
- .config("org.apache.spark.example.columnar.enabled", value = true)
- .config("spark.sql.extensions", "com.intel.sparkColumnarPlugin.ColumnarPlugin")
- .appName("inject columnar exchange")
- .getOrCreate()
-
- try {
- import session.sqlContext.implicits._
-
- val input = Seq((100), (200), (300))
- val data = input.toDF("vals").repartition(1)
- val result = data.collect()
- assert(result sameElements input.map(x => Row(x)))
- } finally {
- stop(session)
- }
- }
-}
diff --git a/oap-native-sql/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala b/oap-native-sql/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala
new file mode 100644
index 000000000..5216eb8aa
--- /dev/null
+++ b/oap-native-sql/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.nio.file.Files
+
+import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector
+import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
+import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
+import org.apache.arrow.vector.{FieldVector, IntVector}
+import org.apache.spark._
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.Utils
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.{any, anyInt, anyLong}
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.{Mock, MockitoAnnotations}
+
+import scala.collection.JavaConverters._
+
+class ColumnarShuffleWriterSuite extends SharedSparkSession {
+ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var dependency
+ : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _
+
+ override def sparkConf: SparkConf =
+ super.sparkConf
+ .setAppName("test ColumnarShuffleWriter")
+ .set("spark.file.transferTo", "true")
+ .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
+ .set("spark.sql.execution.arrow.maxRecordsPerBatch", "4096")
+
+ private var taskMetrics: TaskMetrics = _
+ private var tempDir: File = _
+ private var outputFile: File = _
+
+ private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _
+ private val schema = new Schema(
+ List(
+ Field.nullable("pid", new ArrowType.Int(32, true)),
+ Field.nullable("f_1", new ArrowType.Int(32, true)),
+ Field.nullable("f_2", new ArrowType.Int(32, true)),
+ ).asJava)
+ private val allocator = new RootAllocator(Long.MaxValue)
+
+ override def beforeEach() = {
+ super.beforeEach()
+
+ tempDir = Utils.createTempDir()
+ outputFile = File.createTempFile("shuffle", null, tempDir)
+ taskMetrics = new TaskMetrics
+
+ MockitoAnnotations.initMocks(this)
+
+ shuffleHandle =
+ new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency)
+
+ when(dependency.partitioner).thenReturn(new HashPartitioner(11))
+ when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf))
+ when(dependency.serializedSchema).thenReturn(schema.toByteArray)
+ when(dependency.dataSize)
+ .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size"))
+ when(dependency.splitTime)
+ .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "split time"))
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+ when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
+
+ doAnswer { (invocationOnMock: InvocationOnMock) =>
+ val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ if (tmp != null) {
+ outputFile.delete
+ tmp.renameTo(outputFile)
+ }
+ null
+ }.when(blockResolver)
+ .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))
+ }
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ override def afterAll(): Unit = {
+ allocator.close()
+ super.afterAll()
+ }
+
+ test("write empty iterator") {
+ val writer = new ColumnarShuffleWriter[Int, ColumnarBatch](
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ taskContext.taskMetrics().shuffleWriteMetrics)
+ writer.write(Iterator.empty)
+ writer.stop( /* success = */ true)
+
+ assert(writer.getPartitionLengths.sum === 0)
+ assert(outputFile.exists())
+ assert(outputFile.length() === 0)
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+ assert(shuffleWriteMetrics.bytesWritten === 0)
+ assert(shuffleWriteMetrics.recordsWritten === 0)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ test("write empty column batch") {
+ val vectorPid = new IntVector("pid", allocator)
+ val vector1 = new IntVector("v1", allocator)
+ val vector2 = new IntVector("v2", allocator)
+
+ ColumnarShuffleWriterSuite.setIntVector(vectorPid)
+ ColumnarShuffleWriterSuite.setIntVector(vector1)
+ ColumnarShuffleWriterSuite.setIntVector(vector2)
+ val cb = ColumnarShuffleWriterSuite.makeColumnarBatch(
+ vectorPid.getValueCount,
+ List(vectorPid, vector1, vector2))
+
+ def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb), (0, cb))
+
+ val writer = new ColumnarShuffleWriter[Int, ColumnarBatch](
+ blockResolver,
+ shuffleHandle,
+ 0L, // MapId
+ taskContext.taskMetrics().shuffleWriteMetrics)
+
+ writer.write(records)
+ writer.stop(success = true)
+ cb.close()
+
+ assert(writer.getPartitionLengths.sum === 0)
+ assert(outputFile.exists())
+ assert(outputFile.length() === 0)
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+ assert(shuffleWriteMetrics.bytesWritten === 0)
+ assert(shuffleWriteMetrics.recordsWritten === 0)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ test("write with some empty partitions") {
+ val vectorPid = new IntVector("pid", allocator)
+ val vector1 = new IntVector("v1", allocator)
+ val vector2 = new IntVector("v2", allocator)
+ ColumnarShuffleWriterSuite.setIntVector(vectorPid, 1, 2, 1, 10)
+ ColumnarShuffleWriterSuite.setIntVector(vector1, null, null, null, null)
+ ColumnarShuffleWriterSuite.setIntVector(vector2, 100, 100, null, null)
+ val cb = ColumnarShuffleWriterSuite.makeColumnarBatch(
+ vectorPid.getValueCount,
+ List(vectorPid, vector1, vector2))
+
+ def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb), (0, cb))
+
+ val writer = new ColumnarShuffleWriter[Int, ColumnarBatch](
+ blockResolver,
+ shuffleHandle,
+ 0L, // MapId
+ taskContext.taskMetrics().shuffleWriteMetrics)
+
+ writer.write(records)
+ writer.stop(success = true)
+
+ assert(writer.getPartitionLengths.sum === outputFile.length())
+ assert(writer.getPartitionLengths.count(_ == 0L) === 8) // should be (11 - 3) zero length files
+
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+ assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
+ assert(shuffleWriteMetrics.recordsWritten === records.length)
+
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+
+ val bytes = Files.readAllBytes(outputFile.toPath)
+ val reader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(bytes), allocator)
+ try {
+ val schema = reader.getVectorSchemaRoot.getSchema
+ assert(schema.getFields == this.schema.getFields.subList(1, this.schema.getFields.size()))
+ } finally {
+ reader.close()
+ cb.close()
+ }
+ }
+}
+
+object ColumnarShuffleWriterSuite {
+
+ def setIntVector(vector: IntVector, values: Integer*): Unit = {
+ val length = values.length
+ vector.allocateNew(length)
+ (0 until length).foreach { i =>
+ if (values(i) != null) {
+ vector.set(i, values(i).asInstanceOf[Int])
+ }
+ }
+ vector.setValueCount(length)
+ }
+
+ def makeColumnarBatch(capacity: Int, vectors: List[FieldVector]): ColumnarBatch = {
+ val columnVectors = ArrowWritableColumnVector.loadColumns(capacity, vectors.asJava)
+ new ColumnarBatch(columnVectors.toArray, capacity)
+ }
+
+}
diff --git a/oap-native-sql/core/src/test/scala/org/apache/spark/sql/RepartitionSuite.scala b/oap-native-sql/core/src/test/scala/org/apache/spark/sql/RepartitionSuite.scala
new file mode 100644
index 000000000..7bb6e5cb9
--- /dev/null
+++ b/oap-native-sql/core/src/test/scala/org/apache/spark/sql/RepartitionSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.{
+ ColumnarShuffleExchangeExec,
+ ColumnarToRowExec,
+ RowToColumnarExec
+}
+import org.apache.spark.sql.test.SharedSparkSession
+
+class RepartitionSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ override def sparkConf: SparkConf =
+ super.sparkConf
+ .setAppName("test repartition")
+ .set("spark.sql.parquet.columnarReaderBatchSize", "4096")
+ .set("spark.sql.sources.useV1SourceList", "avro")
+ .set("spark.sql.extensions", "com.intel.sparkColumnarPlugin.ColumnarPlugin")
+ .set("spark.sql.execution.arrow.maxRecordsPerBatch", "4096")
+ .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
+
+ def checkCoulumnarExec(data: DataFrame) = {
+ val found = data.queryExecution.executedPlan
+ .collect {
+ case r2c: RowToColumnarExec => 1
+ case c2r: ColumnarToRowExec => 10
+ case exc: ColumnarShuffleExchangeExec => 100
+ }
+ .distinct
+ .sum
+ assert(found == 110)
+ }
+
+ def withInput(input: DataFrame)(
+ transformation: Option[DataFrame => DataFrame],
+ repartition: DataFrame => DataFrame): Unit = {
+ val expected = transformation.getOrElse(identity[DataFrame](_))(input)
+ val data = repartition(expected)
+ checkCoulumnarExec(data)
+ checkAnswer(data, expected)
+ }
+
+ lazy val input: DataFrame = Seq((1, "1"), (2, "20"), (3, "300")).toDF("id", "val")
+
+ def withTransformationAndRepartition(
+ transformation: DataFrame => DataFrame,
+ repartition: DataFrame => DataFrame): Unit =
+ withInput(input)(Some(transformation), repartition)
+
+ def withRepartition: (DataFrame => DataFrame) => Unit = withInput(input)(None, _)
+}
+
+class SmallDataRepartitionSuite extends RepartitionSuite {
+ import testImplicits._
+
+ test("test round robin partitioning") {
+ withRepartition(df => df.repartition(2))
+ }
+
+ test("test hash partitioning") {
+ withRepartition(df => df.repartition('id))
+ }
+
+ test("test range partitioning") {
+ withRepartition(df => df.repartitionByRange('id))
+ }
+
+ ignore("test cached repartiiton") {
+ val data = input.cache.repartition(2)
+
+ val found = data.queryExecution.executedPlan.collect {
+ case cache: InMemoryTableScanExec => 1
+ case c2r: ColumnarToRowExec => 10
+ case exc: ColumnarShuffleExchangeExec => 100
+ }.sum
+ assert(found == 111)
+
+ checkAnswer(data, input)
+ }
+}
+
+class TPCHTableRepartitionSuite extends RepartitionSuite {
+ import testImplicits._
+
+ val filePath = getClass.getClassLoader
+ .getResource("part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet")
+ .getFile
+
+ override lazy val input = spark.read.parquet(filePath)
+
+ test("test tpch round robin partitioning") {
+ withRepartition(df => df.repartition(2))
+ }
+
+ test("test tpch hash partitioning") {
+ withRepartition(df => df.repartition('n_nationkey))
+ }
+
+ test("test tpch range partitioning") {
+ withRepartition(df => df.repartitionByRange('n_name))
+ }
+
+ test("test tpch sum after repartition") {
+ withTransformationAndRepartition(
+ df => df.groupBy("n_regionkey").agg(Map("n_nationkey" -> "sum")),
+ df => df.repartition(2))
+ }
+}
+
+class DisableColumnarShuffle extends SmallDataRepartitionSuite {
+ override def sparkConf: SparkConf = {
+ super.sparkConf
+ .set("spark.shuffle.manager", "sort")
+ .set("spark.sql.codegen.wholeStage", "true")
+ }
+
+ override def checkCoulumnarExec(data: DataFrame) = {
+ val found = data.queryExecution.executedPlan
+ .collect {
+ case c2r: ColumnarToRowExec => 1
+ case exc: ColumnarShuffleExchangeExec => 10
+ case exc: ShuffleExchangeExec => 100
+ }
+ .distinct
+ .sum
+ assert(found == 101)
+ }
+}
diff --git a/oap-native-sql/cpp/.gitignore b/oap-native-sql/cpp/.gitignore
new file mode 100644
index 000000000..22da1f13b
--- /dev/null
+++ b/oap-native-sql/cpp/.gitignore
@@ -0,0 +1,24 @@
+thirdparty/*.tar*
+CMakeFiles/
+CMakeCache.txt
+CTestTestfile.cmake
+Makefile
+cmake_install.cmake
+build/
+*-build/
+Testing/
+cmake-build-debug/
+cmake-build-release/
+
+#########################################
+# Editor temporary/working/backup files #
+.#*
+*\#*\#
+[#]*#
+*~
+*$
+*.bak
+*flymake*
+*.kdev4
+*.log
+*.swp
diff --git a/oap-native-sql/cpp/compile.sh b/oap-native-sql/cpp/compile.sh
new file mode 100755
index 000000000..8f89e491b
--- /dev/null
+++ b/oap-native-sql/cpp/compile.sh
@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+
+set -eu
+
+CURRENT_DIR=$(cd "$(dirname "$BASH_SOURCE")"; pwd)
+echo $CURRENT_DIR
+
+cd ${CURRENT_DIR}
+if [ -d build ]; then
+ rm -r build
+fi
+mkdir build
+cd build
+cmake ..
+make
+
+set +eu
+
diff --git a/oap-native-sql/cpp/src/CMakeLists.txt b/oap-native-sql/cpp/src/CMakeLists.txt
index 06352cae2..604a9206d 100644
--- a/oap-native-sql/cpp/src/CMakeLists.txt
+++ b/oap-native-sql/cpp/src/CMakeLists.txt
@@ -18,6 +18,11 @@ option(TESTS "Build the tests" OFF)
option(BENCHMARKS "Build the benchmarks" OFF)
option(DEBUG "Enable Debug Info" OFF)
+# same as the version required in arrow/ci/conda_env_cpp.yml
+set(BOOST_MIN_VERSION "1.42.0")
+find_package(Boost REQUIRED)
+INCLUDE_DIRECTORIES(${Boost_INCLUDE_DIRS})
+
find_package(JNI REQUIRED)
set(source_root_directory ${CMAKE_CURRENT_SOURCE_DIR})
@@ -133,8 +138,8 @@ endif()
if(BENCHMARKS)
find_package(GTest)
- #add_definitions(-DBENCHMARK_FILE_PATH=${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/)
- add_compile_definitions(BENCHMARK_FILE_PATH="file://${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/")
+ add_definitions(-DBENCHMARK_FILE_PATH="file://${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/")
+ #add_compile_definitions(BENCHMARK_FILE_PATH="file://${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/")
macro(package_add_benchmark TESTNAME)
#configure_file(${ARGN}.in ${ARGN})
add_executable(${TESTNAME} ${ARGN})
@@ -167,6 +172,22 @@ if(NOT GANDIVA_LIB)
message(FATAL_ERROR "Gandiva library not found")
endif()
+set(CODEGEN_HEADERS
+ third_party/
+ )
+file(MAKE_DIRECTORY ${root_directory}/releases/include)
+file(MAKE_DIRECTORY ${root_directory}/releases/include/codegen/common/)
+file(MAKE_DIRECTORY ${root_directory}/releases/include/codegen/third_party/)
+file(MAKE_DIRECTORY ${root_directory}/releases/include/codegen/arrow_compute/ext/)
+file(COPY third_party/ DESTINATION ${root_directory}/releases/include/)
+file(COPY third_party/ DESTINATION ${root_directory}/releases/include/third_party/)
+file(COPY codegen/arrow_compute/ext/array_item_index.h DESTINATION ${root_directory}/releases/include/codegen/arrow_compute/ext/)
+file(COPY codegen/arrow_compute/ext/code_generator_base.h DESTINATION ${root_directory}/releases/include/codegen/arrow_compute/ext/)
+file(COPY codegen/arrow_compute/ext/kernels_ext.h DESTINATION ${root_directory}/releases/include/codegen/arrow_compute/ext/)
+file(COPY codegen/common/result_iterator.h DESTINATION ${root_directory}/releases/include/codegen/common/)
+
+add_definitions(-DNATIVESQL_SRC_PATH="${root_directory}/releases")
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations")
set(SPARK_COLUMNAR_PLUGIN_SRCS
jni/jni_wrapper.cc
${PROTO_SRCS}
@@ -174,15 +195,13 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
proto/protobuf_utils.cc
codegen/expr_visitor.cc
codegen/arrow_compute/expr_visitor.cc
- codegen/arrow_compute/ext/item_iterator.cc
- codegen/arrow_compute/ext/shuffle_v2_action.cc
- codegen/arrow_compute/ext/codegen_node_visitor.cc
- codegen/arrow_compute/ext/codegen_node_visitor_v2.cc
- codegen/arrow_compute/ext/conditioned_shuffle_kernel.cc
- codegen/arrow_compute/ext/conditioned_probe_kernel.cc
+ codegen/arrow_compute/ext/probe_kernel.cc
codegen/arrow_compute/ext/sort_kernel.cc
codegen/arrow_compute/ext/kernels_ext.cc
codegen/arrow_compute/ext/codegen_common.cc
+ codegen/arrow_compute/ext/codegen_node_visitor.cc
+ shuffle/splitter.cc
+ shuffle/partition_writer.cc
)
file(MAKE_DIRECTORY ${root_directory}/releases)
diff --git a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc
index 51700efb3..258cc11e2 100644
--- a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc
+++ b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc
@@ -139,14 +139,6 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) {
"codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32());
auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_indices);
- auto n_conditionedShuffleArrayList =
- TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32());
- auto n_codegen_shuffle = TreeExprBuilder::MakeFunction(
- "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right},
- uint32());
-
- auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res);
-
auto schema_table_0 = arrow::schema(left_field_list);
auto schema_table_1 = arrow::schema(right_field_list);
std::vector> field_list(left_field_list.size() +
@@ -160,7 +152,6 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) {
std::vector> dummy_result_batches;
std::shared_ptr> probe_result_iterator;
- std::shared_ptr> shuffle_result_iterator;
////////////////////// evaluate //////////////////////
std::shared_ptr left_record_batch;
@@ -169,16 +160,14 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) {
uint64_t elapse_left_read = 0;
uint64_t elapse_right_read = 0;
uint64_t elapse_eval = 0;
+ uint64_t elapse_finish = 0;
uint64_t elapse_probe_process = 0;
uint64_t elapse_shuffle_process = 0;
uint64_t num_batches = 0;
uint64_t num_rows = 0;
TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(left_schema, {probeArrays_expr},
- {f_indices}, &expr_probe, true));
- TIME_MICRO_OR_THROW(
- elapse_gen, CreateCodeGenerator(schema_table_0, {conditionShuffleExpr}, field_list,
- &expr_shuffle, true));
+ field_list, &expr_probe, true));
do {
TIME_MICRO_OR_THROW(elapse_left_read,
@@ -186,16 +175,12 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) {
if (left_record_batch) {
TIME_MICRO_OR_THROW(elapse_eval,
expr_probe->evaluate(left_record_batch, &dummy_result_batches));
- TIME_MICRO_OR_THROW(
- elapse_eval, expr_shuffle->evaluate(left_record_batch, &dummy_result_batches));
num_batches += 1;
}
} while (left_record_batch);
std::cout << "Readed left table with " << num_batches << " batches." << std::endl;
- TIME_MICRO_OR_THROW(elapse_eval, expr_probe->finish(&probe_result_iterator));
- TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->SetDependency(probe_result_iterator));
- TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->finish(&shuffle_result_iterator));
+ TIME_MICRO_OR_THROW(elapse_finish, expr_probe->finish(&probe_result_iterator));
num_batches = 0;
uint64_t num_output_batches = 0;
@@ -209,9 +194,7 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) {
right_column_vector.push_back(right_record_batch->column(i));
}
TIME_MICRO_OR_THROW(elapse_probe_process,
- probe_result_iterator->ProcessAndCacheOne(right_column_vector));
- TIME_MICRO_OR_THROW(elapse_shuffle_process,
- shuffle_result_iterator->Process(right_column_vector, &out));
+ probe_result_iterator->Process(right_column_vector, &out));
num_batches += 1;
num_output_batches++;
num_rows += out->num_rows();
@@ -219,16 +202,17 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) {
} while (right_record_batch);
std::cout << "Readed right table with " << num_batches << " batches." << std::endl;
- std::cout << "BenchmarkArrowComputeJoin processed " << num_batches
- << " batches, then output " << num_output_batches << " batches with "
- << num_rows << " rows, to complete, it took " << TIME_TO_STRING(elapse_gen)
- << " doing codegen, took " << TIME_TO_STRING(elapse_left_read)
- << " doing left BatchRead, took " << TIME_TO_STRING(elapse_right_read)
- << " doing right BatchRead, took " << TIME_TO_STRING(elapse_eval)
- << " doing left table hashmap insert, took "
- << TIME_TO_STRING(elapse_probe_process) << " doing probe indice fetch, took "
- << TIME_TO_STRING(elapse_shuffle_process) << " doing final shuffle."
- << std::endl;
+ std::cout << "=========================================="
+ << "\nBenchmarkArrowComputeJoin processed " << num_batches << " batches"
+ << "\noutput " << num_output_batches << " batches with " << num_rows
+ << " rows"
+ << "\nCodeGen took " << TIME_TO_STRING(elapse_gen)
+ << "\nLeft Batch Read took " << TIME_TO_STRING(elapse_left_read)
+ << "\nRight Batch Read took " << TIME_TO_STRING(elapse_right_read)
+ << "\nLeft Table Hash Insert took " << TIME_TO_STRING(elapse_eval)
+ << "\nMake Result Iterator took " << TIME_TO_STRING(elapse_finish)
+ << "\nProbe and Shuffle took " << TIME_TO_STRING(elapse_probe_process) << "\n"
+ << "===========================================" << std::endl;
}
TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) {
@@ -266,14 +250,6 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) {
"codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32());
auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_indices);
- auto n_conditionedShuffleArrayList =
- TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32());
- auto n_codegen_shuffle = TreeExprBuilder::MakeFunction(
- "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right},
- uint32());
-
- auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res);
-
auto schema_table_0 = arrow::schema(left_field_list);
auto schema_table_1 = arrow::schema(right_field_list);
std::vector> field_list(left_field_list.size() +
@@ -283,32 +259,25 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) {
auto schema_table = arrow::schema(field_list);
///////////////////// Calculation //////////////////
std::shared_ptr expr_probe;
- std::shared_ptr expr_shuffle;
std::vector> dummy_result_batches;
std::shared_ptr> probe_result_iterator;
- std::shared_ptr> shuffle_result_iterator;
////////////////////// evaluate //////////////////////
std::shared_ptr left_record_batch;
std::shared_ptr right_record_batch;
uint64_t elapse_gen = 0;
- auto n_codegen = TreeExprBuilder::MakeFunction(
- "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right},
- uint32());
uint64_t elapse_left_read = 0;
uint64_t elapse_right_read = 0;
uint64_t elapse_eval = 0;
+ uint64_t elapse_finish = 0;
uint64_t elapse_probe_process = 0;
uint64_t elapse_shuffle_process = 0;
uint64_t num_batches = 0;
uint64_t num_rows = 0;
TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(left_schema, {probeArrays_expr},
- {f_indices}, &expr_probe, true));
- TIME_MICRO_OR_THROW(
- elapse_gen, CreateCodeGenerator(schema_table_0, {conditionShuffleExpr}, field_list,
- &expr_shuffle, true));
+ field_list, &expr_probe, true));
do {
TIME_MICRO_OR_THROW(elapse_left_read,
@@ -316,16 +285,12 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) {
if (left_record_batch) {
TIME_MICRO_OR_THROW(elapse_eval,
expr_probe->evaluate(left_record_batch, &dummy_result_batches));
- TIME_MICRO_OR_THROW(
- elapse_eval, expr_shuffle->evaluate(left_record_batch, &dummy_result_batches));
num_batches += 1;
}
} while (left_record_batch);
std::cout << "Readed left table with " << num_batches << " batches." << std::endl;
- TIME_MICRO_OR_THROW(elapse_eval, expr_probe->finish(&probe_result_iterator));
- TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->SetDependency(probe_result_iterator));
- TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->finish(&shuffle_result_iterator));
+ TIME_MICRO_OR_THROW(elapse_finish, expr_probe->finish(&probe_result_iterator));
num_batches = 0;
uint64_t num_output_batches = 0;
@@ -339,9 +304,7 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) {
right_column_vector.push_back(right_record_batch->column(i));
}
TIME_MICRO_OR_THROW(elapse_probe_process,
- probe_result_iterator->ProcessAndCacheOne(right_column_vector));
- TIME_MICRO_OR_THROW(elapse_shuffle_process,
- shuffle_result_iterator->Process(right_column_vector, &out));
+ probe_result_iterator->Process(right_column_vector, &out));
num_batches += 1;
num_output_batches++;
num_rows += out->num_rows();
@@ -349,16 +312,17 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) {
} while (right_record_batch);
std::cout << "Readed right table with " << num_batches << " batches." << std::endl;
- std::cout << "BenchmarkArrowComputeJoin processed " << num_batches
- << " batches, then output " << num_output_batches << " batches with "
- << num_rows << " rows, to complete, it took " << TIME_TO_STRING(elapse_gen)
- << " doing codegen, took " << TIME_TO_STRING(elapse_left_read)
- << " doing left BatchRead, took " << TIME_TO_STRING(elapse_right_read)
- << " doing right BatchRead, took " << TIME_TO_STRING(elapse_eval)
- << " doing left table hashmap insert, took "
- << TIME_TO_STRING(elapse_probe_process) << " doing probe indice fetch, took "
- << TIME_TO_STRING(elapse_shuffle_process) << " doing final shuffle."
- << std::endl;
+ std::cout << "=========================================="
+ << "\nBenchmarkArrowComputeJoin processed " << num_batches << " batches"
+ << "\noutput " << num_output_batches << " batches with " << num_rows
+ << " rows"
+ << "\nCodeGen took " << TIME_TO_STRING(elapse_gen)
+ << "\nLeft Batch Read took " << TIME_TO_STRING(elapse_left_read)
+ << "\nRight Batch Read took " << TIME_TO_STRING(elapse_right_read)
+ << "\nLeft Table Hash Insert took " << TIME_TO_STRING(elapse_eval)
+ << "\nMake Result Iterator took " << TIME_TO_STRING(elapse_finish)
+ << "\nProbe and Shuffle took " << TIME_TO_STRING(elapse_probe_process) << "\n"
+ << "===========================================" << std::endl;
}
} // namespace codegen
} // namespace sparkcolumnarplugin
diff --git a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc
index 92bb22a83..9fc6b5b8a 100644
--- a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc
+++ b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc
@@ -67,25 +67,12 @@ class BenchmarkArrowComputeSort : public ::testing::Test {
ASSERT_NOT_OK(
parquet_reader->GetRecordBatchReader({0}, {0, 1, 2}, &record_batch_reader));
- schema = record_batch_reader->schema();
- std::cout << schema->ToString() << std::endl;
-
////////////////// expr prepration ////////////////
field_list = record_batch_reader->schema()->fields();
ret_field_list = record_batch_reader->schema()->fields();
}
- void StartWithIterator() {
- uint64_t elapse_gen = 0;
- uint64_t elapse_read = 0;
- uint64_t elapse_eval = 0;
- uint64_t elapse_sort = 0;
- uint64_t elapse_shuffle = 0;
- uint64_t num_batches = 0;
- std::shared_ptr sort_expr;
- TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(schema, {sortArrays_expr},
- {f_indices}, &sort_expr, true));
-
+ void StartWithIterator(std::shared_ptr sort_expr) {
std::vector> input_batch_list;
std::vector> dummy_result_batches;
std::shared_ptr> sort_result_iterator;
@@ -126,45 +113,84 @@ class BenchmarkArrowComputeSort : public ::testing::Test {
std::shared_ptr file;
std::unique_ptr<::parquet::arrow::FileReader> parquet_reader;
std::shared_ptr record_batch_reader;
- std::shared_ptr schema;
std::vector> field_list;
std::vector> ret_field_list;
int primary_key_index = 0;
- std::shared_ptr f_indices;
std::shared_ptr f_res;
- ::gandiva::ExpressionPtr sortArrays_expr;
- ::gandiva::ExpressionPtr conditionShuffleExpr;
+
+ uint64_t elapse_gen = 0;
+ uint64_t elapse_read = 0;
+ uint64_t elapse_eval = 0;
+ uint64_t elapse_sort = 0;
+ uint64_t elapse_shuffle = 0;
+ uint64_t num_batches = 0;
};
TEST_F(BenchmarkArrowComputeSort, SortBenchmark) {
+ elapse_gen = 0;
+ elapse_read = 0;
+ elapse_eval = 0;
+ elapse_sort = 0;
+ elapse_shuffle = 0;
+ num_batches = 0;
+ ////////////////////// prepare expr_vector ///////////////////////
+ auto indices_type = std::make_shared(16);
+ f_res = field("res", arrow::uint64());
+
+ std::vector> gandiva_field_list;
+ for (auto field : field_list) {
+ gandiva_field_list.push_back(TreeExprBuilder::MakeField(field));
+ }
+ auto n_sort_to_indices =
+ TreeExprBuilder::MakeFunction("sortArraysToIndicesNullsFirstAsc",
+ {gandiva_field_list[primary_key_index]}, uint64());
+ std::shared_ptr schema;
+ schema = arrow::schema(field_list);
+ std::cout << schema->ToString() << std::endl;
+
+ ::gandiva::ExpressionPtr sortArrays_expr;
+ sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort_to_indices, f_res);
+
+ std::shared_ptr sort_expr;
+ TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(schema, {sortArrays_expr},
+ ret_field_list, &sort_expr, true));
+
+ ///////////////////// Calculation //////////////////
+ StartWithIterator(sort_expr);
+}
+
+TEST_F(BenchmarkArrowComputeSort, SortBenchmarkWOPayLoad) {
+ elapse_gen = 0;
+ elapse_read = 0;
+ elapse_eval = 0;
+ elapse_sort = 0;
+ elapse_shuffle = 0;
+ num_batches = 0;
////////////////////// prepare expr_vector ///////////////////////
auto indices_type = std::make_shared(16);
- f_indices = field("indices", indices_type);
f_res = field("res", arrow::uint64());
std::vector> gandiva_field_list;
for (auto field : field_list) {
gandiva_field_list.push_back(TreeExprBuilder::MakeField(field));
}
- auto n_left =
- TreeExprBuilder::MakeFunction("codegen_left_schema", gandiva_field_list, uint64());
- auto n_right = TreeExprBuilder::MakeFunction("codegen_right_schema", {}, uint64());
auto n_sort_to_indices =
TreeExprBuilder::MakeFunction("sortArraysToIndicesNullsFirstAsc",
{gandiva_field_list[primary_key_index]}, uint64());
- sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort_to_indices, f_indices);
+ std::shared_ptr schema;
+ schema = arrow::schema({field_list[primary_key_index]});
+ ::gandiva::ExpressionPtr sortArrays_expr;
+ sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort_to_indices, f_res);
- auto n_conditionedShuffleArrayList =
- TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint64());
- auto n_codegen_shuffle = TreeExprBuilder::MakeFunction(
- "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right},
- uint64());
- conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res);
+ std::shared_ptr sort_expr;
+ TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(schema, {sortArrays_expr},
+ {ret_field_list[primary_key_index]},
+ &sort_expr, true));
///////////////////// Calculation //////////////////
- StartWithIterator();
+ StartWithIterator(sort_expr);
}
} // namespace codegen
diff --git a/oap-native-sql/cpp/src/benchmarks/make.sh b/oap-native-sql/cpp/src/benchmarks/make.sh
new file mode 100755
index 000000000..c85c35956
--- /dev/null
+++ b/oap-native-sql/cpp/src/benchmarks/make.sh
@@ -0,0 +1,2 @@
+#gcc $1.cc -std=c++17 -O3 -shared -fPIC /mnt/nvme2/chendi/intel-bigdata/arrow/cpp/release-build/release/libarrow.a -o $1.so
+gcc -I /mnt/nvme2/chendi/intel-bigdata/OAP/oap-native-sql/cpp/src/ -I/mnt/nvme2/chendi/intel-bigdata/OAP/oap-native-sql/cpp/src/third_party/sparsehash $1.cc -std=c++17 -O3 -shared -fPIC -larrow -o $1.so
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc
index 74e7caf75..7a37ea570 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc
@@ -226,15 +226,6 @@ arrow::Status ExprVisitor::MakeExprVisitorImpl(
std::vector> left_field_list,
std::vector> right_field_list,
std::vector> ret_fields, ExprVisitor* p) {
- if (func_name.compare("conditionedShuffleArrayList") == 0) {
- std::shared_ptr child_node;
- if (func_node->children().size() > 0) {
- child_node = func_node->children()[0];
- }
- RETURN_NOT_OK(ConditionedShuffleArrayListVisitorImpl::Make(
- child_node, left_field_list, right_field_list, ret_fields, p, &impl_));
- goto finish;
- }
if (func_name.compare("conditionedProbeArraysInner") == 0 ||
func_name.compare("conditionedProbeArraysOuter") == 0 ||
func_name.compare("conditionedProbeArraysAnti") == 0 ||
@@ -272,7 +263,7 @@ arrow::Status ExprVisitor::MakeExprVisitorImpl(
}
RETURN_NOT_OK(ConditionedProbeArraysVisitorImpl::Make(
left_key_list, right_key_list, condition_node, join_type, left_field_list,
- right_field_list, p, &impl_));
+ right_field_list, ret_fields, p, &impl_));
goto finish;
}
finish:
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h
index 3fa7c6c69..8e216b87b 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h
@@ -414,93 +414,6 @@ class SortArraysToIndicesVisitorImpl : public ExprVisitorImpl {
bool asc_;
};
-////////////////////////// ConditionedShuffleArrayListVisitorImpl
-//////////////////////////
-class ConditionedShuffleArrayListVisitorImpl : public ExprVisitorImpl {
- public:
- ConditionedShuffleArrayListVisitorImpl(
- std::shared_ptr func_node,
- std::vector> left_field_list,
- std::vector> right_field_list,
- std::vector> ret_field_list, ExprVisitor* p)
- : func_node_(func_node),
- left_field_list_(left_field_list),
- right_field_list_(right_field_list),
- output_field_list_(ret_field_list),
- ExprVisitorImpl(p) {}
- static arrow::Status Make(std::shared_ptr func_node,
- std::vector> left_field_list,
- std::vector> right_field_list,
- std::vector> ret_field_list,
- ExprVisitor* p, std::shared_ptr* out) {
- auto impl = std::make_shared(
- func_node, left_field_list, right_field_list, ret_field_list, p);
- *out = impl;
- return arrow::Status::OK();
- }
-
- arrow::Status Init() override {
- if (initialized_) {
- return arrow::Status::OK();
- }
- RETURN_NOT_OK(extra::ConditionedShuffleArrayListKernel::Make(
- &p_->ctx_, func_node_, left_field_list_, right_field_list_, output_field_list_,
- &kernel_));
- initialized_ = true;
- return arrow::Status::OK();
- }
-
- arrow::Status Eval() override {
- switch (p_->dependency_result_type_) {
- case ArrowComputeResultType::None: {
- ArrayList in;
- for (int i = 0; i < p_->in_record_batch_->num_columns(); i++) {
- in.push_back(p_->in_record_batch_->column(i));
- }
- TIME_MICRO_OR_RAISE(p_->elapse_time_, kernel_->Evaluate(in));
- finish_return_type_ = ArrowComputeResultType::Batch;
- } break;
- default:
- return arrow::Status::NotImplemented(
- "ConditionedShuffleArrayListVisitorImpl: Does not support this type of "
- "input.");
- }
- return arrow::Status::OK();
- }
-
- arrow::Status SetDependency(
- const std::shared_ptr>& dependency_iter,
- int index) override {
- RETURN_NOT_OK(kernel_->SetDependencyIter(dependency_iter, index));
- p_->dependency_result_type_ = ArrowComputeResultType::BatchIterator;
- return arrow::Status::OK();
- }
-
- arrow::Status MakeResultIterator(
- std::shared_ptr schema,
- std::shared_ptr>* out) override {
- switch (finish_return_type_) {
- case ArrowComputeResultType::Batch: {
- TIME_MICRO_OR_RAISE(p_->elapse_time_, kernel_->MakeResultIterator(schema, out));
- p_->return_type_ = ArrowComputeResultType::BatchIterator;
- } break;
- default:
- return arrow::Status::Invalid(
- "ConditionedShuffleArrayListVisitorImpl MakeResultIterator does not "
- "support "
- "dependency type other than Batch.");
- }
- return arrow::Status::OK();
- }
-
- private:
- int col_id_;
- std::shared_ptr func_node_;
- std::vector> left_field_list_;
- std::vector> right_field_list_;
- std::vector> output_field_list_;
-};
-
////////////////////////// ConditionedProbeArraysVisitorImpl ///////////////////////
class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl {
public:
@@ -509,23 +422,26 @@ class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl {
std::vector> right_key_list,
std::shared_ptr func_node, int join_type,
std::vector> left_field_list,
- std::vector> right_field_list, ExprVisitor* p)
+ std::vector> right_field_list,
+ std::vector> ret_fields, ExprVisitor* p)
: left_key_list_(left_key_list),
right_key_list_(right_key_list),
join_type_(join_type),
func_node_(func_node),
left_field_list_(left_field_list),
right_field_list_(right_field_list),
+ ret_fields_(ret_fields),
ExprVisitorImpl(p) {}
static arrow::Status Make(std::vector> left_key_list,
std::vector> right_key_list,
std::shared_ptr func_node, int join_type,
std::vector> left_field_list,
std::vector> right_field_list,
+ std::vector> ret_fields,
ExprVisitor* p, std::shared_ptr* out) {
auto impl = std::make_shared(
left_key_list, right_key_list, func_node, join_type, left_field_list,
- right_field_list, p);
+ right_field_list, ret_fields, p);
*out = impl;
return arrow::Status::OK();
}
@@ -536,7 +452,7 @@ class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl {
}
RETURN_NOT_OK(extra::ConditionedProbeArraysKernel::Make(
&p_->ctx_, left_key_list_, right_key_list_, func_node_, join_type_,
- left_field_list_, right_field_list_, &kernel_));
+ left_field_list_, right_field_list_, arrow::schema(ret_fields_), &kernel_));
initialized_ = true;
return arrow::Status::OK();
}
@@ -583,6 +499,7 @@ class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl {
std::vector> right_key_list_;
std::vector> left_field_list_;
std::vector> right_field_list_;
+ std::vector> ret_fields_;
};
} // namespace arrowcompute
} // namespace codegen
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h
index 01f22ea0a..4cc39bd10 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h
@@ -23,10 +23,10 @@ namespace codegen {
namespace arrowcompute {
namespace extra {
struct ArrayItemIndex {
- uint64_t id = 0;
- uint64_t array_id = 0;
+ uint16_t id = 0;
+ uint16_t array_id = 0;
ArrayItemIndex() : array_id(0), id(0) {}
- ArrayItemIndex(uint64_t array_id, uint64_t id) : array_id(array_id), id(id) {}
+ ArrayItemIndex(uint16_t array_id, uint16_t id) : array_id(array_id), id(id) {}
};
} // namespace extra
} // namespace arrowcompute
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h
index afd15d261..86b8bde79 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h
@@ -31,16 +31,17 @@ using ArrayList = std::vector>;
class CodeGenBase {
public:
virtual arrow::Status Evaluate(const ArrayList& in) {
- return arrow::Status::NotImplemented("SortBase Evaluate is an abstract interface.");
+ return arrow::Status::NotImplemented(
+ "CodeGenBase Evaluate is an abstract interface.");
}
virtual arrow::Status Finish(std::shared_ptr* out) {
- return arrow::Status::NotImplemented("SortBase Finish is an abstract interface.");
+ return arrow::Status::NotImplemented("CodeGenBase Finish is an abstract interface.");
}
virtual arrow::Status MakeResultIterator(
std::shared_ptr schema,
std::shared_ptr>* out) {
return arrow::Status::NotImplemented(
- "SortBase MakeResultIterator is an abstract interface.");
+ "CodeGenBase MakeResultIterator is an abstract interface.");
}
};
} // namespace extra
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc
index bce0918ec..b51b669c4 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc
@@ -48,104 +48,159 @@ std::string BaseCodes() {
#include
#include
-template class ResultIterator {
-public:
- virtual bool HasNext() { return false; }
- virtual arrow::Status Next(std::shared_ptr *out) {
- return arrow::Status::NotImplemented("ResultIterator abstract Next function");
- }
- virtual arrow::Status
- Process(std::vector> in,
- std::shared_ptr *out,
- const std::shared_ptr &selection = nullptr) {
- return arrow::Status::NotImplemented("ResultIterator abstract Process function");
- }
- virtual arrow::Status
- ProcessAndCacheOne(std::vector> in,
- const std::shared_ptr &selection = nullptr) {
- return arrow::Status::NotImplemented(
- "ResultIterator abstract ProcessAndCacheOne function");
- }
- virtual arrow::Status GetResult(std::shared_ptr* out) {
- return arrow::Status::NotImplemented("ResultIterator abstract GetResult function");
- }
- virtual std::string ToString() { return ""; }
-};
-
-using ArrayList = std::vector>;
-struct ArrayItemIndex {
- uint64_t id = 0;
- uint64_t array_id = 0;
- ArrayItemIndex(uint64_t array_id, uint64_t id) : array_id(array_id), id(id) {}
-};
-
-class CodeGenBase {
- public:
- virtual arrow::Status Evaluate(const ArrayList& in) {
- return arrow::Status::NotImplemented("SortBase Evaluate is an abstract interface.");
- }
- virtual arrow::Status Finish(std::shared_ptr* out) {
- return arrow::Status::NotImplemented("SortBase Finish is an abstract interface.");
- }
- virtual arrow::Status MakeResultIterator(
- std::shared_ptr schema,
- std::shared_ptr>* out) {
- return arrow::Status::NotImplemented(
- "SortBase MakeResultIterator is an abstract interface.");
- }
-};)";
-}
-
-int FileSpinLock(std::string path) {
- std::string lockfile = path + "/nativesql_compile.lock";
+#include "codegen/arrow_compute/ext/array_item_index.h"
+#include "codegen/arrow_compute/ext/code_generator_base.h"
+#include "codegen/arrow_compute/ext/kernels_ext.h"
+#include "codegen/common/result_iterator.h"
+#include "sparsehash/sparse_hash_map.h"
+#include "third_party/arrow/utils/hashing.h"
- auto fd = open(lockfile.c_str(), O_CREAT, S_IRWXU | S_IRWXG);
- flock(fd, LOCK_EX);
+using namespace sparkcolumnarplugin::codegen::arrowcompute::extra;
- return fd;
+)";
}
-void FileSpinUnLock(int fd) {
- flock(fd, LOCK_UN);
- close(fd);
+std::string GetArrowTypeDefString(std::shared_ptr type) {
+ switch (type->id()) {
+ case arrow::UInt8Type::type_id:
+ return "uint8()";
+ case arrow::Int8Type::type_id:
+ return "int8()";
+ case arrow::UInt16Type::type_id:
+ return "uint16()";
+ case arrow::Int16Type::type_id:
+ return "int16()";
+ case arrow::UInt32Type::type_id:
+ return "uint32()";
+ case arrow::Int32Type::type_id:
+ return "int32()";
+ case arrow::UInt64Type::type_id:
+ return "uint64()";
+ case arrow::Int64Type::type_id:
+ return "int64()";
+ case arrow::FloatType::type_id:
+ return "float632()";
+ case arrow::DoubleType::type_id:
+ return "float64()";
+ case arrow::Date32Type::type_id:
+ return "date32()";
+ case arrow::StringType::type_id:
+ return "utf8()";
+ default:
+ std::cout << "GetTypeString can't convert " << type->ToString() << std::endl;
+ throw;
+ }
}
-
-std::string GetTypeString(std::shared_ptr type) {
+std::string GetCTypeString(std::shared_ptr type) {
+ switch (type->id()) {
+ case arrow::UInt8Type::type_id:
+ return "uint8_t";
+ case arrow::Int8Type::type_id:
+ return "int8_t";
+ case arrow::UInt16Type::type_id:
+ return "uint16_t";
+ case arrow::Int16Type::type_id:
+ return "int16_t";
+ case arrow::UInt32Type::type_id:
+ return "uint32_t";
+ case arrow::Int32Type::type_id:
+ return "int32_t";
+ case arrow::UInt64Type::type_id:
+ return "uint64_t";
+ case arrow::Int64Type::type_id:
+ return "int64_t";
+ case arrow::FloatType::type_id:
+ return "float";
+ case arrow::DoubleType::type_id:
+ return "double";
+ case arrow::Date32Type::type_id:
+ std::cout << "Can't handle Data32Type yet" << std::endl;
+ throw;
+ case arrow::StringType::type_id:
+ return "std::string";
+ default:
+ std::cout << "GetTypeString can't convert " << type->ToString() << std::endl;
+ throw;
+ }
+}
+std::string GetTypeString(std::shared_ptr type, std::string tail) {
switch (type->id()) {
case arrow::UInt8Type::type_id:
- return "UInt8Type";
+ return "UInt8" + tail;
case arrow::Int8Type::type_id:
- return "Int8Type";
+ return "Int8" + tail;
case arrow::UInt16Type::type_id:
- return "UInt16Type";
+ return "UInt16" + tail;
case arrow::Int16Type::type_id:
- return "Int16Type";
+ return "Int16" + tail;
case arrow::UInt32Type::type_id:
- return "UInt32Type";
+ return "UInt32" + tail;
case arrow::Int32Type::type_id:
- return "Int32Type";
+ return "Int32" + tail;
case arrow::UInt64Type::type_id:
- return "UInt64Type";
+ return "UInt64" + tail;
case arrow::Int64Type::type_id:
- return "Int64Type";
+ return "Int64" + tail;
case arrow::FloatType::type_id:
- return "FloatType";
+ return "Float" + tail;
case arrow::DoubleType::type_id:
- return "DoubleType";
+ return "Double" + tail;
case arrow::Date32Type::type_id:
- return "Date32Type";
+ return "Date32" + tail;
case arrow::StringType::type_id:
- return "StringType";
+ return "String" + tail;
default:
std::cout << "GetTypeString can't convert " << type->ToString() << std::endl;
throw;
}
}
+std::string GetTempPath() {
+ std::string tmp_dir_;
+ const char* env_tmp_dir = std::getenv("NATIVESQL_TMP_DIR");
+ if (env_tmp_dir != nullptr) {
+ tmp_dir_ = std::string(env_tmp_dir);
+ } else {
+#ifdef NATIVESQL_SRC_PATH
+ tmp_dir_ = NATIVESQL_SRC_PATH;
+#else
+ std::cerr << "envioroment variable NATIVESQL_TMP_DIR is not set" << std::endl;
+ throw;
+#endif
+ }
+ return tmp_dir_;
+}
+
+int GetBatchSize() {
+ int batch_size;
+ const char* env_batch_size = std::getenv("NATIVESQL_BATCH_SIZE");
+ if (env_batch_size != nullptr) {
+ batch_size = atoi(env_batch_size);
+ } else {
+ batch_size = 10000;
+ }
+ return batch_size;
+}
+
+int FileSpinLock() {
+ std::string lockfile = GetTempPath() + "/nativesql_compile.lock";
+
+ auto fd = open(lockfile.c_str(), O_CREAT, S_IRWXU | S_IRWXG);
+ flock(fd, LOCK_EX);
+
+ return fd;
+}
+
+void FileSpinUnLock(int fd) {
+ flock(fd, LOCK_UN);
+ close(fd);
+}
+
arrow::Status CompileCodes(std::string codes, std::string signature) {
// temporary cpp/library output files
srand(time(NULL));
- std::string outpath = "/tmp";
+ std::string outpath = GetTempPath() + "/tmp/";
+ mkdir(outpath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);
std::string prefix = "/spark-columnar-plugin-codegen-";
std::string cppfile = outpath + prefix + signature + ".cc";
std::string libfile = outpath + prefix + signature + ".so";
@@ -158,6 +213,10 @@ arrow::Status CompileCodes(std::string codes, std::string signature) {
exit(EXIT_FAILURE);
}
out << codes;
+#ifdef DEBUG
+ std::cout << "BatchSize is " << GetBatchSize() << std::endl;
+ std::cout << codes << std::endl;
+#endif
out.flush();
out.close();
@@ -170,18 +229,29 @@ arrow::Status CompileCodes(std::string codes, std::string signature) {
const char* env_arrow_dir = std::getenv("LIBARROW_DIR");
std::string arrow_header;
- std::string arrow_lib;
+ std::string arrow_lib, arrow_lib2;
+ std::string nativesql_header = " -I" + GetTempPath() + "/nativesql_include/ ";
+ std::string nativesql_lib = " -L" + GetTempPath() + " ";
if (env_arrow_dir != nullptr) {
arrow_header = " -I" + std::string(env_arrow_dir) + "/include ";
arrow_lib = " -L" + std::string(env_arrow_dir) + "/lib64 ";
+ // incase there's a different location for libarrow.so
+ arrow_lib2 = " -L" + std::string(env_arrow_dir) + "/lib ";
}
// compile the code
- std::string cmd = env_gcc + " -std=c++11 -Wall -Wextra " + arrow_header + arrow_lib +
- cppfile + " -o " + libfile + " -O3 -shared -fPIC -larrow 2> " +
+ std::string cmd = env_gcc + " -std=c++14 -Wno-deprecated-declarations " + arrow_header +
+ arrow_lib + arrow_lib2 + nativesql_header + nativesql_lib + cppfile + " -o " +
+ libfile + " -O3 -march=native -shared -fPIC -larrow -lspark_columnar_jni 2> " +
logfile;
+ //#ifdef DEBUG
+ std::cout << cmd << std::endl;
+ //#endif
int ret = system(cmd.c_str());
if (WEXITSTATUS(ret) != EXIT_SUCCESS) {
std::cout << "compilation failed, see " << logfile << std::endl;
+ std::cout << cmd << std::endl;
+ cmd = "ls -R -l " + GetTempPath() + "; cat " + logfile;
+ system(cmd.c_str());
exit(EXIT_FAILURE);
}
@@ -197,10 +267,9 @@ arrow::Status CompileCodes(std::string codes, std::string signature) {
arrow::Status LoadLibrary(std::string signature, arrow::compute::FunctionContext* ctx,
std::shared_ptr* out) {
- std::string outpath = "/tmp";
+ std::string outpath = GetTempPath() + "/tmp/";
std::string prefix = "/spark-columnar-plugin-codegen-";
std::string libfile = outpath + prefix + signature + ".so";
- std::cout << "LoadLibrary " << libfile << std::endl;
// load dynamic library
void* dynlib = dlopen(libfile.c_str(), RTLD_LAZY);
if (!dynlib) {
@@ -210,6 +279,7 @@ arrow::Status LoadLibrary(std::string signature, arrow::compute::FunctionContext
// loading symbol from library and assign to pointer
// (to be cast to function pointer later)
+ std::cout << "LoadLibrary " << libfile << std::endl;
void (*MakeCodeGen)(arrow::compute::FunctionContext * ctx,
std::shared_ptr * out);
*(void**)(&MakeCodeGen) = dlsym(dynlib, "MakeCodeGen");
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h
index 1316659ae..a47590586 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h
@@ -30,11 +30,15 @@ namespace extra {
std::string BaseCodes();
-int FileSpinLock(std::string path);
+int FileSpinLock();
void FileSpinUnLock(int fd);
-std::string GetTypeString(std::shared_ptr type);
+int GetBatchSize();
+std::string GetArrowTypeDefString(std::shared_ptr type);
+std::string GetCTypeString(std::shared_ptr type);
+std::string GetTypeString(std::shared_ptr type,
+ std::string tail = "Type");
arrow::Status CompileCodes(std::string codes, std::string signature);
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc
index 3196f5928..da1f1313b 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc
@@ -32,87 +32,52 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FunctionNode& node) {
std::shared_ptr child_visitor;
*func_count_ = *func_count_ + 1;
RETURN_NOT_OK(MakeCodeGenNodeVisitor(child, field_list_v_, func_count_, codes_ss_,
- &child_visitor));
+ left_indices_, right_indices_, &child_visitor));
child_visitor_list.push_back(child_visitor);
}
auto func_name = node.descriptor()->name();
std::stringstream ss;
if (func_name.compare("less_than") == 0) {
- auto check_str = child_visitor_list[0]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
- check_str = child_visitor_list[1]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
ss << "(" << child_visitor_list[0]->GetResult() << " < "
<< child_visitor_list[1]->GetResult() << ")";
- }
- if (func_name.compare("greater_than") == 0) {
- auto check_str = child_visitor_list[0]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
- check_str = child_visitor_list[1]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
+ } else if (func_name.compare("greater_than") == 0) {
ss << "(" << child_visitor_list[0]->GetResult() << " > "
<< child_visitor_list[1]->GetResult() << ")";
- }
- if (func_name.compare("less_than_or_equal_to") == 0) {
- auto check_str = child_visitor_list[0]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
- check_str = child_visitor_list[1]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
+ } else if (func_name.compare("less_than_or_equal_to") == 0) {
ss << "(" << child_visitor_list[0]->GetResult()
<< " <= " << child_visitor_list[1]->GetResult() << ")";
- }
- if (func_name.compare("greater_than_or_equal_to") == 0) {
- auto check_str = child_visitor_list[0]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
- check_str = child_visitor_list[1]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
+ } else if (func_name.compare("greater_than_or_equal_to") == 0) {
ss << "(" << child_visitor_list[0]->GetResult()
<< " >= " << child_visitor_list[1]->GetResult() << ")";
- }
- if (func_name.compare("equal") == 0) {
- auto check_str = child_visitor_list[0]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
- check_str = child_visitor_list[1]->GetPreCheck();
- if (!check_str.empty()) {
- ss << check_str << " || ";
- }
+ } else if (func_name.compare("equal") == 0) {
ss << "(" << child_visitor_list[0]->GetResult()
<< " == " << child_visitor_list[1]->GetResult() << ")";
- }
- if (func_name.compare("not") == 0) {
+ } else if (func_name.compare("not") == 0) {
ss << "!(" << child_visitor_list[0]->GetResult() << ")";
+ } else {
+ ss << child_visitor_list[0]->GetResult();
+ }
+ if (cur_func_id == 0) {
+ codes_str_ = ss.str();
+ } else {
+ codes_str_ = ss.str();
}
- codes_str_ = ss.str();
return arrow::Status::OK();
}
+
arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FieldNode& node) {
auto cur_func_id = *func_count_;
auto this_field = node.field();
int arg_id = 0;
bool found = false;
+ int index = 0;
for (auto field_list : field_list_v_) {
+ arg_id = 0;
for (auto field : field_list) {
if (field->name() == this_field->name()) {
found = true;
+ InsertToIndices(index, arg_id);
break;
}
arg_id++;
@@ -120,27 +85,26 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FieldNode& node) {
if (found) {
break;
}
+ index = 1;
}
- std::string type_name;
- if (this_field->type()->id() == arrow::Type::STRING) {
- type_name = "std::string";
+ if (index == 0) {
+ *codes_ss_ << "if (cached_0_" << arg_id
+ << "_[x.array_id]->IsNull(x.id)) {return false;}" << std::endl;
+ *codes_ss_ << "auto input_field_" << cur_func_id << " = cached_0_" << arg_id
+ << "_[x.array_id]->GetView(x.id);" << std::endl;
+
} else {
- type_name = this_field->type()->name();
+ *codes_ss_ << "if (cached_1_" << arg_id << "_->IsNull(y)) {return false;}"
+ << std::endl;
+ *codes_ss_ << "auto input_field_" << cur_func_id << " = cached_1_" << arg_id
+ << "_->GetView(y);" << std::endl;
}
- *codes_ss_ << type_name << " *input_field_" << cur_func_id << " = nullptr;"
- << std::endl;
- *codes_ss_ << "if (!is_null_func_list[" << arg_id << "]()) {" << std::endl;
- *codes_ss_ << " input_field_" << cur_func_id << " = (" << type_name
- << "*)get_func_list[" << arg_id << "]();" << std::endl;
- *codes_ss_ << "}" << std::endl;
std::stringstream ss;
- ss << "*input_field_" << cur_func_id;
+ ss << "input_field_" << cur_func_id;
codes_str_ = ss.str();
- std::stringstream ss_check;
- ss_check << "(input_field_" << cur_func_id << " == nullptr)";
- check_str_ = ss_check.str();
+ check_str_ = "";
return arrow::Status::OK();
}
@@ -151,12 +115,12 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::IfNode& node) {
arrow::Status CodeGenNodeVisitor::Visit(const gandiva::LiteralNode& node) {
auto cur_func_id = *func_count_;
if (node.return_type()->id() == arrow::Type::STRING) {
- *codes_ss_ << "const std::string input_field_" << cur_func_id << R"( = ")"
+ *codes_ss_ << "auto input_field_" << cur_func_id << R"( = ")"
<< gandiva::ToString(node.holder()) << R"(";)" << std::endl;
} else {
- *codes_ss_ << "const " << node.return_type()->name() << " input_field_" << cur_func_id
- << " = " << gandiva::ToString(node.holder()) << ";" << std::endl;
+ *codes_ss_ << "auto input_field_" << cur_func_id << " = "
+ << gandiva::ToString(node.holder()) << ";" << std::endl;
}
std::stringstream ss;
@@ -172,7 +136,7 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::BooleanNode& node) {
std::shared_ptr child_visitor;
*func_count_ = *func_count_ + 1;
RETURN_NOT_OK(MakeCodeGenNodeVisitor(child, field_list_v_, func_count_, codes_ss_,
- &child_visitor));
+ left_indices_, right_indices_, &child_visitor));
child_visitor_list.push_back(child_visitor);
}
@@ -188,14 +152,106 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::BooleanNode& node) {
codes_str_ = ss.str();
return arrow::Status::OK();
}
+
arrow::Status CodeGenNodeVisitor::Visit(const gandiva::InExpressionNode& node) {
+ auto cur_func_id = *func_count_;
+ std::shared_ptr child_visitor;
+ *func_count_ = *func_count_ + 1;
+ RETURN_NOT_OK(MakeCodeGenNodeVisitor(node.eval_expr(), field_list_v_, func_count_,
+ codes_ss_, left_indices_, right_indices_,
+ &child_visitor));
+ *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {";
+ bool add_comma = false;
+ for (auto& value : node.values()) {
+ if (add_comma) {
+ *codes_ss_ << ", ";
+ }
+ // add type in the front to differentiate
+ *codes_ss_ << value;
+ add_comma = true;
+ }
+ *codes_ss_ << "};" << std::endl;
+
+ std::stringstream ss;
+ ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id
+ << ".end(), " << child_visitor->GetResult() << ") != "
+ << "input_field_" << cur_func_id << ".end()";
+ codes_str_ = ss.str();
+ check_str_ = child_visitor->GetPreCheck();
return arrow::Status::OK();
}
+
arrow::Status CodeGenNodeVisitor::Visit(const gandiva::InExpressionNode& node) {
+ auto cur_func_id = *func_count_;
+ std::shared_ptr child_visitor;
+ *func_count_ = *func_count_ + 1;
+ RETURN_NOT_OK(MakeCodeGenNodeVisitor(node.eval_expr(), field_list_v_, func_count_,
+ codes_ss_, left_indices_, right_indices_,
+ &child_visitor));
+ *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {";
+ bool add_comma = false;
+ for (auto& value : node.values()) {
+ if (add_comma) {
+ *codes_ss_ << ", ";
+ }
+ // add type in the front to differentiate
+ *codes_ss_ << value;
+ add_comma = true;
+ }
+ *codes_ss_ << "};" << std::endl;
+
+ std::stringstream ss;
+ ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id
+ << ".end(), " << child_visitor->GetResult() << ") != "
+ << "input_field_" << cur_func_id << ".end()";
+ codes_str_ = ss.str();
+ check_str_ = child_visitor->GetPreCheck();
return arrow::Status::OK();
}
+
arrow::Status CodeGenNodeVisitor::Visit(
const gandiva::InExpressionNode& node) {
+ auto cur_func_id = *func_count_;
+ std::shared_ptr child_visitor;
+ *func_count_ = *func_count_ + 1;
+ RETURN_NOT_OK(MakeCodeGenNodeVisitor(node.eval_expr(), field_list_v_, func_count_,
+ codes_ss_, left_indices_, right_indices_,
+ &child_visitor));
+ *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {";
+ bool add_comma = false;
+ for (auto& value : node.values()) {
+ if (add_comma) {
+ *codes_ss_ << ", ";
+ }
+ // add type in the front to differentiate
+ *codes_ss_ << R"(")" << value << R"(")";
+ add_comma = true;
+ }
+ *codes_ss_ << "};" << std::endl;
+
+ std::stringstream ss;
+ ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id
+ << ".end(), " << child_visitor->GetResult() << ") != "
+ << "input_field_" << cur_func_id << ".end()";
+ codes_str_ = ss.str();
+ check_str_ = child_visitor->GetPreCheck();
+ return arrow::Status::OK();
+}
+
+arrow::Status CodeGenNodeVisitor::InsertToIndices(int index, int arg_id) {
+ if (index == 0) {
+ if (std::find((*left_indices_).begin(), (*left_indices_).end(), arg_id) ==
+ (*left_indices_).end()) {
+ (*left_indices_).push_back(arg_id);
+ }
+ }
+ if (index == 1) {
+ if (std::find((*right_indices_).begin(), (*right_indices_).end(), arg_id) ==
+ (*right_indices_).end()) {
+ (*right_indices_).push_back(arg_id);
+ }
+ }
+
return arrow::Status::OK();
}
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h
index 127d95534..bbb77eebe 100644
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h
+++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h
@@ -28,11 +28,14 @@ class CodeGenNodeVisitor : public VisitorBase {
public:
CodeGenNodeVisitor(std::shared_ptr func,
std::vector>> field_list_v,
- int* func_count, std::stringstream* codes_ss)
+ int* func_count, std::stringstream* codes_ss,
+ std::vector* left_indices, std::vector* right_indices)
: func_(func),
field_list_v_(field_list_v),
func_count_(func_count),
- codes_ss_(codes_ss) {}
+ codes_ss_(codes_ss),
+ left_indices_(left_indices),
+ right_indices_(right_indices) {}
arrow::Status Eval() {
RETURN_NOT_OK(func_->Accept(*this));
@@ -57,13 +60,17 @@ class CodeGenNodeVisitor : public VisitorBase {
std::stringstream* codes_ss_;
std::string codes_str_;
std::string check_str_;
+ std::vector* left_indices_;
+ std::vector* right_indices_;
+ arrow::Status InsertToIndices(int index, int arg_id);
};
static arrow::Status MakeCodeGenNodeVisitor(
std::shared_ptr func,
std::vector>> field_list_v, int* func_count,
- std::stringstream* codes_ss, std::shared_ptr* out) {
- auto visitor =
- std::make_shared(func, field_list_v, func_count, codes_ss);
+ std::stringstream* codes_ss, std::vector* left_indices,
+ std::vector* right_indices, std::shared_ptr* out) {
+ auto visitor = std::make_shared(
+ func, field_list_v, func_count, codes_ss, left_indices, right_indices);
RETURN_NOT_OK(visitor->Eval());
*out = visitor;
return arrow::Status::OK();
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.cc
deleted file mode 100644
index 7a9af5781..000000000
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.cc
+++ /dev/null
@@ -1,251 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "codegen/arrow_compute/ext/codegen_node_visitor_v2.h"
-
-#include
-
-#include
-#include
-
-namespace sparkcolumnarplugin {
-namespace codegen {
-namespace arrowcompute {
-namespace extra {
-std::string CodeGenNodeVisitorV2::GetResult() { return codes_str_; }
-std::string CodeGenNodeVisitorV2::GetPreCheck() { return check_str_; }
-arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::FunctionNode& node) {
- std::vector> child_visitor_list;
- auto cur_func_id = *func_count_;
- for (auto child : node.children()) {
- std::shared_ptr child_visitor;
- *func_count_ = *func_count_ + 1;
- RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(child, field_list_v_, func_count_, codes_ss_,
- &child_visitor));
- child_visitor_list.push_back(child_visitor);
- }
-
- auto func_name = node.descriptor()->name();
- std::stringstream ss;
- if (func_name.compare("less_than") == 0) {
- ss << "(" << child_visitor_list[0]->GetResult() << " < "
- << child_visitor_list[1]->GetResult() << ")";
- } else if (func_name.compare("greater_than") == 0) {
- ss << "(" << child_visitor_list[0]->GetResult() << " > "
- << child_visitor_list[1]->GetResult() << ")";
- } else if (func_name.compare("less_than_or_equal_to") == 0) {
- ss << "(" << child_visitor_list[0]->GetResult()
- << " <= " << child_visitor_list[1]->GetResult() << ")";
- } else if (func_name.compare("greater_than_or_equal_to") == 0) {
- ss << "(" << child_visitor_list[0]->GetResult()
- << " >= " << child_visitor_list[1]->GetResult() << ")";
- } else if (func_name.compare("equal") == 0) {
- ss << "(" << child_visitor_list[0]->GetResult()
- << " == " << child_visitor_list[1]->GetResult() << ")";
- } else if (func_name.compare("not") == 0) {
- ss << "!(" << child_visitor_list[0]->GetResult() << ")";
- } else {
- ss << child_visitor_list[0]->GetResult();
- }
- if (cur_func_id == 0) {
- codes_str_ = ss.str();
- } else {
- codes_str_ = ss.str();
- }
- return arrow::Status::OK();
-}
-
-arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::FieldNode& node) {
- auto cur_func_id = *func_count_;
- auto this_field = node.field();
- int arg_id = 0;
- bool found = false;
- std::string is_null_func_list = "left_is_null_func_list_";
- std::string get_func_list = "left_get_func_list_";
- std::string index = "left_index";
- for (auto field_list : field_list_v_) {
- arg_id = 0;
- for (auto field : field_list) {
- if (field->name() == this_field->name()) {
- found = true;
- break;
- }
- arg_id++;
- }
- if (found) {
- break;
- }
- is_null_func_list = "right_is_null_func_list_";
- get_func_list = "right_get_func_list_";
- index = "right_index";
- }
- std::string type_name;
- if (this_field->type()->id() == arrow::Type::STRING) {
- type_name = "std::string";
- } else {
- type_name = this_field->type()->name();
- }
- *codes_ss_ << type_name << " *input_field_" << cur_func_id << " = nullptr;"
- << std::endl;
- *codes_ss_ << "if (!" << is_null_func_list << "[" << arg_id << "](" + index + ")) {"
- << std::endl;
- *codes_ss_ << " input_field_" << cur_func_id << " = (" << type_name << "*)"
- << get_func_list << "[" << arg_id << "](" + index + ");" << std::endl;
- *codes_ss_ << "} else {" << std::endl;
- *codes_ss_ << " return false;" << std::endl;
- *codes_ss_ << "}" << std::endl;
-
- std::stringstream ss;
- ss << "*input_field_" << cur_func_id;
- codes_str_ = ss.str();
-
- check_str_ = "";
- return arrow::Status::OK();
-}
-
-arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::IfNode& node) {
- return arrow::Status::OK();
-}
-
-arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::LiteralNode& node) {
- auto cur_func_id = *func_count_;
- if (node.return_type()->id() == arrow::Type::STRING) {
- *codes_ss_ << "const std::string input_field_" << cur_func_id << R"( = ")"
- << gandiva::ToString(node.holder()) << R"(";)" << std::endl;
-
- } else {
- *codes_ss_ << "const " << node.return_type()->name() << " input_field_" << cur_func_id
- << " = " << gandiva::ToString(node.holder()) << ";" << std::endl;
- }
-
- std::stringstream ss;
- ss << "input_field_" << cur_func_id;
- codes_str_ = ss.str();
- return arrow::Status::OK();
-}
-
-arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::BooleanNode& node) {
- std::vector> child_visitor_list;
- auto cur_func_id = *func_count_;
- for (auto child : node.children()) {
- std::shared_ptr child_visitor;
- *func_count_ = *func_count_ + 1;
- RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(child, field_list_v_, func_count_, codes_ss_,
- &child_visitor));
- child_visitor_list.push_back(child_visitor);
- }
-
- std::stringstream ss;
- if (node.expr_type() == gandiva::BooleanNode::AND) {
- ss << "(" << child_visitor_list[0]->GetResult() << ") && ("
- << child_visitor_list[1]->GetResult() << ")";
- }
- if (node.expr_type() == gandiva::BooleanNode::OR) {
- ss << "(" << child_visitor_list[0]->GetResult() << ") || ("
- << child_visitor_list[1]->GetResult() << ")";
- }
- codes_str_ = ss.str();
- return arrow::Status::OK();
-}
-
-arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::InExpressionNode& node) {
- auto cur_func_id = *func_count_;
- std::shared_ptr child_visitor;
- *func_count_ = *func_count_ + 1;
- RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(node.eval_expr(), field_list_v_, func_count_,
- codes_ss_, &child_visitor));
- *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {";
- bool add_comma = false;
- for (auto& value : node.values()) {
- if (add_comma) {
- *codes_ss_ << ", ";
- }
- // add type in the front to differentiate
- *codes_ss_ << value;
- add_comma = true;
- }
- *codes_ss_ << "};" << std::endl;
-
- std::stringstream ss;
- ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id
- << ".end(), " << child_visitor->GetResult() << ") != "
- << "input_field_" << cur_func_id << ".end()";
- codes_str_ = ss.str();
- check_str_ = child_visitor->GetPreCheck();
- return arrow::Status::OK();
-}
-
-arrow::Status CodeGenNodeVisitorV2::Visit(
- const gandiva::InExpressionNode& node) {
- auto cur_func_id = *func_count_;
- std::shared_ptr child_visitor;
- *func_count_ = *func_count_ + 1;
- RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(node.eval_expr(), field_list_v_, func_count_,
- codes_ss_, &child_visitor));
- *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {";
- bool add_comma = false;
- for (auto& value : node.values()) {
- if (add_comma) {
- *codes_ss_ << ", ";
- }
- // add type in the front to differentiate
- *codes_ss_ << value;
- add_comma = true;
- }
- *codes_ss_ << "};" << std::endl;
-
- std::stringstream ss;
- ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id
- << ".end(), " << child_visitor->GetResult() << ") != "
- << "input_field_" << cur_func_id << ".end()";
- codes_str_ = ss.str();
- check_str_ = child_visitor->GetPreCheck();
- return arrow::Status::OK();
-}
-
-arrow::Status CodeGenNodeVisitorV2::Visit(
- const gandiva::InExpressionNode& node) {
- auto cur_func_id = *func_count_;
- std::shared_ptr child_visitor;
- *func_count_ = *func_count_ + 1;
- RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(node.eval_expr(), field_list_v_, func_count_,
- codes_ss_, &child_visitor));
- *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {";
- bool add_comma = false;
- for (auto& value : node.values()) {
- if (add_comma) {
- *codes_ss_ << ", ";
- }
- // add type in the front to differentiate
- *codes_ss_ << R"(")" << value << R"(")";
- add_comma = true;
- }
- *codes_ss_ << "};" << std::endl;
-
- std::stringstream ss;
- ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id
- << ".end(), " << child_visitor->GetResult() << ") != "
- << "input_field_" << cur_func_id << ".end()";
- codes_str_ = ss.str();
- check_str_ = child_visitor->GetPreCheck();
- return arrow::Status::OK();
-}
-
-} // namespace extra
-} // namespace arrowcompute
-} // namespace codegen
-} // namespace sparkcolumnarplugin
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.h
deleted file mode 100644
index 88cbd3caa..000000000
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include
-#include "codegen/common/visitor_base.h"
-
-namespace sparkcolumnarplugin {
-namespace codegen {
-namespace arrowcompute {
-namespace extra {
-class CodeGenNodeVisitorV2 : public VisitorBase {
- public:
- CodeGenNodeVisitorV2(
- std::shared_ptr func,
- std::vector>> field_list_v,
- int* func_count, std::stringstream* codes_ss)
- : func_(func),
- field_list_v_(field_list_v),
- func_count_(func_count),
- codes_ss_(codes_ss) {}
-
- arrow::Status Eval() {
- RETURN_NOT_OK(func_->Accept(*this));
- return arrow::Status::OK();
- }
- std::string GetResult();
- std::string GetPreCheck();
- arrow::Status Visit(const gandiva::FunctionNode& node) override;
- arrow::Status Visit(const gandiva::FieldNode& node) override;
- arrow::Status Visit(const gandiva::IfNode& node) override;
- arrow::Status Visit(const gandiva::LiteralNode& node) override;
- arrow::Status Visit(const gandiva::BooleanNode& node) override;
- arrow::Status Visit(const gandiva::InExpressionNode& node) override;
- arrow::Status Visit(const gandiva::InExpressionNode& node) override;
- arrow::Status Visit(const gandiva::InExpressionNode& node) override;
-
- private:
- std::shared_ptr func_;
- std::vector>> field_list_v_;
- int* func_count_;
- // output
- std::stringstream* codes_ss_;
- std::string codes_str_;
- std::string check_str_;
-};
-static arrow::Status MakeCodeGenNodeVisitorV2(
- std::shared_ptr func,
- std::vector>> field_list_v, int* func_count,
- std::stringstream* codes_ss, std::shared_ptr* out) {
- auto visitor =
- std::make_shared(func, field_list_v, func_count, codes_ss);
- RETURN_NOT_OK(visitor->Eval());
- *out = visitor;
- return arrow::Status::OK();
-}
-} // namespace extra
-} // namespace arrowcompute
-} // namespace codegen
-} // namespace sparkcolumnarplugin
diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc
deleted file mode 100644
index a10b5ebf5..000000000
--- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc
+++ /dev/null
@@ -1,815 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include
-#include
-#include
-#include
-#include
-
-#include "codegen/arrow_compute/ext/codegen_node_visitor_v2.h"
-#include "codegen/arrow_compute/ext/conditioner.h"
-#include "codegen/arrow_compute/ext/item_iterator.h"
-#include "codegen/arrow_compute/ext/kernels_ext.h"
-#include "codegen/arrow_compute/ext/shuffle_v2_action.h"
-#include "third_party/arrow/utils/hashing.h"
-
-namespace sparkcolumnarplugin {
-namespace codegen {
-namespace arrowcompute {
-namespace extra {
-
-using ArrayList = std::vector>;
-
-/////////////// ConditionedProbeArrays ////////////////
-class ConditionedProbeArraysKernel::Impl {
- public:
- Impl() {}
- virtual ~Impl() {}
- virtual arrow::Status Evaluate(const ArrayList& in) {
- return arrow::Status::NotImplemented(
- "ConditionedProbeArraysKernel::Impl Evaluate is abstract");
- } // namespace extra
- virtual arrow::Status MakeResultIterator(
- std::shared_ptr schema,
- std::shared_ptr>* out) {
- return arrow::Status::NotImplemented(
- "ConditionedProbeArraysKernel::Impl MakeResultIterator is abstract");
- }
-}; // namespace arrowcompute
-
-template
-class ConditionedProbeArraysTypedImpl : public ConditionedProbeArraysKernel::Impl {
- public:
- ConditionedProbeArraysTypedImpl(
- arrow::compute::FunctionContext* ctx,
- std::vector> left_key_list,
- std::vector> right_key_list,
- std::shared_ptr func_node, int join_type,
- std::vector> left_field_list,
- std::vector> right_field_list)
- : ctx_(ctx),
- left_key_list_(left_key_list),
- right_key_list_(right_key_list),
- func_node_(func_node),
- join_type_(join_type),
- left_field_list_(left_field_list),
- right_field_list_(right_field_list) {
- hash_table_ = std::make_shared(ctx_->memory_pool());
- for (auto key_field : left_key_list) {
- int i = 0;
- for (auto field : left_field_list) {
- if (key_field->name() == field->name()) {
- break;
- }
- i++;
- }
- left_key_indices_.push_back(i);
- }
- for (auto key_field : right_key_list) {
- int i = 0;
- for (auto field : right_field_list) {
- if (key_field->name() == field->name()) {
- break;
- }
- i++;
- }
- right_key_indices_.push_back(i);
- }
- if (left_key_list.size() > 1) {
- std::vector> type_list;
- for (auto key_field : left_key_list) {
- type_list.push_back(key_field->type());
- }
- extra::HashAggrArrayKernel::Make(ctx_, type_list, &concat_kernel_);
- }
- for (auto field : left_field_list) {
- std::shared_ptr iter;
- MakeArrayListItemIterator(field->type(), &iter);
- input_iterator_list_.push_back(iter);
- }
- for (auto field : right_field_list) {
- std::shared_ptr iter;
- MakeArrayItemIterator(field->type(), &iter);
- input_iterator_list_.push_back(iter);
- }
- input_cache_.resize(left_field_list.size());
-
- // This Function suppose to return a lambda function for later ResultIterator
- auto start = std::chrono::steady_clock::now();
- if (func_node_) {
- LoadJITFunction(func_node_, left_field_list_, right_field_list_, &conditioner_);
- }
- auto end = std::chrono::steady_clock::now();
- std::cout
- << "Code Generation took "
- << (std::chrono::duration_cast(end - start).count() /
- 1000)
- << " ms." << std::endl;
- }
-
- arrow::Status Evaluate(const ArrayList& in_arr_list) override {
- if (in_arr_list.size() != input_cache_.size()) {
- return arrow::Status::Invalid(
- "ConditionedShuffleArrayListKernel input arrayList size does not match numCols "
- "in cache, which are ",
- in_arr_list.size(), " and ", input_cache_.size());
- }
- // we need to convert std::vector to std::vector
- for (int col_id = 0; col_id < input_cache_.size(); col_id++) {
- input_cache_[col_id].push_back(in_arr_list[col_id]);
- }
-
- // do concat_join if necessary
- std::shared_ptr in;
- if (concat_kernel_) {
- ArrayList concat_kernel_arr_list;
- for (auto i : left_key_indices_) {
- concat_kernel_arr_list.push_back(in_arr_list[i]);
- }
- RETURN_NOT_OK(concat_kernel_->Evaluate(concat_kernel_arr_list, &in));
- } else {
- in = in_arr_list[left_key_indices_[0]];
- }
-
- // we should put items into hashmap
- auto typed_array = std::dynamic_pointer_cast(in);
- auto insert_on_found = [this](int32_t i) {
- left_table_size_++;
- memo_index_to_arrayid_[i].emplace_back(cur_array_id_, cur_id_);
- };
- auto insert_on_not_found = [this](int32_t i) {
- left_table_size_++;
- memo_index_to_arrayid_.push_back({ArrayItemIndex(cur_array_id_, cur_id_)});
- };
-
- cur_id_ = 0;
- int memo_index = 0;
- if (typed_array->null_count() == 0) {
- for (; cur_id_ < typed_array->length(); cur_id_++) {
- hash_table_->GetOrInsert(typed_array->GetView(cur_id_), insert_on_found,
- insert_on_not_found, &memo_index);
- }
- } else {
- for (; cur_id_ < typed_array->length(); cur_id_++) {
- if (typed_array->IsNull(cur_id_)) {
- hash_table_->GetOrInsertNull(insert_on_found, insert_on_not_found);
- } else {
- hash_table_->GetOrInsert(typed_array->GetView(cur_id_), insert_on_found,
- insert_on_not_found, &memo_index);
- }
- }
- }
- cur_array_id_++;
- return arrow::Status::OK();
- }
-
- arrow::Status MakeResultIterator(
- std::shared_ptr schema,
- std::shared_ptr>* out) override {
- std::function& in,
- std::function ConditionCheck,
- std::shared_ptr* out)>
- eval_func;
- // prepare process next function
- std::cout << "HashMap lenghth is " << memo_index_to_arrayid_.size()
- << ", total stored items count is " << left_table_size_ << std::endl;
- switch (join_type_) {
- case 0: { /*Inner Join*/
- eval_func = [this, schema](
- const std::shared_ptr& in,
- std::function ConditionCheck,
- std::shared_ptr* out) {
- // prepare
- std::unique_ptr left_indices_builder;
- auto left_array_type = arrow::fixed_size_binary(sizeof(ArrayItemIndex));
- left_indices_builder.reset(
- new arrow::FixedSizeBinaryBuilder(left_array_type, ctx_->memory_pool()));
-
- std::unique_ptr right_indices_builder;
- right_indices_builder.reset(
- new arrow::UInt32Builder(arrow::uint32(), ctx_->memory_pool()));
-
- auto typed_array = std::dynamic_pointer_cast(in);
- for (int i = 0; i < typed_array->length(); i++) {
- if (!typed_array->IsNull(i)) {
- auto index = hash_table_->Get(typed_array->GetView(i));
- if (index != -1) {
- for (auto tmp : memo_index_to_arrayid_[index]) {
- if (ConditionCheck(tmp, i)) {
- RETURN_NOT_OK(left_indices_builder->Append((uint8_t*)&tmp));
- RETURN_NOT_OK(right_indices_builder->Append(i));
- }
- }
- }
- }
- }
- // create buffer and null_vector to FixedSizeBinaryArray
- std::shared_ptr left_arr_out;
- std::shared_ptr right_arr_out;
- RETURN_NOT_OK(left_indices_builder->Finish(&left_arr_out));
- RETURN_NOT_OK(right_indices_builder->Finish(&right_arr_out));
- auto result_schema =
- arrow::schema({arrow::field("left_indices", left_array_type),
- arrow::field("right_indices", arrow::uint32())});
- *out = arrow::RecordBatch::Make(result_schema, right_arr_out->length(),
- {left_arr_out, right_arr_out});
- return arrow::Status::OK();
- };
- } break;
- case 1: { /*Outer Join*/
- eval_func = [this, schema](
- const std::shared_ptr& in,
- std::function ConditionCheck,
- std::shared_ptr* out) {
- std::unique_ptr left_indices_builder;
- auto left_array_type = arrow::fixed_size_binary(sizeof(ArrayItemIndex));
- left_indices_builder.reset(
- new arrow::FixedSizeBinaryBuilder(left_array_type, ctx_->memory_pool()));
-
- std::unique_ptr right_indices_builder;
- right_indices_builder.reset(
- new arrow::UInt32Builder(arrow::uint32(), ctx_->memory_pool()));
-
- auto typed_array = std::dynamic_pointer_cast(in);
- for (int i = 0; i < typed_array->length(); i++) {
- if (typed_array->IsNull(i)) {
- auto index = hash_table_->GetNull();
- if (index == -1) {
- RETURN_NOT_OK(left_indices_builder->AppendNull());
- RETURN_NOT_OK(right_indices_builder->Append(i));
- } else {
- for (auto tmp : memo_index_to_arrayid_[index]) {
- if (ConditionCheck(tmp, i)) {
- RETURN_NOT_OK(left_indices_builder->Append((uint8_t*)&tmp));
- RETURN_NOT_OK(right_indices_builder->Append(i));
- }
- }
- }
- } else {
- auto index = hash_table_->Get(typed_array->GetView(i));
- if (index == -1) {
- RETURN_NOT_OK(left_indices_builder->AppendNull());
- RETURN_NOT_OK(right_indices_builder->Append(i));
- } else {
- for (auto tmp : memo_index_to_arrayid_[index]) {
- if (ConditionCheck(tmp, i)) {
- RETURN_NOT_OK(left_indices_builder->Append((uint8_t*)&tmp));
- RETURN_NOT_OK(right_indices_builder->Append(i));
- }
- }
- }
- }
- }
- // create buffer and null_vector to FixedSizeBinaryArray
- std::shared_ptr left_arr_out;
- std::shared_ptr