diff --git a/.travis.yml b/.travis.yml index 6ab1db883..d41228055 100644 --- a/.travis.yml +++ b/.travis.yml @@ -31,6 +31,7 @@ jobs: - echo ${TRAVIS_COMMIT_MESSAGE} #- if [[ ${TRAVIS_COMMIT_MESSAGE} != \[oap-native-sql\]* ]]; then travis_terminate 0 ; fi ; - sudo apt-get install cmake + - sudo apt-get install libboost-all-dev - export | grep JAVA_HOME install: - # Download spark 3.0 diff --git a/oap-data-source/arrow/pom.xml b/oap-data-source/arrow/pom.xml index 4d65be407..769411cc8 100644 --- a/oap-data-source/arrow/pom.xml +++ b/oap-data-source/arrow/pom.xml @@ -8,7 +8,7 @@ 2.12.10 2.12 3.1.0-SNAPSHOT - 0.17.0 + 0.17.0 UTF-8 UTF-8 @@ -112,40 +112,43 @@ src/test/scala - org.codehaus.mojo - build-helper-maven-plugin + org.apache.maven.plugins + maven-resources-plugin + copy-writable-vector-source generate-sources - add-source + copy-resources - - ${project.build.directory}/generated-sources/downloaded/java/ - + ${project.build.directory}/generated-sources/copied/java/ + + + ${project.basedir}/../../oap-native-sql/core/src/main/java/ + + **/ArrowWritableColumnVector.java + + + org.codehaus.mojo - exec-maven-plugin + build-helper-maven-plugin - download-writable-column-vector + add-src-1 generate-sources - exec + add-source - curl - - https://raw.githubusercontent.com/Intel-bigdata/OAP/branch-nativesql-spark-3.0.0/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java - --create-dirs - -o> - ${project.build.directory}/generated-sources/downloaded/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java - + + ${project.build.directory}/generated-sources/copied/java/ + diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala index b120fbf3f..b3cf69dcf 100644 --- a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala +++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala @@ -17,12 +17,15 @@ package com.intel.oap.spark.sql.execution.datasources.arrow +import java.net.URLDecoder + import scala.collection.JavaConverters._ import com.intel.oap.spark.sql.execution.datasources.arrow.ArrowFileFormat.UnsafeItr -import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, ArrowOptions} +import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, ArrowOptions, ExecutionMemoryAllocationListener} import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._ import org.apache.arrow.dataset.scanner.ScanOptions +import org.apache.arrow.memory.AllocationListener import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.Job @@ -71,10 +74,12 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab val sqlConf = sparkSession.sessionState.conf; val enableFilterPushDown = sqlConf.arrowFilterPushDown + val taskMemoryManager = ArrowUtils.getTaskMemoryManager() val factory = ArrowUtils.makeArrowDiscovery( - file.filePath, new ArrowOptions( + URLDecoder.decode(file.filePath, "UTF-8"), new ArrowOptions( new CaseInsensitiveStringMap( - options.asJava).asScala.toMap)) + options.asJava).asScala.toMap), + new ExecutionMemoryAllocationListener(taskMemoryManager)) // todo predicate validation / pushdown val dataset = factory.finish(); @@ -98,7 +103,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab val itr = itrList .toIterator .flatMap(itr => itr.asScala) - .map(vsr => ArrowUtils.loadVsr(vsr, file.partitionValues, partitionSchema, dataSchema)) + .map(vsr => ArrowUtils.loadVectors(vsr, file.partitionValues, partitionSchema, dataSchema)) new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]] } } diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala index 6c2431aaf..47b56f965 100644 --- a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala +++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala @@ -16,6 +16,8 @@ */ package com.intel.oap.spark.sql.execution.datasources.v2.arrow +import java.net.URLDecoder + import scala.collection.JavaConverters._ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowPartitionReaderFactory.ColumnarBatchRetainer @@ -23,6 +25,7 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._ import org.apache.arrow.dataset.scanner.ScanOptions import org.apache.spark.broadcast.Broadcast +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.PartitionedFile @@ -43,7 +46,7 @@ case class ArrowPartitionReaderFactory( options: ArrowOptions) extends FilePartitionReaderFactory { - private val batchSize = 4096 + private val batchSize = sqlConf.parquetVectorizedReaderBatchSize private val enableFilterPushDown: Boolean = sqlConf.arrowFilterPushDown override def supportColumnarReads(partition: InputPartition): Boolean = true @@ -56,7 +59,9 @@ case class ArrowPartitionReaderFactory( override def buildColumnarReader( partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = { val path = partitionedFile.filePath - val factory = ArrowUtils.makeArrowDiscovery(path, options) + val taskMemoryManager = ArrowUtils.getTaskMemoryManager() + val factory = ArrowUtils.makeArrowDiscovery(URLDecoder.decode(path, "UTF-8"), options, + new ExecutionMemoryAllocationListener(taskMemoryManager)) val dataset = factory.finish() val filter = if (enableFilterPushDown) { ArrowFilters.translateFilters(ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema)) @@ -79,7 +84,7 @@ case class ArrowPartitionReaderFactory( val batchItr = vsrItrList .toIterator .flatMap(itr => itr.asScala) - .map(vsr => ArrowUtils.loadVsr(vsr, partitionedFile.partitionValues, + .map(bundledVectors => ArrowUtils.loadVectors(bundledVectors, partitionedFile.partitionValues, readPartitionSchema, readDataSchema)) new PartitionReader[ColumnarBatch] { diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala index 7702f0283..ed4e4b183 100644 --- a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala +++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowTable.scala @@ -16,6 +16,7 @@ */ package com.intel.oap.spark.sql.execution.datasources.v2.arrow +import org.apache.arrow.memory.AllocationListener import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession diff --git a/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ExecutionMemoryAllocationListener.scala b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ExecutionMemoryAllocationListener.scala new file mode 100644 index 000000000..4e2eb96bc --- /dev/null +++ b/oap-data-source/arrow/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ExecutionMemoryAllocationListener.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.oap.spark.sql.execution.datasources.v2.arrow + +import org.apache.arrow.memory.{AllocationListener, OutOfMemoryException} + +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} + +class ExecutionMemoryAllocationListener(mm: TaskMemoryManager) + extends MemoryConsumer(mm, mm.pageSizeBytes(), MemoryMode.OFF_HEAP) with AllocationListener { + + + override def onPreAllocation(size: Long): Unit = { + if (size == 0) { + return + } + val granted = acquireMemory(size) + if (granted < size) { + throw new OutOfMemoryException("Failed allocating spark execution memory. Acquired: " + + size + ", granted: " + granted) + } + } + + override def onRelease(size: Long): Unit = { + freeMemory(size) + } + + override def spill(size: Long, trigger: MemoryConsumer): Long = { + // not spillable + 0L + } +} diff --git a/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index 101dff283..589fb5e93 100644 --- a/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/oap-data-source/arrow/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -17,28 +17,38 @@ package org.apache.spark.sql.execution.datasources.v2.arrow import java.net.URI -import java.util.TimeZone +import java.util.{TimeZone, UUID} import scala.collection.JavaConverters._ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector import org.apache.arrow.dataset.file.{FileSystem, SingleFileDatasetFactory} -import org.apache.arrow.memory.BaseAllocator -import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.dataset.scanner.ScanTask +import org.apache.arrow.memory.{AllocationListener, BaseAllocator} +import org.apache.arrow.vector.FieldVector +import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID import org.apache.arrow.vector.types.pojo.Schema import org.apache.hadoop.fs.FileStatus +import org.apache.spark.TaskContext +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ColumnVectorUtils, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch object ArrowUtils { + + def getTaskMemoryManager(): TaskMemoryManager = { + TaskContext.get().taskMemoryManager() + } + def readSchema(file: FileStatus, options: CaseInsensitiveStringMap): Option[StructType] = { val factory: SingleFileDatasetFactory = - makeArrowDiscovery(file.getPath.toString, new ArrowOptions(options.asScala.toMap)) + makeArrowDiscovery(file.getPath.toString, new ArrowOptions(options.asScala.toMap), + AllocationListener.NOOP) val schema = factory.inspect() try { Option(org.apache.spark.sql.util.ArrowUtils.fromArrowSchema(schema)) @@ -50,13 +60,17 @@ object ArrowUtils { def readSchema(files: Seq[FileStatus], options: CaseInsensitiveStringMap): Option[StructType] = readSchema(files.toList.head, options) // todo merge schema - def makeArrowDiscovery(file: String, options: ArrowOptions): SingleFileDatasetFactory = { + def makeArrowDiscovery(file: String, options: ArrowOptions, + al: AllocationListener): SingleFileDatasetFactory = { val format = getFormat(options).getOrElse(throw new IllegalStateException) val fs = getFs(options).getOrElse(throw new IllegalStateException) - + val parent = defaultAllocator() + val allocator = parent + .newChildAllocator("Spark Managed Allocator - " + UUID.randomUUID().toString, al, + 0, parent.getLimit) val factory = new SingleFileDatasetFactory( - org.apache.spark.sql.util.ArrowUtils.rootAllocator, + allocator, format, fs, rewriteFilePath(file)) @@ -82,12 +96,14 @@ object ArrowUtils { org.apache.spark.sql.util.ArrowUtils.toArrowSchema(t, TimeZone.getDefault.getID) } - def loadVsr(vsr: VectorSchemaRoot, partitionValues: InternalRow, - partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = { - val fvs = getDataVectors(vsr, dataSchema) + def loadVectors(bundledVectors: ScanTask.ArrowBundledVectors, partitionValues: InternalRow, + partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = { + val rowCount: Int = getRowCount(bundledVectors) + val dataVectors = getDataVectors(bundledVectors, dataSchema) + val dictionaryVectors = getDictionaryVectors(bundledVectors, dataSchema) - val rowCount = vsr.getRowCount - val vectors = ArrowWritableColumnVector.loadColumns(rowCount, fvs.asJava) + val vectors = ArrowWritableColumnVector.loadColumns(rowCount, dataVectors.asJava, + dictionaryVectors.asJava) val partitionColumns = ArrowWritableColumnVector.allocateColumns(rowCount, partitionSchema) (0 until partitionColumns.length).foreach(i => { ColumnVectorUtils.populate(partitionColumns(i), partitionValues, i) @@ -98,16 +114,23 @@ object ArrowUtils { batch } - def rootAllocator(): BaseAllocator = { + def defaultAllocator(): BaseAllocator = { org.apache.spark.sql.util.ArrowUtils.rootAllocator } - private def getDataVectors(vsr: VectorSchemaRoot, + private def getRowCount(bundledVectors: ScanTask.ArrowBundledVectors) = { + val valueVectors = bundledVectors.valueVectors + val rowCount = valueVectors.getRowCount + rowCount + } + + private def getDataVectors(bundledVectors: ScanTask.ArrowBundledVectors, dataSchema: StructType): List[FieldVector] = { // TODO Deprecate following (bad performance maybe brought). // TODO Assert vsr strictly matches dataSchema instead. + val valueVectors = bundledVectors.valueVectors dataSchema.map(f => { - val vector = vsr.getVector(f.name) + val vector = valueVectors.getVector(f.name) if (vector == null) { throw new IllegalStateException("Error: no vector named " + f.name + " in record bach") } @@ -115,6 +138,30 @@ object ArrowUtils { }).toList } + private def getDictionaryVectors(bundledVectors: ScanTask.ArrowBundledVectors, + dataSchema: StructType): List[FieldVector] = { + val valueVectors = bundledVectors.valueVectors + val dictionaryVectorMap = bundledVectors.dictionaryVectors + + val fieldNameToDictionaryEncoding = valueVectors.getSchema.getFields.asScala.map(f => { + f.getName -> f.getDictionary + }).toMap + + val dictionaryVectorsWithNulls = dataSchema.map(f => { + val de = fieldNameToDictionaryEncoding(f.name) + + Option(de) match { + case None => null + case _ => + if (de.getIndexType.getTypeID != ArrowTypeID.Int) { + throw new IllegalArgumentException("Wrong index type: " + de.getIndexType) + } + dictionaryVectorMap.get(de.getId).getVector + } + }).toList + dictionaryVectorsWithNulls + } + private def getFormat( options: ArrowOptions): Option[org.apache.arrow.dataset.file.FileFormat] = { Option(options.originalFormat match { diff --git a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala index 402c76a88..976475836 100644 --- a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala +++ b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTPCHBasedTest.scala @@ -21,6 +21,7 @@ import java.util.concurrent.{Executors, TimeUnit} import com.intel.oap.spark.sql.DataFrameReaderImplicits._ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions +import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.execution.datasources.v2.arrow.ArrowUtils import org.apache.spark.sql.internal.SQLConf @@ -38,6 +39,13 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession { private val orders = prefix + tpchFolder + "/orders" private val nation = prefix + tpchFolder + "/nation" + + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + conf.set("spark.memory.offHeap.size", String.valueOf(128 * 1024 * 1024)) + conf + } + ignore("tpch lineitem - desc") { val frame = spark.read .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet") @@ -48,6 +56,16 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession { spark.sql("describe lineitem").show() } + ignore("tpch part - special characters in path") { + val frame = spark.read + .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet") + .option(ArrowOptions.KEY_FILESYSTEM, "hdfs") + .arrow(part) + frame.createOrReplaceTempView("part") + + spark.sql("select * from part limit 100").show() + } + ignore("tpch lineitem - read partition values") { val frame = spark.read .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet") @@ -110,7 +128,7 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession { val aPrev = System.currentTimeMillis() (0 until iterations).foreach(_ => { // scalastyle:off println - println(ArrowUtils.rootAllocator().getAllocatedMemory()) + println(ArrowUtils.defaultAllocator().getAllocatedMemory()) // scalastyle:on println spark.sql("select\n\tsum(l_extendedprice * l_discount) as revenue\n" + "from\n\tlineitem_arrow\n" + @@ -258,7 +276,7 @@ class ArrowDataSourceTPCHBasedTest extends QueryTest with SharedSparkSession { }) Executors.newSingleThreadScheduledExecutor().scheduleWithFixedDelay(() => { println("[org.apache.spark.sql.util.ArrowUtils.rootAllocator] " + - "Allocated memory amount: " + ArrowUtils.rootAllocator().getAllocatedMemory) + "Allocated memory amount: " + ArrowUtils.defaultAllocator().getAllocatedMemory) println("[com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector.allocator] " + "Allocated memory amount: " + com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector.allocator.getAllocatedMemory) }, 0L, 100L, TimeUnit.MILLISECONDS) diff --git a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala index efaa81423..a648438cd 100644 --- a/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala +++ b/oap-data-source/arrow/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala @@ -22,16 +22,26 @@ import java.lang.management.ManagementFactory import com.intel.oap.spark.sql.DataFrameReaderImplicits._ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions +import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector import com.sun.management.UnixOperatingSystemMXBean import org.apache.commons.io.FileUtils +import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.execution.datasources.v2.arrow.ArrowUtils import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} class ArrowDataSourceTest extends QueryTest with SharedSparkSession { private val parquetFile1 = "parquet-1.parquet" private val parquetFile2 = "parquet-2.parquet" + private val parquetFile3 = "parquet-3.parquet" + + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + conf.set("spark.memory.offHeap.size", String.valueOf(1 * 1024 * 1024)) + conf + } override def beforeAll(): Unit = { super.beforeAll() @@ -47,11 +57,19 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { .toDS()) .repartition(1) .write.parquet(ArrowDataSourceTest.locateResourcePath(parquetFile2)) + + spark.read + .json(Seq("{\"col1\": \"apple\", \"col2\": 100}", "{\"col1\": \"pear\", \"col2\": 200}", + "{\"col1\": \"apple\", \"col2\": 300}") + .toDS()) + .repartition(1) + .write.parquet(ArrowDataSourceTest.locateResourcePath(parquetFile3)) } override def afterAll(): Unit = { delete(ArrowDataSourceTest.locateResourcePath(parquetFile1)) delete(ArrowDataSourceTest.locateResourcePath(parquetFile2)) + delete(ArrowDataSourceTest.locateResourcePath(parquetFile3)) super.afterAll() } @@ -64,7 +82,7 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { .arrow(path)) } - test("simple SQL query on parquet file") { + test("simple SQL query on parquet file - 1") { val path = ArrowDataSourceTest.locateResourcePath(parquetFile1) val frame = spark.read .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet") @@ -76,6 +94,27 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { verifyParquet(spark.sql("select col from ptab where col is not null or col is null")) } + test("simple SQL query on parquet file - 2") { + val path = ArrowDataSourceTest.locateResourcePath(parquetFile3) + val frame = spark.read + .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet") + .option(ArrowOptions.KEY_FILESYSTEM, "hdfs") + .arrow(path) + frame.createOrReplaceTempView("ptab") + val sqlFrame = spark.sql("select * from ptab") + assert( + sqlFrame.schema === + StructType(Seq(StructField("col1", StringType), StructField("col2", LongType)))) + val rows = sqlFrame.collect() + assert(rows(0).get(0) == "apple") + assert(rows(0).get(1) == 100) + assert(rows(1).get(0) == "pear") + assert(rows(1).get(1) == 200) + assert(rows(2).get(0) == "apple") + assert(rows(2).get(1) == 300) + assert(rows.length === 3) + } + test("simple SQL query on parquet file with pushed filters") { val path = ArrowDataSourceTest.locateResourcePath(parquetFile1) val frame = spark.read @@ -162,6 +201,11 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { def delete(path: String): Unit = { FileUtils.forceDelete(new File(path)) } + + def closeAllocators(): Unit = { + ArrowUtils.defaultAllocator().close() + ArrowWritableColumnVector.allocator.close() + } } object ArrowDataSourceTest { diff --git a/oap-native-sql/README.md b/oap-native-sql/README.md index 6ba286707..82ffc6bcf 100644 --- a/oap-native-sql/README.md +++ b/oap-native-sql/README.md @@ -1,161 +1,44 @@ -# SparkColumnarPlugin +# Spark Native SQL Engine -## Contents - -- [Introduction](#introduction) -- [Installation](#installation) -- [Benchmark](#benchmark) -- [Contact](#contact) +A Native Engine for Spark SQL with vectorze SIMD optimizations ## Introduction -### Key concepts of this project: -1. Using Apache Arrow as column vector format as intermediate data among spark operator. -2. Enable Apache Arrow native readers for Parquet and other formats. -3. Leverage Apache Arrow Gandiva/Compute to evaluate columnar operator expressions. -4. (WIP)New native columnar shuffle operator with efficient compression support. - -![Overview](/oap-native-sql/resource/Native_SQL_Engine_Intro.jpg) - -## Installation - -For detailed testing scripts, please refer to [solution guide](https://github.com/Intel-bigdata/Solution_navigator/tree/master/nativesql) - -### Installation option 1: For evaluation, simple and fast - -#### install spark 3.0.0 or above - -[spark download](https://spark.apache.org/downloads.html) - -Remove original Arrow Jars inside Spark assemply folder -``` shell -yes | rm assembly/target/scala-2.12/jars/arrow-format-0.15.1.jar -yes | rm assembly/target/scala-2.12/jars/arrow-vector-0.15.1.jar -yes | rm assembly/target/scala-2.12/jars/arrow-memory-0.15.1.jar -``` - -#### install arrow 0.17.0 - -``` -git clone https://github.com/apache/arrow && cd arrow & git checkout arrow-0.17.0 -vim ci/conda_env_gandiva.yml -clangdev=7 -llvmdev=7 - -conda create -y -n pyarrow-dev -c conda-forge \ - --file ci/conda_env_unix.yml \ - --file ci/conda_env_cpp.yml \ - --file ci/conda_env_python.yml \ - --file ci/conda_env_gandiva.yml \ - compilers \ - python=3.7 \ - pandas -conda activate pyarrow-dev -``` - -#### Build native-sql cpp - -``` shell -git clone https://github.com/Intel-bigdata/OAP.git -cd OAP && git checkout branch-nativesql-spark-3.0.0 -cd oap-native-sql -cp cpp/src/resources/libhdfs.so ${HADOOP_HOME}/lib/native/ -cp cpp/src/resources/libprotobuf.so.13 /usr/lib64/ -``` - -Download spark-columnar-core-1.0-jar-with-dependencies.jar to local, add classPath to spark.driver.extraClassPath and spark.executor.extraClassPath -``` shell -Internal Location: vsr602://mnt/nvme2/chendi/000000/spark-columnar-core-1.0-jar-with-dependencies.jar -``` - -Download spark-sql_2.12-3.1.0-SNAPSHOT.jar to ${SPARK_HOME}/assembly/target/scala-2.12/jars/spark-sql_2.12-3.1.0-SNAPSHOT.jar -``` shell -Internal Location: vsr602://mnt/nvme2/chendi/000000/spark-sql_2.12-3.1.0-SNAPSHOT.jar -``` - -### Installation option 2: For contribution, Patch and build - -#### install spark 3.0.0 or above - -Please refer this link to install Spark. -[Apache Spark Installation](/oap-native-sql/resource/SparkInstallation.md) - -Remove original Arrow Jars inside Spark assemply folder -``` shell -yes | rm assembly/target/scala-2.12/jars/arrow-format-0.15.1.jar -yes | rm assembly/target/scala-2.12/jars/arrow-vector-0.15.1.jar -yes | rm assembly/target/scala-2.12/jars/arrow-memory-0.15.1.jar -``` - -#### install arrow 0.17.0 +![Overview](/oap-native-sql/resource/nativesql_arch.png) -Please refer this markdown to install Apache Arrow and Gandiva. -[Apache Arrow Installation](/oap-native-sql/resource/ApacheArrowInstallation.md) +Spark SQL works very well with structured row-based data. It used WholeStageCodeGen to improve the performance by Java JIT code. However Java JIT is usually not working very well on utilizing latest SIMD instructions, espeically under complicated queries. [Apache Arrow](https://arrow.apache.org/) provided CPU-cahce friendly columnar in-memory layout, its SIMD optimized kernels and LLVM based SQL engine Gandiva are also very efficient. Native SQL Engine used these technoligies and brought better performance to Spark SQL. -#### compile and install oap-native-sql +## Key Features -##### Install Googletest and Googlemock +### Apache Arrow formated intermediate data among Spark operator -``` shell -yum install gtest-devel -yum install gmock -``` +![Overview](/oap-native-sql/resource/columnar.png) -##### Build this project +With [Spark 27396](https://issues.apache.org/jira/browse/SPARK-27396) its possible to pass a RDD of Columnarbatch to operators. We implementd this API with Arrow columnar format. -``` shell -git clone https://github.com/Intel-bigdata/OAP.git -cd OAP && git checkout branch-nativesql-spark-3.0.0 -cd oap-native-sql -cd cpp/ -mkdir build/ -cd build/ -cmake .. -DTESTS=ON -make -j -make install -#when deploying on multiple node, make sure all nodes copied libhdfs.so and libprotobuf.so.13 -``` +### Apache Arrow based Native Readers for Paruqet and other formats -``` shell -cd SparkColumnarPlugin/core/ -mvn clean package -DskipTests -``` -### Additonal Notes -[Notes for Installation Issues](/oap-native-sql/resource/InstallationNotes.md) - +![Overview](/oap-native-sql/resource/dataset.png) -## Spark Configuration +A native parquet reader was developed to speed up the data loading. it's based on Apache Arrow Dataset. For details please check [Arrow Data Source](../oap-data-source/README.md) -Add below configuration to spark-defaults.conf - -``` -##### Columnar Process Configuration +### Apache Arrow Compute/Gandiva based operators -spark.sql.parquet.columnarReaderBatchSize 4096 -spark.sql.sources.useV1SourceList avro -spark.sql.join.preferSortMergeJoin false -spark.sql.extensions com.intel.sparkColumnarPlugin.ColumnarPlugin +![Overview](/oap-native-sql/resource/kernel.png) -spark.driver.extraClassPath ${PATH_TO_OAP_NATIVE_SQL}/core/target/spark-columnar-core-1.0-jar-with-dependencies.jar -spark.executor.extraClassPath ${PATH_TO_OAP_NATIVE_SQL}/core/target/spark-columnar-core-1.0-jar-with-dependencies.jar +We implemented common operatos based on Apache Arrow Compute and Gandiva. The SQL expression was compiled to one expression tree with protobuf and passed to native kernels. The native kernels will then evaluate the these expressions based on the input columnar batch. -###### -``` -## Benchmark +### Native Columnar Shuffle Operator with efficient compression support -For initial microbenchmark performance, we add 10 fields up with spark, data size is 200G data +![Overview](/oap-native-sql/resource/shuffle.png) -![Performance](/oap-native-sql/resource/performance.png) +We implemented columnar shuffle to improve the shuffle performance. With the columnar layout we could do very efficient data compression for different data format. -## Coding Style +## Testing -* For Java code, we used [google-java-format](https://github.com/google/google-java-format) -* For Scala code, we used [Spark Scala Format](https://github.com/apache/spark/blob/master/dev/.scalafmt.conf), please use [scalafmt](https://github.com/scalameta/scalafmt) or run ./scalafmt for scala codes format -* For Cpp codes, we used Clang-Format, check on this link [google-vim-codefmt](https://github.com/google/vim-codefmt) for details. +Check out the detailed installation/testing guide(/oap-native-sql/resource/installation.md) for quick testing -## contact +## Contact chendi.xue@intel.com -yuan.zhou@intel.com -binwei.yang@intel.com -jian.zhang@intel.com +binwei.yang@intel.com \ No newline at end of file diff --git a/oap-native-sql/core/pom.xml b/oap-native-sql/core/pom.xml index abb37bedf..9d452b94b 100644 --- a/oap-native-sql/core/pom.xml +++ b/oap-native-sql/core/pom.xml @@ -17,6 +17,7 @@ 0.17.0 + ../cpp/ ../cpp/build/releases/ @@ -33,18 +34,39 @@ ${spark.version} provided + + com.intel.spark + spark-core_${scala.binary.version} + ${spark.version} + test-jar + test + com.intel.spark spark-catalyst_${scala.binary.version} ${spark.version} provided + + com.intel.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + test-jar + test + com.intel.spark spark-sql_${scala.binary.version} ${spark.version} provided + + com.intel.spark + spark-sql_${scala.binary.version} + ${spark.version} + test-jar + test + com.intel.spark spark-network-shuffle_${scala.binary.version} @@ -83,6 +105,22 @@ + + org.apache.arrow + arrow-dataset + ${arrow.version} + + + io.netty + netty-common + + + io.netty + netty-buffer + + + + org.apache.arrow.gandiva arrow-gandiva @@ -136,19 +174,19 @@ com.fasterxml.jackson.core jackson-core - 2.9.8 + 2.10.0 test com.fasterxml.jackson.core jackson-annotations - 2.9.8 + 2.10.0 test com.fasterxml.jackson.core jackson-databind - 2.9.8 + 2.10.0 test @@ -157,14 +195,28 @@ ${cpp.build.dir} - - **/libspark_columnar_jni.so - target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + exec-maven-plugin + org.codehaus.mojo + 1.6.0 + + + Build cpp + generate-resources + + exec + + + ${cpp.dir}/compile.sh + + + + net.alchim31.maven scala-maven-plugin @@ -278,6 +330,18 @@ + + org.scalatest + scalatest-maven-plugin + + + test + + test + + + + diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java index 7a81f3640..b55dcfa93 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/VectorizedParquetArrowReader.java @@ -47,8 +47,8 @@ import org.slf4j.LoggerFactory; public class VectorizedParquetArrowReader extends VectorizedParquetRecordReader { - - private static final Logger LOG = LoggerFactory.getLogger(VectorizedParquetArrowReader.class); + private static final Logger LOG = + LoggerFactory.getLogger(VectorizedParquetArrowReader.class); private ParquetReader reader = null; private String path; private long capacity; @@ -58,45 +58,40 @@ public class VectorizedParquetArrowReader extends VectorizedParquetRecordReader private int numLoaded = 0; private int numReaded = 0; private long totalLength; + private String tmp_dir; private ArrowRecordBatch next_batch; - //private ColumnarBatch last_columnar_batch; + // private ColumnarBatch last_columnar_batch; private StructType sourceSchema; private StructType readDataSchema; private Schema schema = null; - public VectorizedParquetArrowReader( - String path, - ZoneId convertTz, - boolean useOffHeap, - int capacity, - StructType sourceSchema, - StructType readDataSchema - ) { + public VectorizedParquetArrowReader(String path, ZoneId convertTz, boolean useOffHeap, + int capacity, StructType sourceSchema, StructType readDataSchema, String tmp_dir) { super(convertTz, useOffHeap, capacity); this.capacity = capacity; this.path = path; + this.tmp_dir = tmp_dir; this.sourceSchema = sourceSchema; this.readDataSchema = readDataSchema; } @Override - public void initBatch(StructType partitionColumns, InternalRow partitionValues) { - } + public void initBatch(StructType partitionColumns, InternalRow partitionValues) {} @Override public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) - throws IOException, InterruptedException, UnsupportedOperationException { + throws IOException, InterruptedException, UnsupportedOperationException { final ParquetInputSplit parquetInputSplit = toParquetSplit(inputSplit); final Configuration configuration = ContextUtil.getConfiguration(taskAttemptContext); initialize(parquetInputSplit, configuration); } public void initialize(ParquetInputSplit inputSplit, Configuration configuration) - throws IOException, InterruptedException, UnsupportedOperationException { + throws IOException, InterruptedException, UnsupportedOperationException { this.totalLength = inputSplit.getLength(); int ordinal = 0; @@ -104,7 +99,7 @@ public void initialize(ParquetInputSplit inputSplit, Configuration configuration int[] column_indices = new int[readDataSchema.size()]; List targetSchema = Arrays.asList(readDataSchema.names()); - for (String fieldName: sourceSchema.names()) { + for (String fieldName : sourceSchema.names()) { if (targetSchema.contains(fieldName)) { column_indices[cur_index++] = ordinal; } @@ -116,16 +111,17 @@ public void initialize(ParquetInputSplit inputSplit, Configuration configuration if (uriPath.contains("hdfs")) { uriPath = this.path + "?user=root&replication=1"; } - ParquetInputSplit split = (ParquetInputSplit)inputSplit; - LOG.info("ParquetReader uri path is " + uriPath + ", rowGroupIndices is " + Arrays.toString(rowGroupIndices) + ", column_indices is " + Arrays.toString(column_indices)); - this.reader = new ParquetReader(uriPath, - split.getStart(), split.getEnd(), column_indices, capacity, ArrowWritableColumnVector.getNewAllocator()); + ParquetInputSplit split = (ParquetInputSplit) inputSplit; + LOG.info("ParquetReader uri path is " + uriPath + ", rowGroupIndices is " + + Arrays.toString(rowGroupIndices) + ", column_indices is " + + Arrays.toString(column_indices)); + this.reader = new ParquetReader(uriPath, split.getStart(), split.getEnd(), + column_indices, capacity, ArrowWritableColumnVector.getNewAllocator(), tmp_dir); } @Override - public void initialize(String path, List columns) throws IOException, - UnsupportedOperationException { - } + public void initialize(String path, List columns) + throws IOException, UnsupportedOperationException {} @Override public boolean nextKeyValue() throws IOException { @@ -155,7 +151,7 @@ public Object getCurrentValue() { } numReaded += lastReadLength; ArrowWritableColumnVector[] columnVectors = - ArrowWritableColumnVector.loadColumns(next_batch.getLength(), schema, next_batch); + ArrowWritableColumnVector.loadColumns(next_batch.getLength(), schema, next_batch); next_batch.close(); return new ColumnarBatch(columnVectors, next_batch.getLength()); } @@ -170,10 +166,11 @@ public void close() throws IOException { @Override public float getProgress() { - return (float) (numReaded/totalLength); + return (float) (numReaded / totalLength); } - private int[] filterRowGroups(ParquetInputSplit parquetInputSplit, Configuration configuration) throws IOException { + private int[] filterRowGroups(ParquetInputSplit parquetInputSplit, + Configuration configuration) throws IOException { final long[] rowGroupOffsets = parquetInputSplit.getRowGroupOffsets(); if (rowGroupOffsets != null) { throw new UnsupportedOperationException(); @@ -184,30 +181,34 @@ private int[] filterRowGroups(ParquetInputSplit parquetInputSplit, Configuration final List filteredRowGroups; final List unfilteredRowGroups; - try (ParquetFileReader reader = ParquetFileReader.open( - HadoopInputFile.fromPath(path, configuration), createOptions(parquetInputSplit, configuration))) { + try (ParquetFileReader reader = + ParquetFileReader.open(HadoopInputFile.fromPath(path, configuration), + createOptions(parquetInputSplit, configuration))) { unfilteredRowGroups = reader.getFooter().getBlocks(); filteredRowGroups = reader.getRowGroups(); } final int[] acc = {0}; - final Map dict = unfilteredRowGroups.stream() - .collect(Collectors.toMap(BlockMetaDataWrapper::wrap, b -> acc[0]++)); + final Map dict = unfilteredRowGroups.stream().collect( + Collectors.toMap(BlockMetaDataWrapper::wrap, b -> acc[0]++)); return filteredRowGroups.stream() - .map(BlockMetaDataWrapper::wrap) - .map(b -> { - if (!dict.containsKey(b)) { - // This should not happen - throw new IllegalStateException("Unrecognizable filtered row group: " + b); - } - return dict.get(b); - }).mapToInt(n -> n).toArray(); + .map(BlockMetaDataWrapper::wrap) + .map(b -> { + if (!dict.containsKey(b)) { + // This should not happen + throw new IllegalStateException("Unrecognizable filtered row group: " + b); + } + return dict.get(b); + }) + .mapToInt(n -> n) + .toArray(); } - private ParquetReadOptions createOptions(ParquetInputSplit split, Configuration configuration) { + private ParquetReadOptions createOptions( + ParquetInputSplit split, Configuration configuration) { return HadoopReadOptions.builder(configuration) - .withRange(split.getStart(), split.getEnd()) - .build(); + .withRange(split.getStart(), split.getEnd()) + .build(); } private ParquetInputSplit toParquetSplit(InputSplit split) throws IOException { @@ -215,11 +216,12 @@ private ParquetInputSplit toParquetSplit(InputSplit split) throws IOException { return (ParquetInputSplit) split; } else { throw new IllegalArgumentException( - "Invalid split (not a ParquetInputSplit): " + split); + "Invalid split (not a ParquetInputSplit): " + split); } } - // ID for BlockMetaData, to prevent from resulting in mutable BlockMetaData instances after being filtered + // ID for BlockMetaData, to prevent from resulting in mutable BlockMetaData instances + // after being filtered private static class BlockMetaDataWrapper { private BlockMetaData m; @@ -233,8 +235,10 @@ public static BlockMetaDataWrapper wrap(BlockMetaData m) { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; BlockMetaDataWrapper that = (BlockMetaDataWrapper) o; return equals(m, that.m); } @@ -252,6 +256,4 @@ private int hash(BlockMetaData m) { return Objects.hash(m.getStartingPos()); } } - } - diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java index a0d5530bb..80af1d356 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReader.java @@ -36,7 +36,6 @@ /** Parquet Reader Class. */ public class ParquetReader implements AutoCloseable { - /** reference to native reader instance. */ private long nativeInstanceId; @@ -56,14 +55,9 @@ public class ParquetReader implements AutoCloseable { * @param allocator A BufferAllocator reference. * @throws IOException throws io exception in case of native failure. */ - public ParquetReader( - String path, - int[] rowGroupIndices, - int[] columnIndices, - long batchSize, - BufferAllocator allocator) - throws IOException { - this.jniWrapper = new ParquetReaderJniWrapper(); + public ParquetReader(String path, int[] rowGroupIndices, int[] columnIndices, + long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException { + this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir); this.allocator = allocator; this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize); jniWrapper.nativeInitParquetReader(nativeInstanceId, columnIndices, rowGroupIndices); @@ -80,18 +74,13 @@ public ParquetReader( * @param allocator A BufferAllocator reference. * @throws IOException throws io exception in case of native failure. */ - public ParquetReader( - String path, - long startPos, - long endPos, - int[] columnIndices, - long batchSize, - BufferAllocator allocator) - throws IOException { - this.jniWrapper = new ParquetReaderJniWrapper(); + public ParquetReader(String path, long startPos, long endPos, int[] columnIndices, + long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException { + this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir); this.allocator = allocator; this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize); - jniWrapper.nativeInitParquetReader2(nativeInstanceId, columnIndices, startPos, endPos); + jniWrapper.nativeInitParquetReader2( + nativeInstanceId, columnIndices, startPos, endPos); } /** @@ -103,10 +92,9 @@ public ParquetReader( public Schema getSchema() throws IOException { byte[] schemaBytes = jniWrapper.nativeGetSchema(nativeInstanceId); - try (MessageChannelReader schemaReader = - new MessageChannelReader( - new ReadChannel(new ByteArrayReadableSeekableByteChannel(schemaBytes)), allocator)) { - + try (MessageChannelReader schemaReader = new MessageChannelReader( + new ReadChannel(new ByteArrayReadableSeekableByteChannel(schemaBytes)), + allocator)) { MessageResult result = schemaReader.readNext(); if (result == null) { throw new IOException("Unexpected end of input. Missing schema."); @@ -123,7 +111,8 @@ public Schema getSchema() throws IOException { * @throws IOException throws io exception in case of native failure */ public ArrowRecordBatch readNext() throws IOException { - ArrowRecordBatchBuilder recordBatchBuilder = jniWrapper.nativeReadNext(nativeInstanceId); + ArrowRecordBatchBuilder recordBatchBuilder = + jniWrapper.nativeReadNext(nativeInstanceId); if (recordBatchBuilder == null) { return null; } diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java index f52951ddc..efc4d6ead 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReaderJniWrapper.java @@ -22,13 +22,11 @@ import java.io.IOException; - /** Wrapper for Parquet Reader native API. */ public class ParquetReaderJniWrapper { - /** Construct a Jni Instance. */ - ParquetReaderJniWrapper() throws IOException { - JniUtils.getInstance(); + ParquetReaderJniWrapper(String tmp_dir) throws IOException { + JniUtils.getInstance(tmp_dir); } /** @@ -39,7 +37,8 @@ public class ParquetReaderJniWrapper { * @return long id of the parquet reader instance * @throws IOException throws exception in case of any io exception in native codes */ - public native long nativeOpenParquetReader(String path, long batchSize) throws IOException; + public native long nativeOpenParquetReader(String path, long batchSize) + throws IOException; /** * Init a parquet file reader by specifying columns and rowgroups. @@ -49,8 +48,8 @@ public class ParquetReaderJniWrapper { * @param rowGroupIndices a array of indexes indicate which row groups to be read * @throws IOException throws exception in case of any io exception in native codes */ - public native void nativeInitParquetReader(long id, int[] columnIndices, int[] rowGroupIndices) - throws IOException; + public native void nativeInitParquetReader( + long id, int[] columnIndices, int[] rowGroupIndices) throws IOException; /** * Init a parquet file reader by specifying columns and rowgroups. diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java index 0a46b8deb..67f6a875b 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/AdaptorReferenceManager.java @@ -17,13 +17,11 @@ package com.intel.sparkColumnarPlugin.vectorized; +import io.netty.buffer.ArrowBuf; import java.io.IOException; -import java.lang.UnsupportedOperationException; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.OwnershipTransferResult; -import org.apache.arrow.memory.ReferenceManager; +import org.apache.arrow.memory.*; import org.apache.arrow.util.Preconditions; import io.netty.buffer.ArrowBuf; @@ -31,8 +29,8 @@ import org.slf4j.LoggerFactory; /** - * A simple reference manager implementation for memory allocated by native - * code. The underlying memory will be released when reference count reach zero. + * A simple reference manager implementation for memory allocated by native code. The underlying + * memory will be released when reference count reach zero. */ public class AdaptorReferenceManager implements ReferenceManager { private native void nativeRelease(long nativeMemoryHolder); @@ -42,10 +40,14 @@ public class AdaptorReferenceManager implements ReferenceManager { private long nativeMemoryHolder; private int size = 0; + // Required by netty dependencies, but is never used. + private BaseAllocator allocator; + AdaptorReferenceManager(long nativeMemoryHolder, int size) throws IOException { JniUtils.getInstance(); this.nativeMemoryHolder = nativeMemoryHolder; this.size = size; + this.allocator = new RootAllocator(0); } @Override @@ -60,7 +62,8 @@ public boolean release() { @Override public boolean release(int decrement) { - Preconditions.checkState(decrement >= 1, "ref count decrement should be greater than or equal to 1"); + Preconditions.checkState( + decrement >= 1, "ref count decrement should be greater than or equal to 1"); // decrement the ref count final int refCnt; synchronized (this) { @@ -104,13 +107,14 @@ public ArrowBuf deriveBuffer(ArrowBuf sourceBuffer, long index, long length) { } @Override - public OwnershipTransferResult transferOwnership(ArrowBuf sourceBuffer, BufferAllocator targetAllocator) { - throw new UnsupportedOperationException(); + public OwnershipTransferResult transferOwnership( + ArrowBuf sourceBuffer, BufferAllocator targetAllocator) { + return NO_OP.transferOwnership(sourceBuffer, targetAllocator); } @Override public BufferAllocator getAllocator() { - return null; + return allocator; } @Override diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowCompressedStreamReader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowCompressedStreamReader.java new file mode 100644 index 000000000..f5fbf52f4 --- /dev/null +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowCompressedStreamReader.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.sparkColumnarPlugin.vectorized; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; + +import java.io.IOException; +import java.io.InputStream; +import java.util.*; + +/** + * This class reads from an input stream containing compressed buffers and produces + * ArrowRecordBatches. + */ +public class ArrowCompressedStreamReader extends ArrowStreamReader { + + public ArrowCompressedStreamReader(InputStream in, BufferAllocator allocator) { + super(in, allocator); + } + + protected void initialize() throws IOException { + Schema originalSchema = readSchema(); + List fields = new ArrayList<>(); + List vectors = new ArrayList<>(); + Map dictionaries = new HashMap<>(); + + // Convert fields with dictionaries to have the index type + for (Field field : originalSchema.getFields()) { + Field updated = DictionaryUtility.toMemoryFormat(field, allocator, dictionaries); + fields.add(updated); + vectors.add(updated.createVector(allocator)); + } + Schema schema = new Schema(fields, originalSchema.getCustomMetadata()); + + this.root = new VectorSchemaRoot(schema, vectors, 0); + this.loader = new CompressedVectorLoader(root); + this.dictionaries = Collections.unmodifiableMap(dictionaries); + } + + @Override + protected void loadRecordBatch(ArrowRecordBatch batch) { + try { + ((CompressedVectorLoader) loader).loadCompressed(batch); + } finally { + batch.close(); + } + } +} diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java index c0f6e1cf0..bb876ac3f 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ArrowWritableColumnVector.java @@ -29,6 +29,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; @@ -62,6 +63,7 @@ public final class ArrowWritableColumnVector extends WritableColumnVector { private int ordinal; private ValueVector vector; + private ValueVector dictionaryVector; public static BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); public static BufferAllocator getNewAllocator() { @@ -90,7 +92,24 @@ public static ArrowWritableColumnVector[] allocateColumns(int capacity, StructTy return vectors; } - public static ArrowWritableColumnVector[] loadColumns(int capacity, List fieldVectors) { + public static ArrowWritableColumnVector[] loadColumns(int capacity, + List fieldVectors, + List dictionaryVectors) { + if (fieldVectors.size() != dictionaryVectors.size()) { + throw new IllegalArgumentException("Mismatched field vectors and dictionary vectors. " + + "Field vector count: " + fieldVectors.size() + ", " + + "dictionary vector count: " + dictionaryVectors.size()); + } + ArrowWritableColumnVector[] vectors = new ArrowWritableColumnVector[fieldVectors.size()]; + for (int i = 0; i < fieldVectors.size(); i++) { + vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), dictionaryVectors.get(i), + i, capacity, false); + } + return vectors; + } + + public static ArrowWritableColumnVector[] loadColumns(int capacity, + List fieldVectors) { ArrowWritableColumnVector[] vectors = new ArrowWritableColumnVector[fieldVectors.size()]; for (int i = 0; i < fieldVectors.size(); i++) { vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), i, capacity, false); @@ -106,19 +125,26 @@ public static ArrowWritableColumnVector[] loadColumns(int capacity, Schema arrow return loadColumns(capacity, root.getFieldVectors()); } - public ArrowWritableColumnVector(ValueVector vector, int ordinal, int capacity, boolean init){ + @Deprecated + public ArrowWritableColumnVector(ValueVector vector, int ordinal, int capacity, boolean init) { + this(vector, null, ordinal, capacity, init); + } + + public ArrowWritableColumnVector(ValueVector vector, ValueVector dicionaryVector, + int ordinal, int capacity, boolean init) { super(capacity, ArrowUtils.fromArrowField(vector.getField())); vectorCount.getAndIncrement(); refCnt.getAndIncrement(); this.ordinal = ordinal; this.vector = vector; + this.dictionaryVector = dicionaryVector; if (init) { vector.setInitialCapacity(capacity); vector.allocateNew(); } writer = createVectorWriter(vector); - createVectorAccessor(vector); + createVectorAccessor(vector, dicionaryVector); } public ArrowWritableColumnVector(int capacity, DataType dataType) { @@ -135,15 +161,30 @@ public ArrowWritableColumnVector(int capacity, DataType dataType) { vector.setInitialCapacity(capacity); vector.allocateNew(); this.writer = createVectorWriter(vector); - createVectorAccessor(vector); - + createVectorAccessor(vector, null); } public ValueVector getValueVector() { return vector; } - private void createVectorAccessor(ValueVector vector) { + private void createVectorAccessor(ValueVector vector, ValueVector dictionary) { + if (dictionary != null) { + if (!(vector instanceof IntVector)) { + throw new IllegalArgumentException("Expect int32 index vector. Found: " + + vector.getMinorType()); + } + IntVector index = (IntVector) vector; + if (dictionary instanceof VarBinaryVector) { + accessor = new DictionaryEncodedBinaryAccessor(index, (VarBinaryVector) dictionary); + } else if (dictionary instanceof VarCharVector) { + accessor = new DictionaryEncodedStringAccessor(index, (VarCharVector) dictionary); + } else { + throw new IllegalArgumentException("Unrecognized index value type: " + + dictionary.getMinorType()); + } + return; + } if (vector instanceof BitVector) { accessor = new BooleanAccessor((BitVector) vector); } else if (vector instanceof TinyIntVector) { @@ -271,6 +312,9 @@ public void close() { childColumns = null; } vector.close(); + if (dictionaryVector != null) { + dictionaryVector.close(); + } } public static String stat() { @@ -675,7 +719,8 @@ public void putArray(int rowId, int offset, int length) { @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { - throw new UnsupportedOperationException(); + writer.setBytes(rowId, length, value, offset); + return length; } // @@ -963,6 +1008,32 @@ final UTF8String getUTF8String(int rowId) { } } + private static class DictionaryEncodedStringAccessor extends ArrowVectorAccessor { + + private final IntVector index; + private final VarCharVector dictionary; + private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); + + DictionaryEncodedStringAccessor(IntVector index, VarCharVector dictionary) { + super(index); + this.index = index; + this.dictionary = dictionary; + } + + @Override + final UTF8String getUTF8String(int rowId) { + int idx = index.get(rowId); + dictionary.get(idx, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } + } + private static class BinaryAccessor extends ArrowVectorAccessor { private final VarBinaryVector accessor; @@ -978,6 +1049,23 @@ final byte[] getBinary(int rowId) { } } + private static class DictionaryEncodedBinaryAccessor extends ArrowVectorAccessor { + private final IntVector index; + private final VarBinaryVector dictionary; + + DictionaryEncodedBinaryAccessor(IntVector index, VarBinaryVector dictionary) { + super(index); + this.index = index; + this.dictionary = dictionary; + } + + @Override + final byte[] getBinary(int rowId) { + int idx = index.get(rowId); + return dictionary.getObject(idx); + } + } + private static class DateAccessor extends ArrowVectorAccessor { private final DateDayVector accessor; diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/CompressedVectorLoader.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/CompressedVectorLoader.java new file mode 100644 index 000000000..52398cad0 --- /dev/null +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/CompressedVectorLoader.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.sparkColumnarPlugin.vectorized; + +import java.util.Iterator; + +import org.apache.arrow.util.Collections2; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; + +import io.netty.buffer.ArrowBuf; + +/** Loads compressed buffers into vectors. */ +public class CompressedVectorLoader extends VectorLoader { + + /** + * Construct with a root to load and will create children in root based on schema. + * + * @param root the root to add vectors to based on schema + */ + public CompressedVectorLoader(VectorSchemaRoot root) { + super(root); + } + + /** + * Loads the record batch in the vectors. will not close the record batch + * + * @param recordBatch the batch to load + */ + public void loadCompressed(ArrowRecordBatch recordBatch) { + Iterator buffers = recordBatch.getBuffers().iterator(); + Iterator nodes = recordBatch.getNodes().iterator(); + for (FieldVector fieldVector : root.getFieldVectors()) { + loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes); + } + root.setRootRowCount(recordBatch.getLength()); + if (nodes.hasNext() || buffers.hasNext()) { + throw new IllegalArgumentException( + "not all nodes and buffers were consumed. nodes: " + + Collections2.toList(nodes).toString() + + " buffers: " + + Collections2.toList(buffers).toString()); + } + } +} diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java index 4076d8079..8311ea93c 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluator.java @@ -17,6 +17,7 @@ package com.intel.sparkColumnarPlugin.vectorized; +import com.intel.sparkColumnarPlugin.ColumnarPluginConfig; import io.netty.buffer.ArrowBuf; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -33,59 +34,75 @@ import org.apache.arrow.vector.types.pojo.Schema; public class ExpressionEvaluator implements AutoCloseable { - private long nativeHandler = 0; private ExpressionEvaluatorJniWrapper jniWrapper; /** Wrapper for native API. */ public ExpressionEvaluator() throws IOException { - jniWrapper = new ExpressionEvaluatorJniWrapper(); + String tmp_dir = ColumnarPluginConfig.getTempFile(); + if (tmp_dir == null) { + tmp_dir = System.getProperty("java.io.tmpdir"); + } + jniWrapper = new ExpressionEvaluatorJniWrapper(tmp_dir); + jniWrapper.nativeSetJavaTmpDir(tmp_dir); + jniWrapper.nativeSetBatchSize(ColumnarPluginConfig.getBatchSize()); } /** Convert ExpressionTree into native function. */ public void build(Schema schema, List exprs) throws RuntimeException, IOException, GandivaException { - nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, false); + nativeHandler = jniWrapper.nativeBuild( + getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, false); } /** Convert ExpressionTree into native function. */ public void build(Schema schema, List exprs, boolean finishReturn) throws RuntimeException, IOException, GandivaException { - nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, finishReturn); + nativeHandler = jniWrapper.nativeBuild( + getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), null, finishReturn); } /** Convert ExpressionTree into native function. */ public void build(Schema schema, List exprs, Schema resSchema) throws RuntimeException, IOException, GandivaException { - nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), false); + nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), + getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), false); } /** Convert ExpressionTree into native function. */ - public void build(Schema schema, List exprs, Schema resSchema, boolean finishReturn) - throws RuntimeException, IOException, GandivaException { - nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), finishReturn); + public void build(Schema schema, List exprs, Schema resSchema, + boolean finishReturn) throws RuntimeException, IOException, GandivaException { + nativeHandler = jniWrapper.nativeBuild(getSchemaBytesBuf(schema), + getExprListBytesBuf(exprs), getSchemaBytesBuf(resSchema), finishReturn); } /** Convert ExpressionTree into native function. */ - public void build(Schema schema, List exprs, List finish_exprs) + public void build( + Schema schema, List exprs, List finish_exprs) throws RuntimeException, IOException, GandivaException { - nativeHandler = jniWrapper.nativeBuildWithFinish(getSchemaBytesBuf(schema), getExprListBytesBuf(exprs), getExprListBytesBuf(finish_exprs)); + nativeHandler = jniWrapper.nativeBuildWithFinish(getSchemaBytesBuf(schema), + getExprListBytesBuf(exprs), getExprListBytesBuf(finish_exprs)); } /** Set result Schema in some special cases */ - public void setReturnFields(Schema schema) throws RuntimeException, IOException, GandivaException { + public void setReturnFields(Schema schema) + throws RuntimeException, IOException, GandivaException { jniWrapper.nativeSetReturnFields(nativeHandler, getSchemaBytesBuf(schema)); } - /** Evaluate input data using builded native function, and output as recordBatch. */ + /** + * Evaluate input data using builded native function, and output as recordBatch. + */ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch) throws RuntimeException, IOException { return evaluate(recordBatch, null); } - /** Evaluate input data using builded native function, and output as recordBatch. */ - public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVectorInt16 selectionVector) - throws RuntimeException, IOException { + /** + * Evaluate input data using builded native function, and output as recordBatch. + */ + public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, + SelectionVectorInt16 selectionVector) throws RuntimeException, IOException { List buffers = recordBatch.getBuffers(); List buffersLayout = recordBatch.getBuffersLayout(); long[] bufAddrs = new long[buffers.size()]; @@ -105,14 +122,15 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVector int selectionVectorRecordCount = selectionVector.getRecordCount(); long selectionVectorAddr = selectionVector.getBuffer().memoryAddress(); long selectionVectorSize = selectionVector.getBuffer().capacity(); - resRecordBatchBuilderList = - jniWrapper.nativeEvaluateWithSelection(nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes, - selectionVectorRecordCount, selectionVectorAddr, selectionVectorSize); + resRecordBatchBuilderList = jniWrapper.nativeEvaluateWithSelection(nativeHandler, + recordBatch.getLength(), bufAddrs, bufSizes, selectionVectorRecordCount, + selectionVectorAddr, selectionVectorSize); } else { - resRecordBatchBuilderList = - jniWrapper.nativeEvaluate(nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes); + resRecordBatchBuilderList = jniWrapper.nativeEvaluate( + nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes); } - ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[resRecordBatchBuilderList.length]; + ArrowRecordBatch[] recordBatchList = + new ArrowRecordBatch[resRecordBatchBuilderList.length]; for (int i = 0; i < resRecordBatchBuilderList.length; i++) { if (resRecordBatchBuilderList[i] == null) { recordBatchList[i] = null; @@ -125,7 +143,9 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVector return recordBatchList; } - /** Evaluate input data using builded native function, and output as recordBatch. */ + /** + * Evaluate input data using builded native function, and output as recordBatch. + */ public void SetMember(ArrowRecordBatch recordBatch) throws RuntimeException, IOException { List buffers = recordBatch.getBuffers(); @@ -142,12 +162,15 @@ public void SetMember(ArrowRecordBatch recordBatch) bufSizes[idx++] = bufLayout.getSize(); } - jniWrapper.nativeSetMember(nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes); + jniWrapper.nativeSetMember( + nativeHandler, recordBatch.getLength(), bufAddrs, bufSizes); } public ArrowRecordBatch[] finish() throws RuntimeException, IOException { - ArrowRecordBatchBuilder[] resRecordBatchBuilderList = jniWrapper.nativeFinish(nativeHandler); - ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[resRecordBatchBuilderList.length]; + ArrowRecordBatchBuilder[] resRecordBatchBuilderList = + jniWrapper.nativeFinish(nativeHandler); + ArrowRecordBatch[] recordBatchList = + new ArrowRecordBatch[resRecordBatchBuilderList.length]; for (int i = 0; i < resRecordBatchBuilderList.length; i++) { if (resRecordBatchBuilderList[i] == null) { recordBatchList[i] = null; @@ -169,7 +192,8 @@ public void setDependency(BatchIterator child) throws RuntimeException, IOExcept jniWrapper.nativeSetDependency(nativeHandler, child.getInstanceId(), -1); } - public void setDependency(BatchIterator child, int index) throws RuntimeException, IOException { + public void setDependency(BatchIterator child, int index) + throws RuntimeException, IOException { jniWrapper.nativeSetDependency(nativeHandler, child.getInstanceId(), index); } @@ -185,7 +209,8 @@ byte[] getSchemaBytesBuf(Schema schema) throws IOException { } byte[] getExprListBytesBuf(List exprs) throws GandivaException { - GandivaTypes.ExpressionList.Builder builder = GandivaTypes.ExpressionList.newBuilder(); + GandivaTypes.ExpressionList.Builder builder = + GandivaTypes.ExpressionList.newBuilder(); for (ExpressionTree expr : exprs) { builder.addExprs(expr.toProtobuf()); } diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java index de8a14357..b6b0d7997 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ExpressionEvaluatorJniWrapper.java @@ -20,102 +20,133 @@ import java.io.IOException; /** - * This class is implemented in JNI. This provides the Java interface to invoke functions in JNI. - * This file is used to generated the .h files required for jni. Avoid all external dependencies in - * this file. + * This class is implemented in JNI. This provides the Java interface to invoke + * functions in JNI. This file is used to generated the .h files required for + * jni. Avoid all external dependencies in this file. */ public class ExpressionEvaluatorJniWrapper { - /** Wrapper for native API. */ - public ExpressionEvaluatorJniWrapper() throws IOException { - JniUtils.getInstance(); + public ExpressionEvaluatorJniWrapper(String tmp_dir) throws IOException { + JniUtils.getInstance(tmp_dir); } /** - * Generates the projector module to evaluate the expressions with custom configuration. + * Set native env variables NATIVE_TMP_DIR + * + * @param path tmp path for native codes, use java.io.tmpdir + */ + native void nativeSetJavaTmpDir(String path); + + /** + * Set native env variables NATIVE_BATCH_SIZE * - * @param schemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf - * specification - * @param exprListBuf The serialized protobuf of the expression vector. Each expression is created - * using TreeBuilder::MakeExpression. - * @param resSchemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf - * specification - * @param finishReturn This parameter is used to indicate that this expression should return when calling finish - * @return A nativeHandler that is passed to the evaluateProjector() and closeProjector() methods + * @param batch_size numRows of one batch, use + * spark.sql.execution.arrow.maxRecordsPerBatch */ - native long nativeBuild(byte[] schemaBuf, byte[] exprListBuf, byte[] resSchemaBuf, boolean finishReturn) throws RuntimeException; + native void nativeSetBatchSize(int batch_size); /** - * Generates the projector module to evaluate the expressions with custom configuration. + * Generates the projector module to evaluate the expressions with custom + * configuration. * - * @param schemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf - * specification - * @param exprListBuf The serialized protobuf of the expression vector. Each expression is created - * using TreeBuilder::MakeExpression. - * @param finishExprListBuf The serialized protobuf of the expression vector. Each expression is created - * using TreeBuilder::MakeExpression. - * @return A nativeHandler that is passed to the evaluateProjector() and closeProjector() methods + * @param schemaBuf The schema serialized as a protobuf. See Types.proto to + * see the protobuf specification + * @param exprListBuf The serialized protobuf of the expression vector. Each + * expression is created using TreeBuilder::MakeExpression. + * @param resSchemaBuf The schema serialized as a protobuf. See Types.proto to + * see the protobuf specification + * @param finishReturn This parameter is used to indicate that this expression + * should return when calling finish + * @return A nativeHandler that is passed to the evaluateProjector() and + * closeProjector() methods */ - native long nativeBuildWithFinish(byte[] schemaBuf, byte[] exprListBuf, byte[] finishExprListBuf) throws RuntimeException; + native long nativeBuild(byte[] schemaBuf, byte[] exprListBuf, byte[] resSchemaBuf, + boolean finishReturn) throws RuntimeException; + + /** + * Generates the projector module to evaluate the expressions with custom + * configuration. + * + * @param schemaBuf The schema serialized as a protobuf. See Types.proto + * to see the protobuf specification + * @param exprListBuf The serialized protobuf of the expression vector. + * Each expression is created using + * TreeBuilder::MakeExpression. + * @param finishExprListBuf The serialized protobuf of the expression vector. + * Each expression is created using + * TreeBuilder::MakeExpression. + * @return A nativeHandler that is passed to the evaluateProjector() and + * closeProjector() methods + */ + native long nativeBuildWithFinish(byte[] schemaBuf, byte[] exprListBuf, + byte[] finishExprListBuf) throws RuntimeException; /** * Set return schema for this expressionTree. * - * @param nativeHandler nativeHandler representing expressions. Created using a call to - * buildNativeCode - * @param schemaBuf The schema serialized as a protobuf. See Types.proto to see the protobuf - * specification + * @param nativeHandler nativeHandler representing expressions. Created using a + * call to buildNativeCode + * @param schemaBuf The schema serialized as a protobuf. See Types.proto to + * see the protobuf specification */ - native void nativeSetReturnFields(long nativeHandler, byte[] schemaBuf) throws RuntimeException; + native void nativeSetReturnFields(long nativeHandler, byte[] schemaBuf) + throws RuntimeException; /** - * Evaluate the expressions represented by the nativeHandler on a record batch and store the - * output in ValueVectors. Throws an exception in case of errors + * Evaluate the expressions represented by the nativeHandler on a record batch + * and store the output in ValueVectors. Throws an exception in case of errors * - * @param nativeHandler nativeHandler representing expressions. Created using a call to - * buildNativeCode - * @param numRows Number of rows in the record batch - * @param bufAddrs An array of memory addresses. Each memory address points to a validity vector - * or a data vector (will add support for offset vectors later). - * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs, the size of the - * buffer is present in bufSizes - * @return A list of ArrowRecordBatchBuilder which can be used to build a List of ArrowRecordBatch + * @param nativeHandler nativeHandler representing expressions. Created using a + * call to buildNativeCode + * @param numRows Number of rows in the record batch + * @param bufAddrs An array of memory addresses. Each memory address points + * to a validity vector or a data vector (will add support + * for offset vectors later). + * @param bufSizes An array of buffer sizes. For each memory address in + * bufAddrs, the size of the buffer is present in bufSizes + * @return A list of ArrowRecordBatchBuilder which can be used to build a List + * of ArrowRecordBatch */ - native ArrowRecordBatchBuilder[] nativeEvaluate( - long nativeHandler, int numRows, long[] bufAddrs, long[] bufSizes) throws RuntimeException; + native ArrowRecordBatchBuilder[] nativeEvaluate(long nativeHandler, int numRows, + long[] bufAddrs, long[] bufSizes) throws RuntimeException; /** - * Evaluate the expressions represented by the nativeHandler on a record batch and store the - * output in ValueVectors. Throws an exception in case of errors + * Evaluate the expressions represented by the nativeHandler on a record batch + * and store the output in ValueVectors. Throws an exception in case of errors * - * @param nativeHandler nativeHandler representing expressions. Created using a call to - * buildNativeCode - * @param numRows Number of rows in the record batch - * @param bufAddrs An array of memory addresses. Each memory address points to a validity vector - * or a data vector (will add support for offset vectors later). - * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs, the size of the - * buffer is present in bufSizes - * @param selectionVector valid selected item record count - * @param selectionVector selectionVector memory address + * @param nativeHandler nativeHandler representing expressions. Created + * using a call to buildNativeCode + * @param numRows Number of rows in the record batch + * @param bufAddrs An array of memory addresses. Each memory address + * points to a validity vector or a data vector (will + * add support for offset vectors later). + * @param bufSizes An array of buffer sizes. For each memory address + * in bufAddrs, the size of the buffer is present in + * bufSizes + * @param selectionVector valid selected item record count + * @param selectionVector selectionVector memory address * @param selectionVectorSize selectionVector total size - * @return A list of ArrowRecordBatchBuilder which can be used to build a List of ArrowRecordBatch + * @return A list of ArrowRecordBatchBuilder which can be used to build a List + * of ArrowRecordBatch */ - native ArrowRecordBatchBuilder[] nativeEvaluateWithSelection( - long nativeHandler, int numRows, long[] bufAddrs, long[] bufSizes, - int selectionVectorRecordCount, long selectionVectorAddr, long selectionVectorSize) throws RuntimeException; + native ArrowRecordBatchBuilder[] nativeEvaluateWithSelection(long nativeHandler, + int numRows, long[] bufAddrs, long[] bufSizes, int selectionVectorRecordCount, + long selectionVectorAddr, long selectionVectorSize) throws RuntimeException; native void nativeSetMember( long nativeHandler, int numRows, long[] bufAddrs, long[] bufSizes); /** - * Evaluate the expressions represented by the nativeHandler on a record batch and store the - * output in ValueVectors. Throws an exception in case of errors + * Evaluate the expressions represented by the nativeHandler on a record batch + * and store the output in ValueVectors. Throws an exception in case of errors * - * @param nativeHandler nativeHandler representing expressions. Created using a call to - * buildNativeCode - * @return A list of ArrowRecordBatchBuilder which can be used to build a List of ArrowRecordBatch + * @param nativeHandler nativeHandler representing expressions. Created using a + * call to buildNativeCode + * @return A list of ArrowRecordBatchBuilder which can be used to build a List + * of ArrowRecordBatch */ - native ArrowRecordBatchBuilder[] nativeFinish(long nativeHandler) throws RuntimeException; + native ArrowRecordBatchBuilder[] nativeFinish(long nativeHandler) + throws RuntimeException; /** * Call Finish to get result, result will be as a iterator. @@ -128,11 +159,12 @@ native void nativeSetMember( /** * Set another evaluator's iterator as this one's dependency. * - * @param nativeHandler nativeHandler of this expression + * @param nativeHandler nativeHandler of this expression * @param childInstanceId childInstanceId of a child BatchIterator - * @param index exptected index of the output of BatchIterator + * @param index exptected index of the output of BatchIterator */ - native void nativeSetDependency(long nativeHandler, long childInstanceId, int index) throws RuntimeException; + native void nativeSetDependency(long nativeHandler, long childInstanceId, int index) + throws RuntimeException; /** * Closes the projector referenced by nativeHandler. diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java index 5f80c6bfa..c750ffc4e 100644 --- a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/JniUtils.java @@ -19,64 +19,146 @@ import java.io.File; import java.io.FileNotFoundException; +import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; +import java.net.JarURLConnection; +import java.net.URL; +import java.net.URLConnection; import java.nio.file.Files; import java.nio.file.StandardCopyOption; +import java.util.Enumeration; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; /** Helper class for JNI related operations. */ public class JniUtils { private static final String LIBRARY_NAME = "spark_columnar_jni"; private static boolean isLoaded = false; private static volatile JniUtils INSTANCE; + private String tmp_dir; public static JniUtils getInstance() throws IOException { + String tmp_dir = System.getProperty("java.io.tmpdir"); + return getInstance(tmp_dir); + } + + public static JniUtils getInstance(String tmp_dir) throws IOException { if (INSTANCE == null) { synchronized (JniUtils.class) { if (INSTANCE == null) { try { - INSTANCE = new JniUtils(); + INSTANCE = new JniUtils(tmp_dir); } catch (IllegalAccessException ex) { throw new IOException("IllegalAccess"); } } } } - return INSTANCE; } - private JniUtils() throws IOException, IllegalAccessException { - try { - loadLibraryFromJar(); - } catch (IOException ex) { - System.loadLibrary(LIBRARY_NAME); + private JniUtils(String tmp_dir) + throws IOException, IllegalAccessException, IllegalStateException { + if (!isLoaded) { + try { + loadLibraryFromJar(tmp_dir); + } catch (IOException ex) { + System.loadLibrary(LIBRARY_NAME); + } + loadIncludeFromJar(tmp_dir); + isLoaded = true; } } - static void loadLibraryFromJar() throws IOException, IllegalAccessException { + static void loadLibraryFromJar(String tmp_dir) + throws IOException, IllegalAccessException { synchronized (JniUtils.class) { - if (!isLoaded) { - final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME); - final File libraryFile = - moveFileFromJarToTemp(System.getProperty("java.io.tmpdir"), libraryToLoad); - System.load(libraryFile.getAbsolutePath()); - isLoaded = true; + if (tmp_dir == null) { + tmp_dir = System.getProperty("java.io.tmpdir"); + System.out.println("loadLibraryFromJar " + tmp_dir); } + final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME); + final File libraryFile = moveFileFromJarToTemp(tmp_dir, libraryToLoad); + System.load(libraryFile.getAbsolutePath()); } } - private static File moveFileFromJarToTemp(final String tmpDir, String libraryToLoad) + private static void loadIncludeFromJar(String tmp_dir) + throws IOException, IllegalAccessException { + synchronized (JniUtils.class) { + if (tmp_dir == null) { + tmp_dir = System.getProperty("java.io.tmpdir"); + System.out.println("loadIncludeFromJar " + tmp_dir); + } + final String folderToLoad = "include"; + final URLConnection urlConnection = + JniUtils.class.getClassLoader().getResource("include").openConnection(); + if (urlConnection instanceof JarURLConnection) { + final JarFile jarFile = ((JarURLConnection) urlConnection).getJarFile(); + copyResourcesToDirectory(jarFile, folderToLoad, tmp_dir + "/" + "nativesql_include"); + } else { + throw new IOException(urlConnection.toString() + " is not JarUrlConnection"); + } + } + } + + private static File moveFileFromJarToTemp(String tmpDir, String libraryToLoad) throws IOException { - final File temp = File.createTempFile(tmpDir, libraryToLoad); + // final File temp = File.createTempFile(tmpDir, libraryToLoad); + final File temp = new File(tmpDir + "/" + libraryToLoad); try (final InputStream is = - JniUtils.class.getClassLoader().getResourceAsStream(libraryToLoad)) { + JniUtils.class.getClassLoader().getResourceAsStream(libraryToLoad)) { if (is == null) { throw new FileNotFoundException(libraryToLoad); } else { - Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING); + try { + Files.copy(is, temp.toPath()); + } catch (Exception e) { + } } } return temp; } + + /** + * Copies a directory from a jar file to an external directory. + */ + public static void copyResourcesToDirectory( + JarFile fromJar, String jarDir, String destDir) throws IOException { + for (Enumeration entries = fromJar.entries(); entries.hasMoreElements();) { + JarEntry entry = entries.nextElement(); + if (entry.getName().startsWith(jarDir + "/") && !entry.isDirectory()) { + File dest = + new File(destDir + "/" + entry.getName().substring(jarDir.length() + 1)); + File parent = dest.getParentFile(); + if (parent != null) { + parent.mkdirs(); + } + + FileOutputStream out = new FileOutputStream(dest); + InputStream in = fromJar.getInputStream(entry); + + try { + byte[] buffer = new byte[8 * 1024]; + + int s = 0; + while ((s = in.read(buffer)) > 0) { + out.write(buffer, 0, s); + } + } catch (IOException e) { + throw new IOException("Could not copy asset from jar file", e); + } finally { + try { + in.close(); + } catch (IOException ignored) { + } + try { + out.close(); + } catch (IOException ignored) { + } + } + } + } + } } diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/PartitionFileInfo.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/PartitionFileInfo.java new file mode 100644 index 000000000..bedab8f15 --- /dev/null +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/PartitionFileInfo.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.sparkColumnarPlugin.vectorized; + +public class PartitionFileInfo { + private int pid; + private String filePath; + + public PartitionFileInfo(int pid, String filePath) { + this.pid = pid; + this.filePath = filePath; + } + + public int getPid() { + return pid; + } + + public String getFilePath() { + return filePath; + } +} diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleDecompressionJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleDecompressionJniWrapper.java new file mode 100644 index 000000000..5bc46f7c8 --- /dev/null +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleDecompressionJniWrapper.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.sparkColumnarPlugin.vectorized; + +import java.io.IOException; + +public class ShuffleDecompressionJniWrapper { + + public ShuffleDecompressionJniWrapper() throws IOException { + JniUtils.getInstance(); + } + + /** + * Make for multiple decompression with the same schema + * + * @param schemaBuf serialized arrow schema + * @return native schema holder id + * @throws RuntimeException + */ + public native long make(byte[] schemaBuf) throws RuntimeException; + + public native ArrowRecordBatchBuilder decompress( + long schemaHolderId, + String compressionCodec, + int numRows, + long[] bufAddrs, + long[] bufSizes, + long[] bufMask) + throws RuntimeException; + + /** + * Release resources associated with designated schema holder instance. + * + * @param schemaHolderId of the schema holder instance. + */ + public native void close(long schemaHolderId); +} diff --git a/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleSplitterJniWrapper.java b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleSplitterJniWrapper.java new file mode 100644 index 000000000..1cdc300a5 --- /dev/null +++ b/oap-native-sql/core/src/main/java/com/intel/sparkColumnarPlugin/vectorized/ShuffleSplitterJniWrapper.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.sparkColumnarPlugin.vectorized; + +import java.io.IOException; + +public class ShuffleSplitterJniWrapper { + + public ShuffleSplitterJniWrapper() throws IOException { + JniUtils.getInstance(); + } + + /** + * Construct native splitter for shuffled RecordBatch over + * + * @param schemaBuf serialized arrow schema + * @param bufferSize size of native buffers hold by each partition writer + * @return native splitter instance id if created successfully. + * @throws RuntimeException + */ + public native long make(byte[] schemaBuf, long bufferSize, String localDirs) + throws RuntimeException; + + /** + * Split one record batch represented by bufAddrs and bufSizes into several batches. The batch is + * split according to the first column as partition id. During splitting, the data in native + * buffers will be write to disk when the buffers are full. + * + * @param splitterId + * @param numRows Rows per batch + * @param bufAddrs Addresses of buffers + * @param bufSizes Sizes of buffers + * @throws RuntimeException + */ + public native void split(long splitterId, int numRows, long[] bufAddrs, long[] bufSizes) + throws RuntimeException; + + /** + * Write the data remained in the buffers hold by native splitter to each partition's temporary + * file. And stop processing splitting + * + * @param splitterId + * @throws RuntimeException + */ + public native void stop(long splitterId) throws RuntimeException; + + /** + * Set the output buffer for each partition. Splitter will maintain one buffer for each partition + * id occurred, and write data to file when buffer is full. Default buffer size will be set to + * 4096 rows. + * + * @param splitterId + * @param bufferSize In row, not bytes. Default buffer size will be set to 4096 rows. + */ + public native void setPartitionBufferSize(long splitterId, long bufferSize); + + /** + * Set compression codec for splitter's output. Default will be uncompressed. + * + * @param splitterId + * @param codec "lz4", "zstd", "uncompressed" + */ + public native void setCompressionCodec(long splitterId, String codec); + + /** + * Get all files information created by the splitter. Used by the {@link + * org.apache.spark.shuffle.ColumnarShuffleWriter} These files are temporarily existed and will be + * deleted after the combination. + * + * @param splitterId + * @return an array of all files information + */ + public native PartitionFileInfo[] getPartitionFileInfo(long splitterId); + + /** + * Get the total bytes written to disk. + * + * @param splitterId + * @return + */ + public native long getTotalBytesWritten(long splitterId); + + /** + * Release resources associated with designated splitter instance. + * + * @param splitterId of the splitter instance. + */ + public native void close(long splitterId); +} diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala index 21bca93d7..5842150fb 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPlugin.scala @@ -17,8 +17,9 @@ package com.intel.sparkColumnarPlugin -import com.intel.sparkColumnarPlugin.execution._ +import java.util.Locale +import com.intel.sparkColumnarPlugin.execution._ import org.apache.spark.internal.Logging import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.rules.Rule @@ -73,12 +74,20 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] { logDebug(s"Columnar Processing for ${plan.getClass} is not currently supported.") plan.withNewChildren(children) } - /*case plan: ShuffleExchangeExec => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarShuffleExchangeExec( - plan.outputPartitioning, - replaceWithColumnarPlan(plan.child), - plan.canChangeNumPartitions)*/ + case plan: ShuffleExchangeExec => + if (columnarConf.enableColumnarShuffle) { + val child = replaceWithColumnarPlan(plan.child) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + CoalesceBatchesExec( + new ColumnarShuffleExchangeExec( + plan.outputPartitioning, + child, + plan.canChangeNumPartitions)) + } else { + val children = plan.children.map(replaceWithColumnarPlan) + logDebug(s"Columnar Processing for ${plan.getClass} is not currently supported.") + plan.withNewChildren(children) + } case plan: ShuffledHashJoinExec => val left = replaceWithColumnarPlan(plan.left) val right = replaceWithColumnarPlan(plan.right) diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala index fd72fcaab..ee1cf7ac1 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/ColumnarPluginConfig.scala @@ -22,6 +22,13 @@ import org.apache.spark.SparkConf class ColumnarPluginConfig(conf: SparkConf) { val enableColumnarSort: Boolean = conf.getBoolean("spark.sql.columnar.sort", defaultValue = false) + val enableColumnarShuffle: Boolean = conf + .get("spark.shuffle.manager", "sort") + .equals("org.apache.spark.shuffle.sort.ColumnarShuffleManager") + val batchSize: Int = + conf.getInt("spark.sql.execution.arrow.maxRecordsPerBatch", defaultValue = 10000) + val tmpFile: String = + conf.getOption("spark.sql.columnar.tmp_dir").getOrElse(null) } object ColumnarPluginConfig { @@ -41,4 +48,18 @@ object ColumnarPluginConfig { ins } } + def getBatchSize: Int = synchronized { + if (ins == null) { + 10000 + } else { + ins.batchSize + } + } + def getTempFile: String = synchronized { + if (ins != null) { + ins.tmpFile + } else { + System.getProperty("java.io.tmpdir") + } + } } diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/CoalesceBatchesExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/CoalesceBatchesExec.scala new file mode 100644 index 000000000..f92eea89a --- /dev/null +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/CoalesceBatchesExec.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.sparkColumnarPlugin.execution + +import java.util.concurrent.TimeUnit + +import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector +import org.apache.arrow.vector.util.VectorBatchAppender +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import scala.collection.mutable.ListBuffer + +case class CoalesceBatchesExec(child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = child.output + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "CoalesceBatches" + + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"), + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), + "concatTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "concat batch time total"), + "collectTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "collect batch time total"), + "avgCoalescedNumRows" -> SQLMetrics + .createAverageMetric(sparkContext, "avg coalesced batch num rows")) + // TODO: peak device memory total + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + import CoalesceBatchesExec._ + + val recordsPerBatch = conf.arrowMaxRecordsPerBatch + val numInputRows = longMetric("numInputRows") + val numInputBatches = longMetric("numInputBatches") + val numOutputBatches = longMetric("numOutputBatches") + val concatTime = longMetric("concatTime") + val collectTime = longMetric("collectTime") + val avgCoalescedNumRows = longMetric("avgCoalescedNumRows") + + child.executeColumnar().mapPartitions { iter => + if (iter.hasNext) { + new Iterator[ColumnarBatch] { + var target: ColumnarBatch = _ + var numBatchesTotal: Long = _ + var numRowsTotal: Long = _ + + TaskContext.get().addTaskCompletionListener[Unit] { _ => + closePrevious() + if (numBatchesTotal > 0) { + avgCoalescedNumRows.set(numRowsTotal.toDouble / numBatchesTotal) + } + } + + private def closePrevious(): Unit = { + if (target != null) { + target.close() + target = null + } + } + + override def hasNext: Boolean = { + iter.hasNext + } + + override def next(): ColumnarBatch = { + closePrevious() + var rowCount = 0 + val batchesToAppend = ListBuffer[ColumnarBatch]() + + val beforeCollect = System.nanoTime + target = iter.next() + target.retain() + rowCount += target.numRows + + while (iter.hasNext && rowCount < recordsPerBatch) { + val delta = iter.next() + delta.retain() + rowCount += delta.numRows + batchesToAppend += delta + } + + val beforeConcat = System.nanoTime + collectTime += beforeConcat - beforeCollect + + coalesce(target, batchesToAppend.toList) + target.setNumRows(rowCount) + + concatTime += System.nanoTime - beforeConcat + numInputRows += rowCount + numInputBatches += (1 + batchesToAppend.length) + numOutputBatches += 1 + + // used for calculating avgCoalescedNumRows + numRowsTotal += rowCount + numBatchesTotal += 1 + + batchesToAppend.foreach(cb => cb.close()) + + target + } + } + } else { + Iterator.empty + } + } + } +} + +object CoalesceBatchesExec { + implicit class ArrowColumnarBatchRetainer(val cb: ColumnarBatch) { + def retain(): Unit = { + (0 until cb.numCols).toList.foreach(i => + cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain()) + } + } + + def coalesce(targetBatch: ColumnarBatch, batchesToAppend: List[ColumnarBatch]): Unit = { + (0 until targetBatch.numCols).toList.foreach { i => + val targetVector = + targetBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector + val vectorsToAppend = batchesToAppend.map { cb => + cb.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector + } + VectorBatchAppender.batchAppend(targetVector, vectorsToAppend: _*) + } + } +} diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala index 6acef2272..45501a95a 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarBatchScanExec.scala @@ -17,6 +17,7 @@ package com.intel.sparkColumnarPlugin.execution +import com.intel.sparkColumnarPlugin.ColumnarPluginConfig import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan} @@ -24,9 +25,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -class ColumnarBatchScanExec( - output: Seq[AttributeReference], - @transient scan: Scan) extends BatchScanExec(output, scan) { +class ColumnarBatchScanExec(output: Seq[AttributeReference], @transient scan: Scan) + extends BatchScanExec(output, scan) { + val tmpDir = ColumnarPluginConfig.getConf(sparkContext.getConf).tmpFile override def supportsColumnar(): Boolean = true override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -34,7 +35,8 @@ class ColumnarBatchScanExec( override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") val scanTime = longMetric("scanTime") - val inputColumnarRDD = new ColumnarDataSourceRDD(sparkContext, partitions, readerFactory, true, scanTime) + val inputColumnarRDD = + new ColumnarDataSourceRDD(sparkContext, partitions, readerFactory, true, scanTime, tmpDir) inputColumnarRDD.map { r => numOutputRows += r.numRows() r diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala index 748deba7d..9a8cbfefc 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarDataSourceRDD.scala @@ -20,7 +20,11 @@ package com.intel.sparkColumnarPlugin.execution import com.intel.sparkColumnarPlugin.vectorized._ import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, PartitionReader} +import org.apache.spark.sql.connector.read.{ + InputPartition, + PartitionReaderFactory, + PartitionReader +} import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -28,7 +32,8 @@ import org.apache.spark.sql.execution.datasources.v2.VectorizedFilePartitionRead import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) - extends Partition with Serializable + extends Partition + with Serializable // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. @@ -37,8 +42,9 @@ class ColumnarDataSourceRDD( @transient private val inputPartitions: Seq[InputPartition], partitionReaderFactory: PartitionReaderFactory, columnarReads: Boolean, - scanTime: SQLMetric) - extends RDD[ColumnarBatch](sc, Nil) { + scanTime: SQLMetric, + tmp_dir: String) + extends RDD[ColumnarBatch](sc, Nil) { override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { @@ -56,7 +62,7 @@ class ColumnarDataSourceRDD( val reader = if (columnarReads) { partitionReaderFactory match { case factory: ParquetPartitionReaderFactory => - VectorizedFilePartitionReaderHandler.get(inputPartition, factory) + VectorizedFilePartitionReaderHandler.get(inputPartition, factory, tmp_dir) case _ => partitionReaderFactory.createColumnarReader(inputPartition) } } else { @@ -92,7 +98,8 @@ class ColumnarDataSourceRDD( reader.get() } } - val closeableColumnarBatchIterator = new CloseableColumnBatchIterator(iter.asInstanceOf[Iterator[ColumnarBatch]]) + val closeableColumnarBatchIterator = new CloseableColumnBatchIterator( + iter.asInstanceOf[Iterator[ColumnarBatch]]) // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator(context, closeableColumnarBatchIterator) } diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala index b3e2b438f..cf79c5d55 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarHashAggregateExec.scala @@ -67,6 +67,7 @@ class ColumnarHashAggregateExec( resultExpressions, child) { + val sparkConf = sparkContext.getConf override def supportsColumnar = true // Disable code generation @@ -77,7 +78,8 @@ class ColumnarHashAggregateExec( "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of Input batches"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation process"), - "elapseTime" -> SQLMetrics.createTimingMetric(sparkContext, "elapse time from very begin to this process")) + "elapseTime" -> SQLMetrics + .createTimingMetric(sparkContext, "elapse time from very begin to this process")) override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") @@ -108,10 +110,13 @@ class ColumnarHashAggregateExec( numOutputBatches, numOutputRows, aggTime, - elapseTime) - TaskContext.get().addTaskCompletionListener[Unit](_ => { - aggregation.close() - }) + elapseTime, + sparkConf) + TaskContext + .get() + .addTaskCompletionListener[Unit](_ => { + aggregation.close() + }) new CloseableColumnBatchIterator(aggregation.createIterator(iter)) } res diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala index abc079b6c..dbd106c15 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarShuffledHashJoinExec.scala @@ -61,15 +61,10 @@ class ColumnarShuffledHashJoinExec( buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends ShuffledHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - left, - right) { + right: SparkPlan) + extends ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) { + val sparkConf = sparkContext.getConf override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "joinTime" -> SQLMetrics.createTimingMetric(sparkContext, "join time"), @@ -85,15 +80,30 @@ class ColumnarShuffledHashJoinExec( val joinTime = longMetric("joinTime") val buildTime = longMetric("buildTime") val resultSchema = this.schema - streamedPlan.executeColumnar().zipPartitions(buildPlan.executeColumnar()) { (streamIter, buildIter) => - //val hashed = buildHashedRelation(buildIter) - //join(streamIter, hashed, numOutputRows) - val vjoin = ColumnarShuffledHashJoin.create(leftKeys, rightKeys, resultSchema, joinType, buildSide, condition, left, right, buildTime, joinTime, numOutputRows) - val vjoinResult = vjoin.columnarInnerJoin(streamIter, buildIter) - TaskContext.get().addTaskCompletionListener[Unit](_ => { - vjoin.close() - }) - new CloseableColumnBatchIterator(vjoinResult) + streamedPlan.executeColumnar().zipPartitions(buildPlan.executeColumnar()) { + (streamIter, buildIter) => + //val hashed = buildHashedRelation(buildIter) + //join(streamIter, hashed, numOutputRows) + val vjoin = ColumnarShuffledHashJoin.create( + leftKeys, + rightKeys, + resultSchema, + joinType, + buildSide, + condition, + left, + right, + buildTime, + joinTime, + numOutputRows, + sparkConf) + val vjoinResult = vjoin.columnarInnerJoin(streamIter, buildIter) + TaskContext + .get() + .addTaskCompletionListener[Unit](_ => { + vjoin.close() + }) + new CloseableColumnBatchIterator(vjoinResult) } } } diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala index ff2b2fb3a..6e58aa73b 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/execution/ColumnarSortExec.scala @@ -40,6 +40,8 @@ class ColumnarSortExec( child: SparkPlan, testSpillFrequency: Int = 0) extends SortExec(sortOrder, global, child, testSpillFrequency) { + + val sparkConf = sparkContext.getConf override def supportsColumnar = true // Disable code generation @@ -72,7 +74,8 @@ class ColumnarSortExec( numOutputBatches, numOutputRows, shuffleTime, - elapse) + elapse, + sparkConf) TaskContext .get() .addTaskCompletionListener[Unit](_ => { diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala index 221562a9e..70a9c4f2c 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/CodeGeneration.scala @@ -30,16 +30,8 @@ object CodeGeneration { val timeZoneId = SQLConf.get.sessionLocalTimeZone def getResultType(left: ArrowType, right: ArrowType): ArrowType = { - if (left.equals(right)) { - left - } else { - val left_precise_level = getPreciseLevel(left) - val right_precise_level = getPreciseLevel(right) - if (left_precise_level > right_precise_level) - left - else - right - } + //TODO(): remove this API + left } def getResultType(dataType: DataType): ArrowType = { diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala index a2ed66892..7ea630549 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregateExpression.scala @@ -80,7 +80,7 @@ class ColumnarAggregateExpression( var aggrFieldList: List[Field] = _ val (funcName, argSize, resSize) = mode match { - case Partial => + case Partial | PartialMerge => aggregateFunction.prettyName match { case "avg" => ("sum_count", 1, 2) case "count" => { diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala index 4b0fb0810..24c4043d3 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarAggregation.scala @@ -23,6 +23,7 @@ import java.util.Collections import java.util.concurrent.TimeUnit._ import util.control.Breaks._ +import com.intel.sparkColumnarPlugin.ColumnarPluginConfig import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector import org.apache.spark.sql.util.ArrowUtils import com.intel.sparkColumnarPlugin.vectorized.ExpressionEvaluator @@ -30,6 +31,7 @@ import com.intel.sparkColumnarPlugin.vectorized.BatchIterator import com.google.common.collect.Lists import org.apache.hadoop.mapreduce.TaskAttemptID +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -68,9 +70,11 @@ class ColumnarAggregation( numOutputBatches: SQLMetric, numOutputRows: SQLMetric, aggrTime: SQLMetric, - elapseTime: SQLMetric) + elapseTime: SQLMetric, + sparkConf: SparkConf) extends Logging { // build gandiva projection here. + ColumnarPluginConfig.getConf(sparkConf) var elapseTime_make: Long = 0 var rowId: Int = 0 var processedNumRows: Int = 0 @@ -103,12 +107,12 @@ class ColumnarAggregation( // 3. map original input to aggregate input var beforeAggregateProjector: ColumnarProjection = _ - var projectOrdinalList : List[Int] = _ - var aggregateInputAttributes : List[AttributeReference] = _ + var projectOrdinalList: List[Int] = _ + var aggregateInputAttributes: List[AttributeReference] = _ if (mode == null) { projectOrdinalList = List[Int]() - aggregateInputAttributes = List[AttributeReference]() + aggregateInputAttributes = List[AttributeReference]() } else { mode match { case Partial => { @@ -119,7 +123,7 @@ class ColumnarAggregation( projectOrdinalList = beforeAggregateProjector.getOrdinalList aggregateInputAttributes = beforeAggregateProjector.output } - case Final => { + case Final | PartialMerge => { val ordinal_attr_list = originalInputAttributes.toList.zipWithIndex .filter{case(expr, i) => !groupingOrdinalList.contains(i)} .map{case(expr, i) => { @@ -160,7 +164,14 @@ class ColumnarAggregation( // 5. create nativeAggregate evaluator val allNativeExpressions = groupingNativeExpression ::: aggregateNativeExpressions val allAggregateInputFieldList = groupingFieldList ::: aggregateFieldList - val allAggregateResultAttributes = groupingAttributes ::: aggregateAttributes.toList + var allAggregateResultAttributes : List[Attribute] = _ + mode match { + case Partial | PartialMerge => + val aggregateResultAttributes = getAttrForAggregateExpr(aggregateExpressions) + allAggregateResultAttributes = groupingAttributes ::: aggregateResultAttributes + case _ => + allAggregateResultAttributes = groupingAttributes ::: aggregateAttributes.toList + } val aggregateResultFieldList = allAggregateResultAttributes.map(attr => { Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) }) @@ -169,7 +180,7 @@ class ColumnarAggregation( val resultType = CodeGeneration.getResultType() val resultField = Field.nullable(s"dummy_res", resultType) - val expressionTree: List[ExpressionTree] = allNativeExpressions.map( expr => { + val expressionTree: List[ExpressionTree] = allNativeExpressions.map(expr => { val node = expr.doColumnarCodeGen_ext((groupingFieldList, allAggregateInputFieldList, resultType, resultField)) TreeBuilder.makeExpression(node, resultField) @@ -201,6 +212,47 @@ class ColumnarAggregation( } } + def getAttrForAggregateExpr(aggregateExpressions: Seq[AggregateExpression]): List[Attribute] = { + var aggregateAttr = new ListBuffer[Attribute]() + val size = aggregateExpressions.size + for (expIdx <- 0 until size) { + val exp: AggregateExpression = aggregateExpressions(expIdx) + val aggregateFunc = exp.aggregateFunction + aggregateFunc match { + case Average(_) => + val avg = aggregateFunc.asInstanceOf[Average] + val aggBufferAttr = avg.inputAggBufferAttributes + for (index <- 0 until aggBufferAttr.size) { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) + aggregateAttr += attr + } + case Sum(_) => + val sum = aggregateFunc.asInstanceOf[Sum] + val aggBufferAttr = sum.inputAggBufferAttributes + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0)) + aggregateAttr += attr + case Count(_) => + val count = aggregateFunc.asInstanceOf[Count] + val aggBufferAttr = count.inputAggBufferAttributes + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0)) + aggregateAttr += attr + case Max(_) => + val max = aggregateFunc.asInstanceOf[Max] + val aggBufferAttr = max.inputAggBufferAttributes + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0)) + aggregateAttr += attr + case Min(_) => + val min = aggregateFunc.asInstanceOf[Min] + val aggBufferAttr = min.inputAggBufferAttributes + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(0)) + aggregateAttr += attr + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } + } + aggregateAttr.toList + } + def updateAggregationResult(columnarBatch: ColumnarBatch): Unit = { val numRows = columnarBatch.numRows val groupingProjectCols = groupingOrdinalList.map(i => { @@ -290,7 +342,7 @@ class ColumnarAggregation( if (nextCalled == false && resultColumnarBatch != null) { return true } - if ( !nextBatch ) { + if (!nextBatch) { return false } @@ -358,7 +410,8 @@ object ColumnarAggregation { numOutputBatches: SQLMetric, numOutputRows: SQLMetric, aggrTime: SQLMetric, - elapseTime: SQLMetric): ColumnarAggregation = synchronized { + elapseTime: SQLMetric, + sparkConf: SparkConf): ColumnarAggregation = synchronized { columnarAggregation = new ColumnarAggregation( partIndex, groupingExpressions, @@ -371,7 +424,8 @@ object ColumnarAggregation { numOutputBatches, numOutputRows, aggrTime, - elapseTime) + elapseTime, + sparkConf) columnarAggregation } diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala index 6dad05b4e..4e0d6d486 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarCoalesceOperator.scala @@ -37,9 +37,9 @@ import scala.collection.mutable.ListBuffer * coalesce(1, 2) => 1 * coalesce(null, 1, 2) => 1 * coalesce(null, null, 2) => 2 + * coalesce(null, null, null) => null * }}} **/ -//TODO(): coalesce(null, null, null) => null class ColumnarCoalesce(exps: Seq[Expression], original: Expression) extends Coalesce(exps: Seq[Expression]) @@ -47,7 +47,7 @@ class ColumnarCoalesce(exps: Seq[Expression], original: Expression) with Logging { override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { val iter: Iterator[Expression] = exps.iterator - val exp = exps.head + val exp = iter.next() val (exp_node, expType): (TreeNode, ArrowType) = exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) @@ -63,10 +63,7 @@ class ColumnarCoalesce(exps: Seq[Expression], original: Expression) // Return the last element no matter if it is null val (exp_node, expType): (TreeNode, ArrowType) = exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val isnotnullNode = - TreeBuilder.makeFunction("isnotnull", Lists.newArrayList(exp_node), new ArrowType.Bool()) - val funcNode = TreeBuilder.makeIf(isnotnullNode, exp_node, exp_node, expType) - funcNode + exp_node } else { val exp = iter.next() val (exp_node, expType): (TreeNode, ArrowType) = diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarConcatOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarConcatOperator.scala new file mode 100644 index 000000000..ff0f70999 --- /dev/null +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarConcatOperator.scala @@ -0,0 +1,64 @@ +package com.intel.sparkColumnarPlugin.expression + +import com.google.common.collect.Lists +import com.google.common.collect.Sets +import org.apache.arrow.gandiva.evaluator._ +import org.apache.arrow.gandiva.exceptions.GandivaException +import org.apache.arrow.gandiva.expression._ +import org.apache.arrow.vector.types.pojo.ArrowType +import org.apache.arrow.vector.types.pojo.Field +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +import scala.collection.mutable.ListBuffer + +class ColumnarConcat(exps: Seq[Expression], original: Expression) + extends Concat(exps: Seq[Expression]) + with ColumnarExpression + with Logging { + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { + val iter: Iterator[Expression] = exps.iterator + val exp = iter.next() + val iterFaster: Iterator[Expression] = exps.iterator + iterFaster.next() + iterFaster.next() + + val (exp_node, expType): (TreeNode, ArrowType) = + exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + + val resultType = new ArrowType.Utf8() + val funcNode = TreeBuilder.makeFunction("concat", + Lists.newArrayList(exp_node, rightNode(args, exps, iter, iterFaster)), resultType) + (funcNode, expType) + } + + def rightNode(args: java.lang.Object, exps: Seq[Expression], + iter: Iterator[Expression], iterFaster: Iterator[Expression]): TreeNode = { + if (!iterFaster.hasNext) { + // When iter reaches the last but one expression + val (exp_node, expType): (TreeNode, ArrowType) = + exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + exp_node + } else { + val exp = iter.next() + iterFaster.next() + val (exp_node, expType): (TreeNode, ArrowType) = + exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val resultType = new ArrowType.Utf8() + val funcNode = TreeBuilder.makeFunction("concat", + Lists.newArrayList(exp_node, rightNode(args, exps, iter, iterFaster)), resultType) + funcNode + } + } +} + +object ColumnarConcatOperator { + + def create(exps: Seq[Expression], original: Expression): Expression = original match { + case c: Concat => + new ColumnarConcat(exps, original) + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } +} diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala index ff9cf131f..3dffba3f6 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarExpressionConverter.scala @@ -25,10 +25,10 @@ object ColumnarExpressionConverter extends Logging { var check_if_no_calculation = true - def replaceWithColumnarExpression(expr: Expression, attributeSeq: Seq[Attribute] = null): Expression = expr match { + def replaceWithColumnarExpression(expr: Expression, attributeSeq: Seq[Attribute] = null, expIdx : Int = -1): Expression = expr match { case a: Alias => logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") - new ColumnarAlias(replaceWithColumnarExpression(a.child, attributeSeq), a.name)( + new ColumnarAlias(replaceWithColumnarExpression(a.child, attributeSeq, expIdx), a.name)( a.exprId, a.qualifier, a.explicitMetadata) @@ -37,7 +37,11 @@ object ColumnarExpressionConverter extends Logging { if (attributeSeq != null) { val bindReference = BindReferences.bindReference(expr, attributeSeq, true) if (bindReference == expr) { - return new ColumnarAttributeReference(a.name, a.dataType, a.nullable, a.metadata)(a.exprId, a.qualifier) + if (expIdx == -1) { + return new ColumnarAttributeReference(a.name, a.dataType, a.nullable, a.metadata)(a.exprId, a.qualifier) + } else { + return new ColumnarBoundReference(expIdx, a.dataType, a.nullable) + } } val b = bindReference.asInstanceOf[BoundReference] new ColumnarBoundReference(b.ordinal, b.dataType, b.nullable) @@ -144,6 +148,22 @@ object ColumnarExpressionConverter extends Logging { case s: org.apache.spark.sql.execution.ScalarSubquery => logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") new ColumnarScalarSubquery(s) + case c: Concat => + check_if_no_calculation = false + logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") + val exps = c.children.map{ expr => + replaceWithColumnarExpression(expr, attributeSeq) + } + ColumnarConcatOperator.create( + exps, + expr) + case r: Round => + check_if_no_calculation = false + logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") + ColumnarRoundOperator.create( + replaceWithColumnarExpression(r.child, attributeSeq), + replaceWithColumnarExpression(r.scale), + expr) case expr => logWarning(s"${expr.getClass} ${expr} is not currently supported.") expr diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala index b071b1aa0..b1ae938b0 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarLiterals.scala @@ -39,13 +39,23 @@ class ColumnarLiteral(lit: Literal) val resultType = CodeGeneration.getResultType(dataType) dataType match { case t: StringType => - (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType) + value match { + case null => + (TreeBuilder.makeStringLiteral("null": java.lang.String), resultType) + case _ => + (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType) + } case t: IntegerType => (TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType) case t: LongType => (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Long]), resultType) case t: DoubleType => - (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType) + value match { + case null => + (TreeBuilder.makeLiteral(0.0: java.lang.Double), resultType) + case _ => + (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType) + } case d: DecimalType => val v = value.asInstanceOf[Decimal] (TreeBuilder.makeDecimalLiteral(v.toString, v.precision, v.scale), resultType) diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala index f317d0407..2e899a72c 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarProjection.scala @@ -73,10 +73,7 @@ class ColumnarProjection ( case (expr, i) => { ColumnarExpressionConverter.reset() var columnarExpr: Expression = - ColumnarExpressionConverter.replaceWithColumnarExpression(expr, originalInputAttributes) - if (columnarExpr.isInstanceOf[AttributeReference]) { - columnarExpr = new ColumnarBoundReference(i, columnarExpr.dataType, columnarExpr.nullable) - } + ColumnarExpressionConverter.replaceWithColumnarExpression(expr, originalInputAttributes, i) if (ColumnarExpressionConverter.ifNoCalculation == false) { check_if_no_calculation = false } @@ -87,7 +84,8 @@ class ColumnarProjection ( } } - val (ordinalList, arrowSchema) = if (projPrepareList.size > 0 && check_if_no_calculation == false) { + val (ordinalList, arrowSchema) = if (projPrepareList.size > 0 && + (!check_if_no_calculation || projPrepareList.size != inputList.size)) { val inputFieldList = inputList.asScala.toList.distinct val schema = new Schema(inputFieldList.asJava) projector = Projector.make(schema, projPrepareList.toList.asJava) diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarRoundOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarRoundOperator.scala new file mode 100644 index 000000000..506ad4849 --- /dev/null +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarRoundOperator.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.sparkColumnarPlugin.expression + +import com.google.common.collect.Lists + +import org.apache.arrow.gandiva.evaluator._ +import org.apache.arrow.gandiva.exceptions.GandivaException +import org.apache.arrow.gandiva.expression._ +import org.apache.arrow.vector.types.pojo.ArrowType +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.Field +import org.apache.arrow.vector.types.DateUnit + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.types._ + +import scala.collection.mutable.ListBuffer + +class ColumnarRound(child: Expression, scale: Expression, original: Expression) + extends Round(child: Expression, scale: Expression) + with ColumnarExpression + with Logging { + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { + val (child_node, childType): (TreeNode, ArrowType) = + child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val (scale_node, scaleType): (TreeNode, ArrowType) = + scale.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + //TODO(): get precision and scale from decimal + val precision = 8 + val decimalScale = 2 + val castNode = TreeBuilder.makeFunction("castDECIMAL", Lists.newArrayList(child_node), + new ArrowType.Decimal(precision, decimalScale)) + val funcNode = TreeBuilder.makeFunction("round", Lists.newArrayList(castNode, scale_node), + new ArrowType.Decimal(precision, decimalScale)) + + val resultType = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + val castFloat = TreeBuilder.makeFunction("castFLOAT8", + Lists.newArrayList(funcNode), resultType) + (castFloat, resultType) + } +} + +object ColumnarRoundOperator { + + def create(child: Expression, scale: Expression, original: Expression): Expression = original match { + case r: Round => + new ColumnarRound(child, scale, original) + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } +} diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala index f66eb1994..26be8f85e 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarShuffledHashJoin.scala @@ -19,6 +19,7 @@ package com.intel.sparkColumnarPlugin.expression import java.util.concurrent.TimeUnit._ +import com.intel.sparkColumnarPlugin.ColumnarPluginConfig import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector import org.apache.spark.TaskContext @@ -34,6 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import scala.collection.JavaConverters._ +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import scala.collection.mutable.ListBuffer @@ -66,184 +68,192 @@ class ColumnarShuffledHashJoin( right: SparkPlan, buildTime: SQLMetric, joinTime: SQLMetric, - totalOutputNumRows: SQLMetric - ) extends Logging{ - - var build_cb : ColumnarBatch = null - var last_cb: ColumnarBatch = null - - val inputBatchHolder = new ListBuffer[ColumnarBatch]() - // TODO - val l_input_schema: List[Attribute] = left.output.toList - val r_input_schema: List[Attribute] = right.output.toList - logInfo(s"\nleft_schema is ${l_input_schema}, right_schema is ${r_input_schema}, \nleftKeys is ${leftKeys}, rightKeys is ${rightKeys}, \nresultSchema is ${resultSchema}, \njoinType is ${joinType}, buildSide is ${buildSide}, condition is ${conditionOption}") - - val l_input_field_list: List[Field] = l_input_schema.toList.map(attr => { - Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) - }) - val r_input_field_list: List[Field] = r_input_schema.toList.map(attr => { - Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) - }) - - val resultFieldList = resultSchema.map(field => { + totalOutputNumRows: SQLMetric, + sparkConf: SparkConf) + extends Logging { + ColumnarPluginConfig.getConf(sparkConf) + + var build_cb: ColumnarBatch = null + var last_cb: ColumnarBatch = null + + val inputBatchHolder = new ListBuffer[ColumnarBatch]() + // TODO + val l_input_schema: List[Attribute] = left.output.toList + val r_input_schema: List[Attribute] = right.output.toList + logInfo( + s"\nleft_schema is ${l_input_schema}, right_schema is ${r_input_schema}, \nleftKeys is ${leftKeys}, rightKeys is ${rightKeys}, \nresultSchema is ${resultSchema}, \njoinType is ${joinType}, buildSide is ${buildSide}, condition is ${conditionOption}") + + val l_input_field_list: List[Field] = l_input_schema.toList.map(attr => { + Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) + }) + val r_input_field_list: List[Field] = r_input_schema.toList.map(attr => { + Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType)) + }) + + val resultFieldList = resultSchema + .map(field => { Field.nullable(field.name, CodeGeneration.getResultType(field.dataType)) - }).toList - - val leftKeyAttributes = leftKeys.toList.map(expr => { - ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression] - }) - val rightKeyAttributes = rightKeys.toList.map(expr => { - ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression] - }) - - val lkeyFieldList: List[Field] = leftKeyAttributes.toList.map(expr => { - val attr = ConverterUtils.getAttrFromExpr(expr) - Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType)) }) + .toList + + val leftKeyAttributes = leftKeys.toList.map(expr => { + ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression] + }) + val rightKeyAttributes = rightKeys.toList.map(expr => { + ConverterUtils.getAttrFromExpr(expr).asInstanceOf[Expression] + }) + + val lkeyFieldList: List[Field] = leftKeyAttributes.toList.map(expr => { + val attr = ConverterUtils.getAttrFromExpr(expr) + Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType)) + }) + + val rkeyFieldList: List[Field] = rightKeyAttributes.toList.map(expr => { + val attr = ConverterUtils.getAttrFromExpr(expr) + Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType)) + }) + + val ( + build_key_field_list, + stream_key_field_list, + stream_key_ordinal_list, + build_input_field_list, + stream_input_field_list) = buildSide match { + case BuildLeft => + val stream_key_expr_list = bindReferences(rightKeyAttributes, r_input_schema) + ( + lkeyFieldList, + rkeyFieldList, + stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal), + l_input_field_list, + r_input_field_list) + + case BuildRight => + val stream_key_expr_list = bindReferences(leftKeyAttributes, l_input_schema) + ( + rkeyFieldList, + lkeyFieldList, + stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal), + r_input_field_list, + l_input_field_list) - val rkeyFieldList: List[Field] = rightKeyAttributes.toList.map(expr => { - val attr = ConverterUtils.getAttrFromExpr(expr) - Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(expr.dataType)) - }) - - val (build_key_field_list, stream_key_field_list, stream_key_ordinal_list, build_input_field_list, stream_input_field_list) - = buildSide match { - case BuildLeft => - val stream_key_expr_list = bindReferences(rightKeyAttributes, r_input_schema) - (lkeyFieldList, rkeyFieldList, stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal), l_input_field_list, r_input_field_list) - - case BuildRight => - val stream_key_expr_list = bindReferences(leftKeyAttributes, l_input_schema) - (rkeyFieldList, lkeyFieldList, stream_key_expr_list.toList.map(_.asInstanceOf[BoundReference].ordinal), r_input_field_list, l_input_field_list) - - } - - val (probe_func_name, build_output_field_list, stream_output_field_list) = joinType match { - case _: InnerLike => - ("conditionedProbeArraysInner", build_input_field_list, stream_input_field_list) - case LeftSemi => - ("conditionedProbeArraysSemi", List[Field](), stream_input_field_list) - case LeftOuter => - ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list) - case RightOuter => - ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list) - case LeftAnti => - ("conditionedProbeArraysAnti", List[Field](), stream_input_field_list) - case _ => - throw new UnsupportedOperationException(s"Join Type ${joinType} is not supported yet.") - } + } - val build_input_arrow_schema: Schema = new Schema(build_input_field_list.asJava) + val (probe_func_name, build_output_field_list, stream_output_field_list) = joinType match { + case _: InnerLike => + ("conditionedProbeArraysInner", build_input_field_list, stream_input_field_list) + case LeftSemi => + ("conditionedProbeArraysSemi", List[Field](), stream_input_field_list) + case LeftOuter => + ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list) + case RightOuter => + ("conditionedProbeArraysOuter", build_input_field_list, stream_input_field_list) + case LeftAnti => + ("conditionedProbeArraysAnti", List[Field](), stream_input_field_list) + case _ => + throw new UnsupportedOperationException(s"Join Type ${joinType} is not supported yet.") + } - val stream_input_arrow_schema: Schema = new Schema(stream_input_field_list.asJava) + val build_input_arrow_schema: Schema = new Schema(build_input_field_list.asJava) - val build_output_arrow_schema: Schema = new Schema(build_output_field_list.asJava) - val stream_output_arrow_schema: Schema = new Schema(stream_output_field_list.asJava) + val stream_input_arrow_schema: Schema = new Schema(stream_input_field_list.asJava) - val stream_key_arrow_schema: Schema = new Schema(stream_key_field_list.asJava) - var output_arrow_schema: Schema = _ + val build_output_arrow_schema: Schema = new Schema(build_output_field_list.asJava) + val stream_output_arrow_schema: Schema = new Schema(stream_output_field_list.asJava) + val stream_key_arrow_schema: Schema = new Schema(stream_key_field_list.asJava) + var output_arrow_schema: Schema = _ - logInfo(s"\nbuild_key_field_list is ${build_key_field_list}, stream_key_field_list is ${stream_key_field_list}, stream_key_ordinal_list is ${stream_key_ordinal_list}, \nbuild_input_field_list is ${build_input_field_list}, stream_input_field_list is ${stream_input_field_list}, \nbuild_output_field_list is ${build_output_field_list}, stream_output_field_list is ${stream_output_field_list}") + logInfo( + s"\nbuild_key_field_list is ${build_key_field_list}, stream_key_field_list is ${stream_key_field_list}, stream_key_ordinal_list is ${stream_key_ordinal_list}, \nbuild_input_field_list is ${build_input_field_list}, stream_input_field_list is ${stream_input_field_list}, \nbuild_output_field_list is ${build_output_field_list}, stream_output_field_list is ${stream_output_field_list}") - /////////////////////////////// Create Prober ///////////////////////////// - // Prober is used to insert left table primary key into hashMap - // Then use iterator to probe right table primary key from hashmap - // to get corresponding indices of left table - // - val condition = conditionOption match { - case Some(c) => - c - case None => - null - } - val (conditionInputFieldList, conditionOutputFieldList) = buildSide match { - case BuildLeft => - (build_input_field_list, build_output_field_list ::: stream_output_field_list) - case BuildRight => - (build_input_field_list, stream_output_field_list ::: build_output_field_list) - } - val conditionArrowSchema = new Schema(conditionInputFieldList.asJava) - output_arrow_schema = new Schema(conditionOutputFieldList.asJava) - var conditionInputList : java.util.List[Field] = Lists.newArrayList() - val build_args_node = TreeBuilder.makeFunction( - "codegen_left_schema", - build_input_field_list.map(field => { + /////////////////////////////// Create Prober ///////////////////////////// + // Prober is used to insert left table primary key into hashMap + // Then use iterator to probe right table primary key from hashmap + // to get corresponding indices of left table + // + val condition = conditionOption match { + case Some(c) => + c + case None => + null + } + val (conditionInputFieldList, conditionOutputFieldList) = buildSide match { + case BuildLeft => + (build_input_field_list, build_output_field_list ::: stream_output_field_list) + case BuildRight => + (build_input_field_list, stream_output_field_list ::: build_output_field_list) + } + val conditionArrowSchema = new Schema(conditionInputFieldList.asJava) + output_arrow_schema = new Schema(conditionOutputFieldList.asJava) + var conditionInputList: java.util.List[Field] = Lists.newArrayList() + val build_args_node = TreeBuilder.makeFunction( + "codegen_left_schema", + build_input_field_list + .map(field => { TreeBuilder.makeField(field) - }).asJava, - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val stream_args_node = TreeBuilder.makeFunction( - "codegen_right_schema", - stream_input_field_list.map(field => { + }) + .asJava, + new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val stream_args_node = TreeBuilder.makeFunction( + "codegen_right_schema", + stream_input_field_list + .map(field => { TreeBuilder.makeField(field) - }).asJava, - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val build_keys_node = TreeBuilder.makeFunction( - "codegen_left_key_schema", - build_key_field_list.map(field => { + }) + .asJava, + new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val build_keys_node = TreeBuilder.makeFunction( + "codegen_left_key_schema", + build_key_field_list + .map(field => { TreeBuilder.makeField(field) - }).asJava, - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val stream_keys_node = TreeBuilder.makeFunction( - "codegen_right_key_schema", - stream_key_field_list.map(field => { + }) + .asJava, + new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val stream_keys_node = TreeBuilder.makeFunction( + "codegen_right_key_schema", + stream_key_field_list + .map(field => { TreeBuilder.makeField(field) - }).asJava, - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val condition_expression_node_list : java.util.List[org.apache.arrow.gandiva.expression.TreeNode] = - if (condition != null) { - val columnarExpression: Expression = - ColumnarExpressionConverter.replaceWithColumnarExpression(condition) - val (condition_expression_node, resultType) = - columnarExpression.asInstanceOf[ColumnarExpression].doColumnarCodeGen(conditionInputList) - Lists.newArrayList(build_keys_node, stream_keys_node, condition_expression_node) - } else { - Lists.newArrayList(build_keys_node, stream_keys_node) - } - val retType = Field.nullable("res", new ArrowType.Int(32, true)) - - // Make Expresion for conditionedProbe - var prober = new ExpressionEvaluator() - val condition_probe_node = TreeBuilder.makeFunction( - probe_func_name, - condition_expression_node_list, - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val codegen_probe_node = TreeBuilder.makeFunction( - "codegen_withTwoInputs", - Lists.newArrayList(condition_probe_node, build_args_node, stream_args_node), - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val condition_probe_expr = TreeBuilder.makeExpression( - codegen_probe_node, - retType) - - prober.build(build_input_arrow_schema, Lists.newArrayList(condition_probe_expr), true) - - /////////////////////////////// Create Shuffler ///////////////////////////// - // Shuffler will use input indices array to shuffle current table - // output a new table with new sequence. - // - var build_shuffler = new ExpressionEvaluator() - var probe_iterator: BatchIterator = _ - var build_shuffle_iterator: BatchIterator = _ - - // Make Expresion for conditionedShuffle - val condition_shuffle_node = TreeBuilder.makeFunction( - "conditionedShuffleArrayList", - Lists.newArrayList(), - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val codegen_shuffle_node = TreeBuilder.makeFunction( - "codegen_withTwoInputs", - Lists.newArrayList(condition_shuffle_node, build_args_node, stream_args_node), - new ArrowType.Int(32, true)/*dummy ret type, won't be used*/) - val condition_shuffle_expr = TreeBuilder.makeExpression( - codegen_shuffle_node, - retType) - build_shuffler.build(conditionArrowSchema, Lists.newArrayList(condition_shuffle_expr), output_arrow_schema, true) - + }) + .asJava, + new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val condition_expression_node_list + : java.util.List[org.apache.arrow.gandiva.expression.TreeNode] = + if (condition != null) { + val columnarExpression: Expression = + ColumnarExpressionConverter.replaceWithColumnarExpression(condition) + val (condition_expression_node, resultType) = + columnarExpression.asInstanceOf[ColumnarExpression].doColumnarCodeGen(conditionInputList) + Lists.newArrayList(build_keys_node, stream_keys_node, condition_expression_node) + } else { + Lists.newArrayList(build_keys_node, stream_keys_node) + } + val retType = Field.nullable("res", new ArrowType.Int(32, true)) + + // Make Expresion for conditionedProbe + var prober = new ExpressionEvaluator() + val condition_probe_node = TreeBuilder.makeFunction( + probe_func_name, + condition_expression_node_list, + new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val codegen_probe_node = TreeBuilder.makeFunction( + "codegen_withTwoInputs", + Lists.newArrayList(condition_probe_node, build_args_node, stream_args_node), + new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) + val condition_probe_expr = TreeBuilder.makeExpression(codegen_probe_node, retType) + + prober.build( + build_input_arrow_schema, + Lists.newArrayList(condition_probe_expr), + output_arrow_schema, + true) + var probe_iterator: BatchIterator = _ def columnarInnerJoin( - streamIter: Iterator[ColumnarBatch], - buildIter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { + streamIter: Iterator[ColumnarBatch], + buildIter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { val beforeBuild = System.nanoTime() @@ -253,10 +263,10 @@ class ColumnarShuffledHashJoin( } build_cb = buildIter.next() val build_rb = ConverterUtils.createArrowRecordBatch(build_cb) - (0 until build_cb.numCols).toList.foreach(i => build_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain()) + (0 until build_cb.numCols).toList.foreach(i => + build_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain()) inputBatchHolder += build_cb prober.evaluate(build_rb) - build_shuffler.evaluate(build_rb) ConverterUtils.releaseArrowRecordBatch(build_rb) } if (build_cb != null) { @@ -279,13 +289,11 @@ class ColumnarShuffledHashJoin( // there will be different when condition is null or not null probe_iterator = prober.finishByIterator() - build_shuffler.setDependency(probe_iterator) - build_shuffle_iterator = build_shuffler.finishByIterator() buildTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild) - + new Iterator[ColumnarBatch] { override def hasNext: Boolean = { - if(streamIter.hasNext) { + if (streamIter.hasNext) { true } else { inputBatchHolder.foreach(cb => cb.close()) @@ -298,9 +306,8 @@ class ColumnarShuffledHashJoin( last_cb = cb val beforeJoin = System.nanoTime() val stream_rb: ArrowRecordBatch = ConverterUtils.createArrowRecordBatch(cb) - probe_iterator.processAndCacheOne(stream_input_arrow_schema, stream_rb) - val output_rb = build_shuffle_iterator.process(stream_input_arrow_schema, stream_rb) - + val output_rb = probe_iterator.process(stream_input_arrow_schema, stream_rb) + ConverterUtils.releaseArrowRecordBatch(stream_rb) joinTime += NANOSECONDS.toMillis(System.nanoTime() - beforeJoin) if (output_rb == null) { @@ -311,22 +318,14 @@ class ColumnarShuffledHashJoin( val outputNumRows = output_rb.getLength val output = ConverterUtils.fromArrowRecordBatch(output_arrow_schema, output_rb) ConverterUtils.releaseArrowRecordBatch(output_rb) - totalOutputNumRows += outputNumRows + totalOutputNumRows += outputNumRows new ColumnarBatch(output.map(v => v.asInstanceOf[ColumnVector]).toArray, outputNumRows) - } + } } } } def close(): Unit = { - if (build_shuffler != null) { - build_shuffler.close() - build_shuffler = null - } - if (build_shuffle_iterator != null) { - build_shuffle_iterator.close() - build_shuffle_iterator = null - } if (prober != null) { prober.close() prober = null @@ -341,19 +340,31 @@ class ColumnarShuffledHashJoin( object ColumnarShuffledHashJoin { var columnarShuffedHahsJoin: ColumnarShuffledHashJoin = _ def create( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - resultSchema: StructType, - joinType: JoinType, - buildSide: BuildSide, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan, - buildTime: SQLMetric, - joinTime: SQLMetric, - numOutputRows: SQLMetric - ): ColumnarShuffledHashJoin = synchronized { - columnarShuffedHahsJoin = new ColumnarShuffledHashJoin(leftKeys, rightKeys, resultSchema, joinType, buildSide, condition, left, right, buildTime, joinTime, numOutputRows) + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + resultSchema: StructType, + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + buildTime: SQLMetric, + joinTime: SQLMetric, + numOutputRows: SQLMetric, + sparkConf: SparkConf): ColumnarShuffledHashJoin = synchronized { + columnarShuffedHahsJoin = new ColumnarShuffledHashJoin( + leftKeys, + rightKeys, + resultSchema, + joinType, + buildSide, + condition, + left, + right, + buildTime, + joinTime, + numOutputRows, + sparkConf) columnarShuffedHahsJoin } diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala index 39a0091ed..629e66253 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSorter.scala @@ -21,11 +21,13 @@ import java.nio.ByteBuffer import java.util.concurrent.TimeUnit._ import com.google.common.collect.Lists +import com.intel.sparkColumnarPlugin.ColumnarPluginConfig import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector import com.intel.sparkColumnarPlugin.vectorized.ExpressionEvaluator import com.intel.sparkColumnarPlugin.vectorized.BatchIterator import org.apache.spark.internal.Logging +import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -57,10 +59,12 @@ class ColumnarSorter( outputBatches: SQLMetric, outputRows: SQLMetric, shuffleTime: SQLMetric, - elapse: SQLMetric) + elapse: SQLMetric, + sparkConf: SparkConf) extends Logging { logInfo(s"ColumnarSorter sortOrder is ${sortOrder}, outputAttributes is ${outputAttributes}") + ColumnarPluginConfig.getConf(sparkConf) /////////////// Prepare ColumnarSorter ////////////// var processedNumRows: Long = 0 var sort_elapse: Long = 0 @@ -216,7 +220,8 @@ object ColumnarSorter { outputBatches: SQLMetric, outputRows: SQLMetric, shuffleTime: SQLMetric, - elapse: SQLMetric): ColumnarSorter = synchronized { + elapse: SQLMetric, + sparkConf: SparkConf): ColumnarSorter = synchronized { new ColumnarSorter( sortOrder, outputAsColumnar, @@ -225,7 +230,8 @@ object ColumnarSorter { outputBatches, outputRows, shuffleTime, - elapse) + elapse, + sparkConf) } } diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala index 5991fc467..bed28e169 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarSubquery.scala @@ -52,13 +52,23 @@ class ColumnarScalarSubquery( val resultType = CodeGeneration.getResultType(query.dataType) query.dataType match { case t: StringType => - (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType) + value match { + case null => + (TreeBuilder.makeStringLiteral("null": java.lang.String), resultType) + case _ => + (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType) + } case t: IntegerType => (TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType) case t: LongType => (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Long]), resultType) case t: DoubleType => - (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType) + value match { + case null => + (TreeBuilder.makeLiteral(0.0: java.lang.Double), resultType) + case _ => + (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType) + } case d: DecimalType => val v = value.asInstanceOf[Decimal] (TreeBuilder.makeDecimalLiteral(v.toString, v.precision, v.scale), resultType) diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala index 4071a23e9..0d25f88cc 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ColumnarUnaryOperator.scala @@ -132,6 +132,47 @@ class ColumnarUpper(child: Expression, original: Expression) } } +class ColumnarCast(child: Expression, datatype: DataType, timeZoneId: Option[String], original: Expression) + extends Cast(child: Expression, datatype: DataType, timeZoneId: Option[String]) + with ColumnarExpression + with Logging { + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { + val (child_node, childType): (TreeNode, ArrowType) = + child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + + val resultType = CodeGeneration.getResultType(dataType) + if (dataType == StringType) { + //TODO: fix cast uft8 + (child_node, childType) + } else if (dataType == IntegerType) { + val funcNode = + TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node), resultType) + (funcNode, resultType) + } else if (dataType == LongType) { + val funcNode = + TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node), resultType) + (funcNode, resultType) + //(child_node, childType) + } else if (dataType == FloatType) { + val funcNode = + TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node), resultType) + (funcNode, resultType) + } else if (dataType == DoubleType) { + val funcNode = + TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node), resultType) + (funcNode, resultType) + } else if (dataType == DateType) { + val funcNode = + TreeBuilder.makeFunction("castDATE", Lists.newArrayList(child_node), resultType) + (funcNode, resultType) + } else if (dataType == DecimalType) { + throw new UnsupportedOperationException(s"not currently supported: ${dataType}.") + } else { + throw new UnsupportedOperationException(s"not currently supported: ${dataType}.") + } + } +} + object ColumnarUnaryOperator { def create(child: Expression, original: Expression): Expression = original match { @@ -148,7 +189,7 @@ object ColumnarUnaryOperator { case u: Upper => new ColumnarUpper(child, u) case c: Cast => - child + new ColumnarCast(child, c.dataType, c.timeZoneId, c) case a: KnownFloatingPointNormalized => child case a: NormalizeNaNAndZero => diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala index 51f7f4683..16c417003 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/expression/ConverterUtils.scala @@ -17,24 +17,19 @@ package com.intel.sparkColumnarPlugin.expression -import java.util.concurrent.atomic.AtomicLong -import io.netty.buffer.ArrowBuf - import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector - +import io.netty.buffer.ArrowBuf +import org.apache.arrow.vector._ +import org.apache.arrow.vector.ipc.message.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer._ -import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.message.ArrowFieldNode -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.arrow.vector.types.pojo.Field -import org.apache.arrow.vector.types.pojo.ArrowType -import org.apache.spark.internal.Logging +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer @@ -75,7 +70,9 @@ object ConverterUtils extends Logging { new ArrowRecordBatch(numRowsInBatch, fieldNodes.toList.asJava, inputData.toList.asJava) } - def fromArrowRecordBatch(recordBatchSchema: Schema, recordBatch: ArrowRecordBatch): Array[ArrowWritableColumnVector] = { + def fromArrowRecordBatch( + recordBatchSchema: Schema, + recordBatch: ArrowRecordBatch): Array[ArrowWritableColumnVector] = { val numRows = recordBatch.getLength() ArrowWritableColumnVector.loadColumns(numRows, recordBatchSchema, recordBatch) } @@ -115,15 +112,24 @@ object ConverterUtils extends Logging { getAttrFromExpr(c.children(0)) case i: IsNull => getAttrFromExpr(i.child) + case a: Add => + getAttrFromExpr(a.left) + case s: Subtract => + getAttrFromExpr(s.left) + case u: Upper => + getAttrFromExpr(u.child) + case ss: Substring => + getAttrFromExpr(ss.children(0)) case other => - throw new UnsupportedOperationException(s"makeStructField is unable to parse from $other (${other.getClass}).") + throw new UnsupportedOperationException( + s"makeStructField is unable to parse from $other (${other.getClass}).") } } - def getResultAttrFromExpr(fieldExpr: Expression, name: String = "None"): AttributeReference = { + def getResultAttrFromExpr(fieldExpr: Expression, name: String = "None", dataType: Option[DataType]=None): AttributeReference = { fieldExpr match { case a: Cast => - getResultAttrFromExpr(a.child, name) + getResultAttrFromExpr(a.child, name, Some(a.dataType)) case a: AttributeReference => if (name != "None") { new AttributeReference(name, a.dataType, a.nullable)() @@ -134,9 +140,15 @@ object ConverterUtils extends Logging { //TODO: a walkaround since we didn't support cast yet if (a.child.isInstanceOf[Cast]) { val tmp = if (name != "None") { - new Alias(a.child.asInstanceOf[Cast].child, name)(a.exprId, a.qualifier, a.explicitMetadata) + new Alias(a.child.asInstanceOf[Cast].child, name)( + a.exprId, + a.qualifier, + a.explicitMetadata) } else { - new Alias(a.child.asInstanceOf[Cast].child, a.name)(a.exprId, a.qualifier, a.explicitMetadata) + new Alias(a.child.asInstanceOf[Cast].child, a.name)( + a.exprId, + a.qualifier, + a.explicitMetadata) } tmp.toAttribute.asInstanceOf[AttributeReference] } else { @@ -147,13 +159,22 @@ object ConverterUtils extends Logging { a.toAttribute.asInstanceOf[AttributeReference] } } + case d: ColumnarDivide => + new AttributeReference(name, DoubleType, d.nullable)() + case m: ColumnarMultiply => + new AttributeReference(name, m.dataType, m.nullable)() case other => val a = if (name != "None") { new Alias(other, name)() } else { new Alias(other, "res")() } - a.toAttribute.asInstanceOf[AttributeReference] + val tmpAttr = a.toAttribute.asInstanceOf[AttributeReference] + if (dataType.isDefined) { + new AttributeReference(tmpAttr.name, dataType.getOrElse(null), tmpAttr.nullable)() + } else { + tmpAttr + } } } @@ -166,15 +187,21 @@ object ConverterUtils extends Logging { } def combineArrowRecordBatch(rb1: ArrowRecordBatch, rb2: ArrowRecordBatch): ArrowRecordBatch = { - val numRows = rb1.getLength() - val rb1_nodes = rb1.getNodes() - val rb2_nodes = rb2.getNodes() - val rb1_bufferlist = rb1.getBuffers() - val rb2_bufferlist = rb2.getBuffers() - - val combined_nodes = rb1_nodes.addAll(rb2_nodes) - val combined_bufferlist = rb1_bufferlist.addAll(rb2_bufferlist) - new ArrowRecordBatch(numRows, rb1_nodes, rb1_bufferlist) + val numRows = rb1.getLength() + val rb1_nodes = rb1.getNodes() + val rb2_nodes = rb2.getNodes() + val rb1_bufferlist = rb1.getBuffers() + val rb2_bufferlist = rb2.getBuffers() + + val combined_nodes = rb1_nodes.addAll(rb2_nodes) + val combined_bufferlist = rb1_bufferlist.addAll(rb2_bufferlist) + new ArrowRecordBatch(numRows, rb1_nodes, rb1_bufferlist) + } + + def toArrowSchema(attributes: Seq[Attribute]): Schema = { + def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + ArrowUtils.toArrowSchema(fromAttributes(attributes), SQLConf.get.sessionLocalTimeZone) } override def toString(): String = { @@ -182,4 +209,3 @@ object ConverterUtils extends Logging { } } - diff --git a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala index 3a8b92ad9..f66bbf541 100644 --- a/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala +++ b/oap-native-sql/core/src/main/scala/com/intel/sparkColumnarPlugin/vectorized/ArrowColumnarBatchSerializer.scala @@ -20,102 +20,68 @@ package com.intel.sparkColumnarPlugin.vectorized import java.io._ import java.nio.ByteBuffer -import com.intel.sparkColumnarPlugin.expression.CodeGeneration -import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector - import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} -import org.apache.arrow.vector.types.pojo.{Field, Schema} -import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} -import org.apache.spark.sql.util.ArrowUtils - +import org.apache.arrow.util.SchemaUtils +import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging import org.apache.spark.serializer.{ DeserializationStream, SerializationStream, Serializer, SerializerInstance } +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} import scala.collection.JavaConverters._ -import scala.collection.immutable.List +import scala.collection.mutable +import scala.collection.mutable.ListBuffer import scala.reflect.ClassTag -class ArrowColumnarBatchSerializer extends Serializer with Serializable { +class ArrowColumnarBatchSerializer(readBatchNumRows: SQLMetric = null) + extends Serializer + with Serializable { /** Creates a new [[SerializerInstance]]. */ override def newInstance(): SerializerInstance = - new ArrowColumnarBatchSerializerInstance + new ArrowColumnarBatchSerializerInstance(readBatchNumRows) } -private class ArrowColumnarBatchSerializerInstance extends SerializerInstance { - - override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { - - override def writeValue[T: ClassTag](value: T): SerializationStream = { - val root = createVectorSchemaRoot(value.asInstanceOf[ColumnarBatch]) - val writer = new ArrowStreamWriter(root, null, out) - writer.writeBatch() - writer.end() - this - } - - override def writeKey[T: ClassTag](key: T): SerializationStream = { - // The key is only needed on the map side when computing partition ids. It does not need to - // be shuffled. - assert(null == key || key.isInstanceOf[Int]) - this - } - - override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { - // This method is never called by shuffle code. - throw new UnsupportedOperationException - } - - override def writeObject[T: ClassTag](t: T): SerializationStream = { - // This method is never called by shuffle code. - throw new UnsupportedOperationException - } - - override def flush(): Unit = { - out.flush() - } - - override def close(): Unit = { - out.close() - } - - private def createVectorSchemaRoot(cb: ColumnarBatch): VectorSchemaRoot = { - val fieldTypesList = List - .range(0, cb.numCols()) - .map(i => Field.nullable(s"c_$i", CodeGeneration.getResultType(cb.column(i).dataType()))) - val arrowSchema = new Schema(fieldTypesList.asJava) - val vectors = List - .range(0, cb.numCols()) - .map( - i => - cb.column(i) - .asInstanceOf[ArrowWritableColumnVector] - .getValueVector - .asInstanceOf[FieldVector]) - val root = new VectorSchemaRoot(arrowSchema, vectors.asJava, cb.numRows) - root.setRowCount(cb.numRows) - root - } - } +private class ArrowColumnarBatchSerializerInstance(readBatchNumRows: SQLMetric) + extends SerializerInstance + with Logging { override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { - private[this] val columnBatchSize = SQLConf.get.columnBatchSize - private[this] var allocator: BufferAllocator = _ + private val columnBatchSize = SQLConf.get.columnBatchSize + private val compressionEnabled = + SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true) + private val compressionCodec = SparkEnv.get.conf.get("spark.io.compression.codec", "lz4") + private val allocator: BufferAllocator = ArrowUtils.rootAllocator + .newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue) - private[this] var reader: ArrowStreamReader = _ - private[this] var root: VectorSchemaRoot = _ - private[this] var vectors: Array[ArrowWritableColumnVector] = _ + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var vectors: Array[ColumnVector] = _ + private var cb: ColumnarBatch = _ + private var batchLoaded = true - // TODO: see if asKeyValueIterator should be override + private var jniWrapper: ShuffleDecompressionJniWrapper = _ + private var schemaHolderId: Long = 0 + private var vectorLoader: VectorLoader = _ + + private var numBatchesTotal: Long = _ + private var numRowsTotal: Long = _ + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } override def readKey[T: ClassTag](): T = { // We skipped serialization of the key in writeKey(), so just return a dummy value since @@ -125,33 +91,60 @@ private class ArrowColumnarBatchSerializerInstance extends SerializerInstance { @throws(classOf[EOFException]) override def readValue[T: ClassTag](): T = { - try { - allocator = ArrowUtils.rootAllocator - .newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue) - reader = new ArrowStreamReader(in, allocator) - root = reader.getVectorSchemaRoot - // vectors = new ArrayBuffer[ColumnVector]() - } catch { - case _: IOException => - reader.close(false) - root.close() + if (reader != null && batchLoaded) { + root.clear() + if (cb != null) { + cb.close() + cb = null + } + + try { + batchLoaded = reader.loadNextBatch() + } catch { + case ioe: IOException => + this.close() + logError("Failed to load next RecordBatch", ioe) + throw ioe + } + if (batchLoaded) { + val numRows = root.getRowCount + logDebug(s"Read ColumnarBatch of ${numRows} rows") + + numBatchesTotal += 1 + numRowsTotal += numRows + + // jni call to decompress buffers + if (compressionEnabled) { + decompressVectors() + } + + val newFieldVectors = root.getFieldVectors.asScala.map { vector => + val newVector = vector.getField.createVector(allocator) + vector.makeTransferPair(newVector).transfer() + newVector + }.asJava + + vectors = ArrowWritableColumnVector + .loadColumns(numRows, newFieldVectors) + .toArray[ColumnVector] + + cb = new ColumnarBatch(vectors, numRows) + cb.asInstanceOf[T] + } else { + this.close() throw new EOFException + } + } else { + reader = new ArrowCompressedStreamReader(in, allocator) + try { + root = reader.getVectorSchemaRoot + } catch { + case _: IOException => + this.close() + throw new EOFException + } + readValue() } - - var numRows = 0 - // "root.rowCount" is set to 0 when reader.loadNextBatch reach EOF - while (reader.loadNextBatch()) { - numRows += root.getRowCount - } - - assert( - numRows <= columnBatchSize, - "the number of loaded rows exceed the maximum columnar batch size") - - vectors = ArrowWritableColumnVector.loadColumns(numRows, root.getFieldVectors) - val cb = new ColumnarBatch(vectors.toArray, numRows) - // the underlying ColumnVectors of this ColumnarBatch might be empty - cb.asInstanceOf[T] } override def readObject[T: ClassTag](): T = { @@ -160,13 +153,65 @@ private class ArrowColumnarBatchSerializerInstance extends SerializerInstance { } override def close(): Unit = { - if (reader != null) reader.close(false) - if (root != null) root.close() - in.close() + if (numBatchesTotal > 0) { + readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) + } + if (cb != null) cb.close() + if (reader != null) reader.close(true) + if (jniWrapper != null) jniWrapper.close(schemaHolderId) + } + + private def decompressVectors(): Unit = { + if (jniWrapper == null) { + jniWrapper = new ShuffleDecompressionJniWrapper + schemaHolderId = jniWrapper.make(SchemaUtils.get.serialize(root.getSchema)) + } + if (vectorLoader == null) { + vectorLoader = new VectorLoader(root) + } + val bufAddrs = new ListBuffer[Long]() + val bufSizes = new ListBuffer[Long]() + val bufBS = mutable.BitSet() + var bufIdx = 0 + + root.getFieldVectors.asScala.foreach { vector => + val validityBuf = vector.getValidityBuffer + if (validityBuf + .capacity() <= 8 || java.lang.Long.bitCount(validityBuf.getLong(0)) == 64 || + java.lang.Long.bitCount(validityBuf.getLong(0)) == 0) { + bufBS.add(bufIdx) + } + vector.getBuffers(false).foreach { buffer => + bufAddrs += buffer.memoryAddress() + // buffer.readableBytes() will return wrong readable length here since it is initialized by + // data stored in IPC message header, which is not the actual compressed length + bufSizes += buffer.capacity() + bufIdx += 1 + } + } + + val builder = jniWrapper.decompress( + schemaHolderId, + compressionCodec, + root.getRowCount, + bufAddrs.toArray, + bufSizes.toArray, + bufBS.toBitMask) + val builerImpl = new ArrowRecordBatchBuilderImpl(builder) + val decompressedRecordBatch = builerImpl.build + + root.clear() + if (decompressedRecordBatch != null) { + vectorLoader.load(decompressedRecordBatch) + } } } } + // Columnar shuffle write process don't need this. + override def serializeStream(s: OutputStream): SerializationStream = + throw new UnsupportedOperationException + // These methods are never called by shuffle code. override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala new file mode 100644 index 000000000..7b49d8c8e --- /dev/null +++ b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.execution.metric.SQLMetric + +import scala.reflect.ClassTag + +/** + * :: DeveloperApi :: + * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, + * the RDD is transient since we don't need it on the executor side. + * + * @param _rdd the parent RDD + * @param partitioner partitioner used to partition the shuffle output + * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If not set + * explicitly then the default serializer, as specified by `spark.serializer` + * config option, will be used. + * @param keyOrdering key ordering for RDD's shuffles + * @param aggregator map/reduce-side aggregator for RDD's shuffle + * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) + * @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask + * @param serializedSchema serialized [[org.apache.arrow.vector.types.pojo.Schema]] for ColumnarBatch + * @param dataSize for shuffle data size tracking + */ +class ColumnarShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( + @transient private val _rdd: RDD[_ <: Product2[K, V]], + override val partitioner: Partitioner, + override val serializer: Serializer = SparkEnv.get.serializer, + override val keyOrdering: Option[Ordering[K]] = None, + override val aggregator: Option[Aggregator[K, V, C]] = None, + override val mapSideCombine: Boolean = false, + override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, + val serializedSchema: Array[Byte], + val dataSize: SQLMetric, + val splitTime: SQLMetric) + extends ShuffleDependency[K, V, C]( + _rdd, + partitioner, + serializer, + keyOrdering, + aggregator, + mapSideCombine, + shuffleWriterProcessor) {} diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala new file mode 100644 index 000000000..1619c4b6f --- /dev/null +++ b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{File, FileInputStream, FileOutputStream, IOException} +import java.nio.ByteBuffer + +import com.google.common.annotations.VisibleForTesting +import com.google.common.io.Closeables +import com.intel.sparkColumnarPlugin.vectorized.{ + ArrowWritableColumnVector, + ShuffleSplitterJniWrapper +} +import org.apache.arrow.util.SchemaUtils +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +import scala.collection.mutable.ListBuffer + +class ColumnarShuffleWriter[K, V]( + shuffleBlockResolver: IndexShuffleBlockResolver, + handle: BaseShuffleHandle[K, V, V], + mapId: Long, + writeMetrics: ShuffleWriteMetricsReporter) + extends ShuffleWriter[K, V] + with Logging { + + private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]] + + private val conf = SparkEnv.get.conf + + private val blockManager = SparkEnv.get.blockManager + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = _ + + private val transeferToEnabled = conf.getBoolean("spark.file.transferTo", true) + private val compressionEnabled = conf.getBoolean("spark.shuffle.compress", true) + private val compressionCodec = conf.get("spark.io.compression.codec", "lz4") + private val nativeBufferSize = + conf.getLong("spark.sql.execution.arrow.maxRecordsPerBatch", 4096) + + private val jniWrapper = new ShuffleSplitterJniWrapper() + + private var nativeSplitter: Long = 0 + + private var partitionLengths: Array[Long] = _ + + @throws[IOException] + override def write(records: Iterator[Product2[K, V]]): Unit = { + if (!records.hasNext) { + partitionLengths = new Array[Long](dep.partitioner.numPartitions) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, null) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + return + } + + if (nativeSplitter == 0) { + val schema: Schema = Schema.deserialize(ByteBuffer.wrap(dep.serializedSchema)) + val localDirs = Utils.getConfiguredLocalDirs(conf).mkString(",") + nativeSplitter = + jniWrapper.make(SchemaUtils.get.serialize(schema), nativeBufferSize, localDirs) + if (compressionEnabled) { + jniWrapper.setCompressionCodec(nativeSplitter, compressionCodec) + } + } + + while (records.hasNext) { + val cb = records.next()._2.asInstanceOf[ColumnarBatch] + if (cb.numRows == 0 || cb.numCols == 0) { + logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") + } else { + val bufAddrs = new ListBuffer[Long]() + val bufSizes = new ListBuffer[Long]() + (0 until cb.numCols).foreach { idx => + val column = cb.column(idx).asInstanceOf[ArrowWritableColumnVector] + column.getValueVector + .getBuffers(false) + .foreach { buffer => + bufAddrs += buffer.memoryAddress() + bufSizes += buffer.readableBytes() + } + } + dep.dataSize.add(bufSizes.sum) + + val startTime = System.nanoTime() + jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray) + dep.splitTime.add(System.nanoTime() - startTime) + writeMetrics.incRecordsWritten(1) + } + } + + val startTime = System.nanoTime() + jniWrapper.stop(nativeSplitter) + dep.splitTime.add(System.nanoTime() - startTime) + writeMetrics.incBytesWritten(jniWrapper.getTotalBytesWritten(nativeSplitter)) + + val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val tmp = Utils.tempFileWith(output) + try { + partitionLengths = writePartitionedFile(tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + } finally { + if (tmp.exists() && !tmp.delete()) { + logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") + } + } + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + } + + override def stop(success: Boolean): Option[MapStatus] = { + + try { + if (stopping) { + None + } + stopping = true + if (success) { + Option(mapStatus) + } else { + None + } + } finally { + // delete the temporary files hold by native splitter + if (nativeSplitter != 0) { + try { + jniWrapper.getPartitionFileInfo(nativeSplitter).foreach { fileInfo => + { + val pid = fileInfo.getPid + val file = new File(fileInfo.getFilePath) + if (file.exists()) { + if (!file.delete()) { + logError(s"Unable to delete file for partition ${pid}") + } + } + } + } + } finally { + jniWrapper.close(nativeSplitter) + nativeSplitter = 0 + } + } + } + } + + @throws[IOException] + private def writePartitionedFile(outputFile: File): Array[Long] = { + + val lengths = new Array[Long](dep.partitioner.numPartitions) + val out = new FileOutputStream(outputFile, true) + val writerStartTime = System.nanoTime() + var threwException = true + + try { + jniWrapper.getPartitionFileInfo(nativeSplitter).foreach { fileInfo => + { + val pid = fileInfo.getPid + val filePath = fileInfo.getFilePath + + val file = new File(filePath) + if (file.exists()) { + val in = new FileInputStream(file) + var copyThrewException = true + + try { + lengths(pid) = Utils.copyStream(in, out, false, transeferToEnabled) + copyThrewException = false + } finally { + Closeables.close(in, copyThrewException) + } + if (!file.delete()) { + logError(s"Unable to delete file for partition ${pid}") + } else { + logDebug(s"Deleting temporary shuffle file ${filePath} for partition ${pid}") + } + } else { + logWarning( + s"Native shuffle writer temporary file ${filePath} for partition ${pid} not exists") + } + } + } + threwException = false + } finally { + Closeables.close(out, threwException) + writeMetrics.incWriteTime(System.nanoTime - writerStartTime) + } + lengths + } + + @VisibleForTesting + def getPartitionLengths: Array[Long] = partitionLengths + +} diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala new file mode 100644 index 000000000..b6b191631 --- /dev/null +++ b/oap-native-sql/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.InputStream +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.storage.BlockId +import org.apache.spark.util.collection.OpenHashSet + +import scala.collection.JavaConverters._ + +class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + import ColumnarShuffleManager._ + + /** + * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + + /** + * Obtains a [[ShuffleHandle]] to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (dependency.isInstanceOf[ColumnarShuffleDependency[K, V, V]]) { + logInfo(s"Registering ColumnarShuffle shuffleId: ${shuffleId}") + new ColumnarShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) } + val env = SparkEnv.get + handle match { + case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => + new ColumnarShuffleWriter(shuffleBlockResolver, columnarShuffleHandle, mapId, metrics) + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter( + shuffleBlockResolver, + other, + mapId, + context, + shuffleExecutorComponents) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker + .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) + if (handle.isInstanceOf[ColumnarShuffleHandle[K, _]]) { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + new SerializerManager( + SparkEnv.get.serializer, + SparkEnv.get.conf, + SparkEnv.get.securityManager.getIOEncryptionKey()) { + // Bypass the shuffle read compression + override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + wrapForEncryption(s) + } + }) + } else { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds => + mapTaskIds.iterator.foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } + + override def getReaderForRange[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics) + } + +} + +object ColumnarShuffleManager extends Logging { + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } +} + +private[spark] class ColumnarShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 6ff2e7bfc..3d82fe9d4 100644 --- a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -19,14 +19,28 @@ package org.apache.spark.sql.execution import java.util.Random -import com.intel.sparkColumnarPlugin.vectorized.ArrowColumnarBatchSerializer -import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector +import com.intel.sparkColumnarPlugin.expression.ConverterUtils +import com.intel.sparkColumnarPlugin.vectorized.{ + ArrowColumnarBatchSerializer, + ArrowWritableColumnVector +} +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.arrow.vector.{FieldVector, IntVector} +import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ColumnarShuffleDependency import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.{ + Attribute, + AttributeReference, + BoundReference, + UnsafeProjection +} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor import org.apache.spark.sql.execution.metric.{ SQLMetric, @@ -34,67 +48,66 @@ import org.apache.spark.sql.execution.metric.{ SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter } -import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.{HashPartitioner, Partitioner, ShuffleDependency, TaskContext} +import org.apache.spark.util.MutablePair import scala.collection.JavaConverters._ -import scala.collection.mutable -case class ColumnarShuffleExchangeExec( +class ColumnarShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, canChangeNumPartitions: Boolean = true) - extends Exchange { + extends ShuffleExchangeExec(outputPartitioning, child, canChangeNumPartitions) { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) - private lazy val readMetrics = + override private[sql] lazy val readMetrics = SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics: Map[String, SQLMetric] = Map( - "dataSize" -> SQLMetrics - .createSizeMetric(sparkContext, "data size")) ++ readMetrics ++ writeMetrics + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "split time"), + "avgReadBatchNumRows" -> SQLMetrics + .createAverageMetric(sparkContext, "avg read batch num rows")) ++ readMetrics ++ writeMetrics override def nodeName: String = "ColumnarExchange" override def supportsColumnar: Boolean = true - // Disable code generation - @transient lazy val inputRDD: RDD[ColumnarBatch] = child.executeColumnar() + @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar() - private val serializer: Serializer = new ArrowColumnarBatchSerializer - - override protected def doExecute(): RDD[InternalRow] = { - child.execute() - } + private val serializer: Serializer = new ArrowColumnarBatchSerializer( + longMetric("avgReadBatchNumRows")) @transient - lazy val shuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { ColumnarShuffleExchangeExec.prepareShuffleDependency( - inputRDD, + inputColumnarRDD, child.output, outputPartitioning, serializer, - writeMetrics) + writeMetrics, + longMetric("dataSize"), + longMetric("splitTime")) } - def createShuffledRDD(partitionStartIndices: Option[Array[Int]]): ShuffledColumnarBatchRDD = { - new ShuffledColumnarBatchRDD(shuffleDependency, readMetrics, partitionStartIndices) + def createColumnarShuffledRDD( + partitionStartIndices: Option[Array[Int]]): ShuffledColumnarBatchRDD = { + new ShuffledColumnarBatchRDD(columnarShuffleDependency, readMetrics, partitionStartIndices) } - private var cachedShuffleRDD: ShuffledColumnarBatchRDD = null + private var cachedShuffleRDD: ShuffledColumnarBatchRDD = _ override def doExecuteColumnar(): RDD[ColumnarBatch] = { if (cachedShuffleRDD == null) { - cachedShuffleRDD = createShuffledRDD(None) + cachedShuffleRDD = createColumnarShuffledRDD(None) } cachedShuffleRDD } } -object ColumnarShuffleExchangeExec { +object ColumnarShuffleExchangeExec extends Logging { def prepareShuffleDependency( rdd: RDD[ColumnarBatch], @@ -102,10 +115,12 @@ object ColumnarShuffleExchangeExec { newPartitioning: Partitioning, serializer: Serializer, writeMetrics: Map[String, SQLMetric], - enableArrowColumnVector: Boolean = true) - : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + dataSize: SQLMetric, + splitTime: SQLMetric): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { - assert(enableArrowColumnVector, "only support arrow column vector") + val arrowSchema: Schema = + ConverterUtils.toArrowSchema( + AttributeReference("pid", IntegerType, nullable = false)() +: outputAttributes) val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) @@ -116,125 +131,172 @@ object ColumnarShuffleExchangeExec { // `HashPartitioning.partitionIdExpression` to produce partitioning key. override def getPartition(key: Any): Int = key.asInstanceOf[Int] } + case RangePartitioning(sortingExpressions, numPartitions) => + // Extract only fields used for sorting to avoid collecting large fields that does not + // affect sorting result when deciding partition bounds in RangePartitioner + val rddForSampling = rdd.mapPartitionsInternal { iter => + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + iter.flatMap(batch => { + val rows = batch.rowIterator.asScala + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + val mutablePair = new MutablePair[InternalRow, Null]() + rows.map(row => mutablePair.update(projection(row).copy(), null)) + }) + } + // Construct ordering on extracted sort key. + val orderingAttributes = sortingExpressions.zipWithIndex.map { + case (ord, i) => + ord.copy(child = BoundReference(i, ord.dataType, ord.nullable)) + } + implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes) + new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new Partitioner { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } case _ => sys.error(s"Exchange not implemented for $newPartitioning") - // TODO: Handle RangePartitioning. // TODO: Handle BroadcastPartitioning. } val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && newPartitioning.numPartitions > 1 - val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition - - val schema = StructType.fromAttributes(outputAttributes) - - def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { - case RoundRobinPartitioning(numPartitions) => - // Distributes elements evenly across output partitions, starting from a random partition. - var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) - (row: InternalRow) => { - // The HashPartitioner will handle the `mod` by the number of partitions - position += 1 - position - } - case h: HashPartitioning => - val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) - row => projection(row).getInt(0) - case RangePartitioning(sortingExpressions, _) => - val projection = - UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) - row => projection(row) - case SinglePartition => identity - case _ => sys.error(s"Exchange not implemented for $newPartitioning") - } - - def getPartitionIdMask(cb: ColumnarBatch): Array[Int] = { - val getPartitionKey = getPartitionKeyExtractor() - val rowIterator = cb.rowIterator().asScala - val partitionIdMask = new Array[Int](cb.numRows()) - - rowIterator.zipWithIndex.foreach { - case (row, idx) => - partitionIdMask(idx) = part.getPartition(getPartitionKey(row)) - } - partitionIdMask - } - - def createColumnVectors(numRows: Int): Array[WritableColumnVector] = { - val vectors: Seq[WritableColumnVector] = - ArrowWritableColumnVector.allocateColumns(numRows, schema) - vectors.toArray - } - - def partitionIdToColumnVectors( - partitionCounter: PartitionCounter): Map[Int, Array[WritableColumnVector]] = { - val columnVectorsWithPartitionId = mutable.Map[Int, Array[WritableColumnVector]]() - for (pid <- partitionCounter.keys) { - val numRows = partitionCounter(pid) - if (numRows > 0) { - columnVectorsWithPartitionId.update(pid, createColumnVectors(numRows)) - } - } - columnVectorsWithPartitionId.toMap[Int, Array[WritableColumnVector]] - } + // RDD passed to ShuffleDependency should be the form of key-value pairs. + // As for Columnar Shuffle, we create a new column to store the partition ids for each row in + // one ColumnarBatch, and append it to the front. ColumnarShuffleWriter will extract partition ids + // from ColumnarBatch other than read the "key" part. Thus in Columnar Shuffle we never use the "key" part. + val rddWithDummyKey: RDD[Product2[Int, ColumnarBatch]] = { + val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition - val rddWithPartitionIds: RDD[Product2[Int, ColumnarBatch]] = { rdd.mapPartitionsWithIndexInternal( (_, cbIterator) => { - val converters = new RowToColumnConverter(schema) - cbIterator.flatMap { - cb => - val partitionCounter = new PartitionCounter - val partitionIdMask = getPartitionIdMask(cb) - partitionCounter.update(partitionIdMask) - - val toNewVectors = partitionIdToColumnVectors(partitionCounter) - val rowIterator = cb.rowIterator().asScala - rowIterator.zipWithIndex.foreach { rowWithIdx => - val idx = rowWithIdx._2 - val pid = partitionIdMask(idx) - converters.convert(rowWithIdx._1, toNewVectors(pid)) - } - toNewVectors.toSeq.map { - case (pid, vectors) => - (pid, new ColumnarBatch(vectors.toArray, partitionCounter(pid))) - } + newPartitioning match { + case SinglePartition => + CloseablePairedColumnarBatchIterator( + cbIterator + .filter(cb => cb.numRows != 0 && cb.numCols != 0) + .map(cb => { + val pids = Array.fill(cb.numRows)(0) + (0, pushFrontPartitionIds(pids, cb)) + })) + case RoundRobinPartitioning(numPartitions) => + // Distributes elements evenly across output partitions, starting from a random partition. + var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) + CloseablePairedColumnarBatchIterator( + cbIterator + .filter(cb => cb.numRows != 0 && cb.numCols != 0) + .map(cb => { + val pids = cb.rowIterator.asScala.map { _ => + position += 1 + part.getPartition(position) + }.toArray + (0, pushFrontPartitionIds(pids, cb)) + })) + case h: HashPartitioning => + val projection = + UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + CloseablePairedColumnarBatchIterator( + cbIterator + .filter(cb => cb.numRows != 0 && cb.numCols != 0) + .map(cb => { + val pids = cb.rowIterator.asScala.map { row => + part.getPartition(projection(row).getInt(0)) + }.toArray + (0, pushFrontPartitionIds(pids, cb)) + })) + case RangePartitioning(sortingExpressions, _) => + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + CloseablePairedColumnarBatchIterator( + cbIterator + .filter(cb => cb.numRows != 0 && cb.numCols != 0) + .map(cb => { + val pids = cb.rowIterator.asScala.map { row => + part.getPartition(projection(row)) + }.toArray + (0, pushFrontPartitionIds(pids, cb)) + })) + case _ => sys.error(s"Exchange not implemented for $newPartitioning") } }, isOrderSensitive = isOrderSensitive) } val dependency = - new ShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( - rddWithPartitionIds, - new PartitionIdPassthrough(part.numPartitions), + new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( + rddWithDummyKey, + new PartitionIdPassthrough(newPartitioning.numPartitions), serializer, - shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics)) + shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), + serializedSchema = arrowSchema.toByteArray, + dataSize = dataSize, + splitTime = splitTime) dependency } + + def pushFrontPartitionIds(partitionIds: Seq[Int], cb: ColumnarBatch): ColumnarBatch = { + val length = partitionIds.length + + val vectors = (0 until cb.numCols()).map { idx => + val vector = cb + .column(idx) + .asInstanceOf[ArrowWritableColumnVector] + .getValueVector + .asInstanceOf[FieldVector] + vector.setValueCount(length) + vector + } + val pidVec = new IntVector("pid", vectors(0).getAllocator) + + pidVec.allocateNew(length) + (0 until length).foreach { i => + pidVec.set(i, partitionIds(i)) + } + pidVec.setValueCount(length) + + val newVectors = ArrowWritableColumnVector.loadColumns(length, (pidVec +: vectors).asJava) + new ColumnarBatch(newVectors.toArray, cb.numRows) + } } -private class PartitionCounter { - private val pidCounter = mutable.Map[Int, Int]() +case class CloseablePairedColumnarBatchIterator(iter: Iterator[(Int, ColumnarBatch)]) + extends Iterator[(Int, ColumnarBatch)] + with Logging { - def size: Int = pidCounter.size + private var cur: (Int, ColumnarBatch) = _ - def keys: Iterable[Int] = pidCounter.keys + TaskContext.get().addTaskCompletionListener[Unit] { _ => + closeAppendedVector() + } - def update(partitionIdMask: Array[Int]): Unit = { - partitionIdMask.foreach { partitionId => - pidCounter.update(partitionId, pidCounter.get(partitionId) match { - case Some(cnt) => cnt + 1 - case None => 1 - }) + private def closeAppendedVector(): Unit = { + if (cur != null) { + logDebug("Close appended partition id vector") + cur match { + case (_, cb: ColumnarBatch) => + cb.column(0).asInstanceOf[ArrowWritableColumnVector].close() + } + cur = null } } - def apply(partitionId: Int): Int = pidCounter.getOrElse(partitionId, 0) + override def hasNext: Boolean = { + iter.hasNext + } + + override def next(): (Int, ColumnarBatch) = { + closeAppendedVector() + if (iter.hasNext) { + cur = iter.next() + cur + } else Iterator.empty.next() + } } diff --git a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala index bf265286b..1d678c0f7 100644 --- a/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala +++ b/oap-native-sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/VectorizedFilePartitionReaderHandler.scala @@ -26,7 +26,11 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.hadoop.fs.Path -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, PartitionReader} +import org.apache.spark.sql.connector.read.{ + InputPartition, + PartitionReaderFactory, + PartitionReader +} import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} import org.apache.spark.sql.execution.datasources.v2.PartitionedFileReader import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory @@ -35,48 +39,54 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object VectorizedFilePartitionReaderHandler { def get( - inputPartition: InputPartition, - parquetReaderFactory: ParquetPartitionReaderFactory) - : FilePartitionReader[ColumnarBatch] = { - val iter: Iterator[PartitionedFileReader[ColumnarBatch]] - = inputPartition.asInstanceOf[FilePartition].files.toIterator.map { file => - val filePath = new Path(new URI(file.filePath)) - val split = - new org.apache.parquet.hadoop.ParquetInputSplit( - filePath, - file.start, - file.start + file.length, - file.length, - Array.empty, - null) - //val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion - /*val convertTz = + inputPartition: InputPartition, + parquetReaderFactory: ParquetPartitionReaderFactory, + tmpDir: String): FilePartitionReader[ColumnarBatch] = { + val iter: Iterator[PartitionedFileReader[ColumnarBatch]] = + inputPartition.asInstanceOf[FilePartition].files.toIterator.map { file => + val filePath = new Path(new URI(file.filePath)) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) + //val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion + /*val convertTz = if (timestampConversion && !isCreatedByParquetMr) { Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) } else { None }*/ - val capacity = 4096 - //partitionReaderFactory.createColumnarReader(inputPartition) - val dataSchema = parquetReaderFactory.dataSchema - val readDataSchema = parquetReaderFactory.readDataSchema - - val conf = parquetReaderFactory.broadcastedConf.value.value - val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val capacity = 4096 + //partitionReaderFactory.createColumnarReader(inputPartition) + val dataSchema = parquetReaderFactory.dataSchema + val readDataSchema = parquetReaderFactory.readDataSchema - val vectorizedReader = new VectorizedParquetArrowReader(split.getPath().toString(), null, false, capacity, dataSchema, readDataSchema) - vectorizedReader.initialize(split, hadoopAttemptContext) - val partitionReader = new PartitionReader[ColumnarBatch] { - override def next(): Boolean = vectorizedReader.nextKeyValue() - override def get(): ColumnarBatch = - vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] - override def close(): Unit = vectorizedReader.close() - } + val conf = parquetReaderFactory.broadcastedConf.value.value + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + val vectorizedReader = new VectorizedParquetArrowReader( + split.getPath().toString(), + null, + false, + capacity, + dataSchema, + readDataSchema, + tmpDir) + vectorizedReader.initialize(split, hadoopAttemptContext) + val partitionReader = new PartitionReader[ColumnarBatch] { + override def next(): Boolean = vectorizedReader.nextKeyValue() + override def get(): ColumnarBatch = + vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] + override def close(): Unit = vectorizedReader.close() + } - PartitionedFileReader(file, partitionReader) - } + PartitionedFileReader(file, partitionReader) + } new FilePartitionReader[ColumnarBatch](iter) } } - diff --git a/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java b/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java index 785382a86..701963b9b 100644 --- a/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java +++ b/oap-native-sql/core/src/test/java/com/intel/sparkColumnarPlugin/datasource/parquet/ParquetReadTest.java @@ -46,7 +46,6 @@ import io.netty.buffer.ArrowBuf; public class ParquetReadTest { - @Rule public TemporaryFolder testFolder = new TemporaryFolder(); private BufferAllocator allocator; @@ -63,29 +62,26 @@ public void teardown() { @Test public void testParquetRead() throws Exception { - File testFile = testFolder.newFile("_tmpfile_ParquetReadTest"); - //String path = testFile.getAbsolutePath(); - String path = "hdfs://sr602:9000/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet?user=root&replication=1"; + // String path = testFile.getAbsolutePath(); + String path = + "hdfs://sr602:9000/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet?user=root&replication=1"; int numColumns = 0; int[] rowGroupIndices = {0}; int[] columnIndices = new int[numColumns]; - Schema schema = - new Schema( - asList( - field("n_nationkey", new Int(64, true)), - field("n_name", new Utf8()), - field("n_regionkey", new Int(64, true)), - field("n_comment", new Utf8()) - )); + Schema schema = new Schema( + asList(field("n_nationkey", new Int(64, true)), field("n_name", new Utf8()), + field("n_regionkey", new Int(64, true)), field("n_comment", new Utf8()))); - ParquetReader reader = new ParquetReader(path, rowGroupIndices, columnIndices, 16, allocator); + ParquetReader reader = + new ParquetReader(path, rowGroupIndices, columnIndices, 16, allocator, ""); Schema readedSchema = reader.getSchema(); for (int i = 0; i < readedSchema.getFields().size(); i++) { - assertEquals(schema.getFields().get(i).getName(), readedSchema.getFields().get(i).getName()); + assertEquals( + schema.getFields().get(i).getName(), readedSchema.getFields().get(i).getName()); } VectorSchemaRoot actualSchemaRoot = VectorSchemaRoot.create(readedSchema, allocator); @@ -103,7 +99,8 @@ public void testParquetRead() throws Exception { testFile.delete(); } - private static Field field(String name, boolean nullable, ArrowType type, Field... children) { + private static Field field( + String name, boolean nullable, ArrowType type, Field... children) { return new Field(name, new FieldType(nullable, type, null, null), asList(children)); } diff --git a/oap-native-sql/core/src/test/resources/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet b/oap-native-sql/core/src/test/resources/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet new file mode 100644 index 000000000..9eb182285 Binary files /dev/null and b/oap-native-sql/core/src/test/resources/part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet differ diff --git a/oap-native-sql/core/src/test/scala/com.intel.sparkColumnarPlugin/ExtensionSuite.scala b/oap-native-sql/core/src/test/scala/com.intel.sparkColumnarPlugin/ExtensionSuite.scala deleted file mode 100644 index bebb60389..000000000 --- a/oap-native-sql/core/src/test/scala/com.intel.sparkColumnarPlugin/ExtensionSuite.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.sparkColumnarPlugin - -import org.apache.spark.sql.{Row, SparkSession} - -import org.scalatest.FunSuite - -class ExtensionSuite extends FunSuite { - - private def stop(spark: SparkSession): Unit = { - spark.stop() - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - - test("inject columnar exchange") { - val session = SparkSession - .builder() - .master("local[1]") - .config("org.apache.spark.example.columnar.enabled", value = true) - .config("spark.sql.extensions", "com.intel.sparkColumnarPlugin.ColumnarPlugin") - .appName("inject columnar exchange") - .getOrCreate() - - try { - import session.sqlContext.implicits._ - - val input = Seq((100), (200), (300)) - val data = input.toDF("vals").repartition(1) - val result = data.collect() - assert(result sameElements input.map(x => Row(x))) - } finally { - stop(session) - } - } -} diff --git a/oap-native-sql/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala b/oap-native-sql/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala new file mode 100644 index 000000000..5216eb8aa --- /dev/null +++ b/oap-native-sql/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.File +import java.nio.file.Files + +import com.intel.sparkColumnarPlugin.vectorized.ArrowWritableColumnVector +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.arrow.vector.{FieldVector, IntVector} +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.sort.ColumnarShuffleHandle +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.{Mock, MockitoAnnotations} + +import scala.collection.JavaConverters._ + +class ColumnarShuffleWriterSuite extends SharedSparkSession { + @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency + : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test ColumnarShuffleWriter") + .set("spark.file.transferTo", "true") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + .set("spark.sql.execution.arrow.maxRecordsPerBatch", "4096") + + private var taskMetrics: TaskMetrics = _ + private var tempDir: File = _ + private var outputFile: File = _ + + private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ + private val schema = new Schema( + List( + Field.nullable("pid", new ArrowType.Int(32, true)), + Field.nullable("f_1", new ArrowType.Int(32, true)), + Field.nullable("f_2", new ArrowType.Int(32, true)), + ).asJava) + private val allocator = new RootAllocator(Long.MaxValue) + + override def beforeEach() = { + super.beforeEach() + + tempDir = Utils.createTempDir() + outputFile = File.createTempFile("shuffle", null, tempDir) + taskMetrics = new TaskMetrics + + MockitoAnnotations.initMocks(this) + + shuffleHandle = + new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) + + when(dependency.partitioner).thenReturn(new HashPartitioner(11)) + when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) + when(dependency.serializedSchema).thenReturn(schema.toByteArray) + when(dependency.dataSize) + .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) + when(dependency.splitTime) + .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "split time")) + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + + doAnswer { (invocationOnMock: InvocationOnMock) => + val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + outputFile.delete + tmp.renameTo(outputFile) + } + null + }.when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def afterAll(): Unit = { + allocator.close() + super.afterAll() + } + + test("write empty iterator") { + val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( + blockResolver, + shuffleHandle, + 0, // MapId + taskContext.taskMetrics().shuffleWriteMetrics) + writer.write(Iterator.empty) + writer.stop( /* success = */ true) + + assert(writer.getPartitionLengths.sum === 0) + assert(outputFile.exists()) + assert(outputFile.length() === 0) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === 0) + assert(shuffleWriteMetrics.recordsWritten === 0) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write empty column batch") { + val vectorPid = new IntVector("pid", allocator) + val vector1 = new IntVector("v1", allocator) + val vector2 = new IntVector("v2", allocator) + + ColumnarShuffleWriterSuite.setIntVector(vectorPid) + ColumnarShuffleWriterSuite.setIntVector(vector1) + ColumnarShuffleWriterSuite.setIntVector(vector2) + val cb = ColumnarShuffleWriterSuite.makeColumnarBatch( + vectorPid.getValueCount, + List(vectorPid, vector1, vector2)) + + def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb), (0, cb)) + + val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( + blockResolver, + shuffleHandle, + 0L, // MapId + taskContext.taskMetrics().shuffleWriteMetrics) + + writer.write(records) + writer.stop(success = true) + cb.close() + + assert(writer.getPartitionLengths.sum === 0) + assert(outputFile.exists()) + assert(outputFile.length() === 0) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === 0) + assert(shuffleWriteMetrics.recordsWritten === 0) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write with some empty partitions") { + val vectorPid = new IntVector("pid", allocator) + val vector1 = new IntVector("v1", allocator) + val vector2 = new IntVector("v2", allocator) + ColumnarShuffleWriterSuite.setIntVector(vectorPid, 1, 2, 1, 10) + ColumnarShuffleWriterSuite.setIntVector(vector1, null, null, null, null) + ColumnarShuffleWriterSuite.setIntVector(vector2, 100, 100, null, null) + val cb = ColumnarShuffleWriterSuite.makeColumnarBatch( + vectorPid.getValueCount, + List(vectorPid, vector1, vector2)) + + def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb), (0, cb)) + + val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( + blockResolver, + shuffleHandle, + 0L, // MapId + taskContext.taskMetrics().shuffleWriteMetrics) + + writer.write(records) + writer.stop(success = true) + + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 8) // should be (11 - 3) zero length files + + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + + val bytes = Files.readAllBytes(outputFile.toPath) + val reader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(bytes), allocator) + try { + val schema = reader.getVectorSchemaRoot.getSchema + assert(schema.getFields == this.schema.getFields.subList(1, this.schema.getFields.size())) + } finally { + reader.close() + cb.close() + } + } +} + +object ColumnarShuffleWriterSuite { + + def setIntVector(vector: IntVector, values: Integer*): Unit = { + val length = values.length + vector.allocateNew(length) + (0 until length).foreach { i => + if (values(i) != null) { + vector.set(i, values(i).asInstanceOf[Int]) + } + } + vector.setValueCount(length) + } + + def makeColumnarBatch(capacity: Int, vectors: List[FieldVector]): ColumnarBatch = { + val columnVectors = ArrowWritableColumnVector.loadColumns(capacity, vectors.asJava) + new ColumnarBatch(columnVectors.toArray, capacity) + } + +} diff --git a/oap-native-sql/core/src/test/scala/org/apache/spark/sql/RepartitionSuite.scala b/oap-native-sql/core/src/test/scala/org/apache/spark/sql/RepartitionSuite.scala new file mode 100644 index 000000000..7bb6e5cb9 --- /dev/null +++ b/oap-native-sql/core/src/test/scala/org/apache/spark/sql/RepartitionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.{ + ColumnarShuffleExchangeExec, + ColumnarToRowExec, + RowToColumnarExec +} +import org.apache.spark.sql.test.SharedSparkSession + +class RepartitionSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test repartition") + .set("spark.sql.parquet.columnarReaderBatchSize", "4096") + .set("spark.sql.sources.useV1SourceList", "avro") + .set("spark.sql.extensions", "com.intel.sparkColumnarPlugin.ColumnarPlugin") + .set("spark.sql.execution.arrow.maxRecordsPerBatch", "4096") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + + def checkCoulumnarExec(data: DataFrame) = { + val found = data.queryExecution.executedPlan + .collect { + case r2c: RowToColumnarExec => 1 + case c2r: ColumnarToRowExec => 10 + case exc: ColumnarShuffleExchangeExec => 100 + } + .distinct + .sum + assert(found == 110) + } + + def withInput(input: DataFrame)( + transformation: Option[DataFrame => DataFrame], + repartition: DataFrame => DataFrame): Unit = { + val expected = transformation.getOrElse(identity[DataFrame](_))(input) + val data = repartition(expected) + checkCoulumnarExec(data) + checkAnswer(data, expected) + } + + lazy val input: DataFrame = Seq((1, "1"), (2, "20"), (3, "300")).toDF("id", "val") + + def withTransformationAndRepartition( + transformation: DataFrame => DataFrame, + repartition: DataFrame => DataFrame): Unit = + withInput(input)(Some(transformation), repartition) + + def withRepartition: (DataFrame => DataFrame) => Unit = withInput(input)(None, _) +} + +class SmallDataRepartitionSuite extends RepartitionSuite { + import testImplicits._ + + test("test round robin partitioning") { + withRepartition(df => df.repartition(2)) + } + + test("test hash partitioning") { + withRepartition(df => df.repartition('id)) + } + + test("test range partitioning") { + withRepartition(df => df.repartitionByRange('id)) + } + + ignore("test cached repartiiton") { + val data = input.cache.repartition(2) + + val found = data.queryExecution.executedPlan.collect { + case cache: InMemoryTableScanExec => 1 + case c2r: ColumnarToRowExec => 10 + case exc: ColumnarShuffleExchangeExec => 100 + }.sum + assert(found == 111) + + checkAnswer(data, input) + } +} + +class TPCHTableRepartitionSuite extends RepartitionSuite { + import testImplicits._ + + val filePath = getClass.getClassLoader + .getResource("part-00000-d648dd34-c9d2-4fe9-87f2-770ef3551442-c000.snappy.parquet") + .getFile + + override lazy val input = spark.read.parquet(filePath) + + test("test tpch round robin partitioning") { + withRepartition(df => df.repartition(2)) + } + + test("test tpch hash partitioning") { + withRepartition(df => df.repartition('n_nationkey)) + } + + test("test tpch range partitioning") { + withRepartition(df => df.repartitionByRange('n_name)) + } + + test("test tpch sum after repartition") { + withTransformationAndRepartition( + df => df.groupBy("n_regionkey").agg(Map("n_nationkey" -> "sum")), + df => df.repartition(2)) + } +} + +class DisableColumnarShuffle extends SmallDataRepartitionSuite { + override def sparkConf: SparkConf = { + super.sparkConf + .set("spark.shuffle.manager", "sort") + .set("spark.sql.codegen.wholeStage", "true") + } + + override def checkCoulumnarExec(data: DataFrame) = { + val found = data.queryExecution.executedPlan + .collect { + case c2r: ColumnarToRowExec => 1 + case exc: ColumnarShuffleExchangeExec => 10 + case exc: ShuffleExchangeExec => 100 + } + .distinct + .sum + assert(found == 101) + } +} diff --git a/oap-native-sql/cpp/.gitignore b/oap-native-sql/cpp/.gitignore new file mode 100644 index 000000000..22da1f13b --- /dev/null +++ b/oap-native-sql/cpp/.gitignore @@ -0,0 +1,24 @@ +thirdparty/*.tar* +CMakeFiles/ +CMakeCache.txt +CTestTestfile.cmake +Makefile +cmake_install.cmake +build/ +*-build/ +Testing/ +cmake-build-debug/ +cmake-build-release/ + +######################################### +# Editor temporary/working/backup files # +.#* +*\#*\# +[#]*# +*~ +*$ +*.bak +*flymake* +*.kdev4 +*.log +*.swp diff --git a/oap-native-sql/cpp/compile.sh b/oap-native-sql/cpp/compile.sh new file mode 100755 index 000000000..8f89e491b --- /dev/null +++ b/oap-native-sql/cpp/compile.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -eu + +CURRENT_DIR=$(cd "$(dirname "$BASH_SOURCE")"; pwd) +echo $CURRENT_DIR + +cd ${CURRENT_DIR} +if [ -d build ]; then + rm -r build +fi +mkdir build +cd build +cmake .. +make + +set +eu + diff --git a/oap-native-sql/cpp/src/CMakeLists.txt b/oap-native-sql/cpp/src/CMakeLists.txt index 06352cae2..604a9206d 100644 --- a/oap-native-sql/cpp/src/CMakeLists.txt +++ b/oap-native-sql/cpp/src/CMakeLists.txt @@ -18,6 +18,11 @@ option(TESTS "Build the tests" OFF) option(BENCHMARKS "Build the benchmarks" OFF) option(DEBUG "Enable Debug Info" OFF) +# same as the version required in arrow/ci/conda_env_cpp.yml +set(BOOST_MIN_VERSION "1.42.0") +find_package(Boost REQUIRED) +INCLUDE_DIRECTORIES(${Boost_INCLUDE_DIRS}) + find_package(JNI REQUIRED) set(source_root_directory ${CMAKE_CURRENT_SOURCE_DIR}) @@ -133,8 +138,8 @@ endif() if(BENCHMARKS) find_package(GTest) - #add_definitions(-DBENCHMARK_FILE_PATH=${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/) - add_compile_definitions(BENCHMARK_FILE_PATH="file://${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/") + add_definitions(-DBENCHMARK_FILE_PATH="file://${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/") + #add_compile_definitions(BENCHMARK_FILE_PATH="file://${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/source_files/") macro(package_add_benchmark TESTNAME) #configure_file(${ARGN}.in ${ARGN}) add_executable(${TESTNAME} ${ARGN}) @@ -167,6 +172,22 @@ if(NOT GANDIVA_LIB) message(FATAL_ERROR "Gandiva library not found") endif() +set(CODEGEN_HEADERS + third_party/ + ) +file(MAKE_DIRECTORY ${root_directory}/releases/include) +file(MAKE_DIRECTORY ${root_directory}/releases/include/codegen/common/) +file(MAKE_DIRECTORY ${root_directory}/releases/include/codegen/third_party/) +file(MAKE_DIRECTORY ${root_directory}/releases/include/codegen/arrow_compute/ext/) +file(COPY third_party/ DESTINATION ${root_directory}/releases/include/) +file(COPY third_party/ DESTINATION ${root_directory}/releases/include/third_party/) +file(COPY codegen/arrow_compute/ext/array_item_index.h DESTINATION ${root_directory}/releases/include/codegen/arrow_compute/ext/) +file(COPY codegen/arrow_compute/ext/code_generator_base.h DESTINATION ${root_directory}/releases/include/codegen/arrow_compute/ext/) +file(COPY codegen/arrow_compute/ext/kernels_ext.h DESTINATION ${root_directory}/releases/include/codegen/arrow_compute/ext/) +file(COPY codegen/common/result_iterator.h DESTINATION ${root_directory}/releases/include/codegen/common/) + +add_definitions(-DNATIVESQL_SRC_PATH="${root_directory}/releases") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations") set(SPARK_COLUMNAR_PLUGIN_SRCS jni/jni_wrapper.cc ${PROTO_SRCS} @@ -174,15 +195,13 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS proto/protobuf_utils.cc codegen/expr_visitor.cc codegen/arrow_compute/expr_visitor.cc - codegen/arrow_compute/ext/item_iterator.cc - codegen/arrow_compute/ext/shuffle_v2_action.cc - codegen/arrow_compute/ext/codegen_node_visitor.cc - codegen/arrow_compute/ext/codegen_node_visitor_v2.cc - codegen/arrow_compute/ext/conditioned_shuffle_kernel.cc - codegen/arrow_compute/ext/conditioned_probe_kernel.cc + codegen/arrow_compute/ext/probe_kernel.cc codegen/arrow_compute/ext/sort_kernel.cc codegen/arrow_compute/ext/kernels_ext.cc codegen/arrow_compute/ext/codegen_common.cc + codegen/arrow_compute/ext/codegen_node_visitor.cc + shuffle/splitter.cc + shuffle/partition_writer.cc ) file(MAKE_DIRECTORY ${root_directory}/releases) diff --git a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc index 51700efb3..258cc11e2 100644 --- a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc +++ b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_join.cc @@ -139,14 +139,6 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) { "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_indices); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); - - auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); - auto schema_table_0 = arrow::schema(left_field_list); auto schema_table_1 = arrow::schema(right_field_list); std::vector> field_list(left_field_list.size() + @@ -160,7 +152,6 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) { std::vector> dummy_result_batches; std::shared_ptr> probe_result_iterator; - std::shared_ptr> shuffle_result_iterator; ////////////////////// evaluate ////////////////////// std::shared_ptr left_record_batch; @@ -169,16 +160,14 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) { uint64_t elapse_left_read = 0; uint64_t elapse_right_read = 0; uint64_t elapse_eval = 0; + uint64_t elapse_finish = 0; uint64_t elapse_probe_process = 0; uint64_t elapse_shuffle_process = 0; uint64_t num_batches = 0; uint64_t num_rows = 0; TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(left_schema, {probeArrays_expr}, - {f_indices}, &expr_probe, true)); - TIME_MICRO_OR_THROW( - elapse_gen, CreateCodeGenerator(schema_table_0, {conditionShuffleExpr}, field_list, - &expr_shuffle, true)); + field_list, &expr_probe, true)); do { TIME_MICRO_OR_THROW(elapse_left_read, @@ -186,16 +175,12 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) { if (left_record_batch) { TIME_MICRO_OR_THROW(elapse_eval, expr_probe->evaluate(left_record_batch, &dummy_result_batches)); - TIME_MICRO_OR_THROW( - elapse_eval, expr_shuffle->evaluate(left_record_batch, &dummy_result_batches)); num_batches += 1; } } while (left_record_batch); std::cout << "Readed left table with " << num_batches << " batches." << std::endl; - TIME_MICRO_OR_THROW(elapse_eval, expr_probe->finish(&probe_result_iterator)); - TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->SetDependency(probe_result_iterator)); - TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->finish(&shuffle_result_iterator)); + TIME_MICRO_OR_THROW(elapse_finish, expr_probe->finish(&probe_result_iterator)); num_batches = 0; uint64_t num_output_batches = 0; @@ -209,9 +194,7 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) { right_column_vector.push_back(right_record_batch->column(i)); } TIME_MICRO_OR_THROW(elapse_probe_process, - probe_result_iterator->ProcessAndCacheOne(right_column_vector)); - TIME_MICRO_OR_THROW(elapse_shuffle_process, - shuffle_result_iterator->Process(right_column_vector, &out)); + probe_result_iterator->Process(right_column_vector, &out)); num_batches += 1; num_output_batches++; num_rows += out->num_rows(); @@ -219,16 +202,17 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmark) { } while (right_record_batch); std::cout << "Readed right table with " << num_batches << " batches." << std::endl; - std::cout << "BenchmarkArrowComputeJoin processed " << num_batches - << " batches, then output " << num_output_batches << " batches with " - << num_rows << " rows, to complete, it took " << TIME_TO_STRING(elapse_gen) - << " doing codegen, took " << TIME_TO_STRING(elapse_left_read) - << " doing left BatchRead, took " << TIME_TO_STRING(elapse_right_read) - << " doing right BatchRead, took " << TIME_TO_STRING(elapse_eval) - << " doing left table hashmap insert, took " - << TIME_TO_STRING(elapse_probe_process) << " doing probe indice fetch, took " - << TIME_TO_STRING(elapse_shuffle_process) << " doing final shuffle." - << std::endl; + std::cout << "==========================================" + << "\nBenchmarkArrowComputeJoin processed " << num_batches << " batches" + << "\noutput " << num_output_batches << " batches with " << num_rows + << " rows" + << "\nCodeGen took " << TIME_TO_STRING(elapse_gen) + << "\nLeft Batch Read took " << TIME_TO_STRING(elapse_left_read) + << "\nRight Batch Read took " << TIME_TO_STRING(elapse_right_read) + << "\nLeft Table Hash Insert took " << TIME_TO_STRING(elapse_eval) + << "\nMake Result Iterator took " << TIME_TO_STRING(elapse_finish) + << "\nProbe and Shuffle took " << TIME_TO_STRING(elapse_probe_process) << "\n" + << "===========================================" << std::endl; } TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) { @@ -266,14 +250,6 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) { "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_indices); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); - - auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); - auto schema_table_0 = arrow::schema(left_field_list); auto schema_table_1 = arrow::schema(right_field_list); std::vector> field_list(left_field_list.size() + @@ -283,32 +259,25 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) { auto schema_table = arrow::schema(field_list); ///////////////////// Calculation ////////////////// std::shared_ptr expr_probe; - std::shared_ptr expr_shuffle; std::vector> dummy_result_batches; std::shared_ptr> probe_result_iterator; - std::shared_ptr> shuffle_result_iterator; ////////////////////// evaluate ////////////////////// std::shared_ptr left_record_batch; std::shared_ptr right_record_batch; uint64_t elapse_gen = 0; - auto n_codegen = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); uint64_t elapse_left_read = 0; uint64_t elapse_right_read = 0; uint64_t elapse_eval = 0; + uint64_t elapse_finish = 0; uint64_t elapse_probe_process = 0; uint64_t elapse_shuffle_process = 0; uint64_t num_batches = 0; uint64_t num_rows = 0; TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(left_schema, {probeArrays_expr}, - {f_indices}, &expr_probe, true)); - TIME_MICRO_OR_THROW( - elapse_gen, CreateCodeGenerator(schema_table_0, {conditionShuffleExpr}, field_list, - &expr_shuffle, true)); + field_list, &expr_probe, true)); do { TIME_MICRO_OR_THROW(elapse_left_read, @@ -316,16 +285,12 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) { if (left_record_batch) { TIME_MICRO_OR_THROW(elapse_eval, expr_probe->evaluate(left_record_batch, &dummy_result_batches)); - TIME_MICRO_OR_THROW( - elapse_eval, expr_shuffle->evaluate(left_record_batch, &dummy_result_batches)); num_batches += 1; } } while (left_record_batch); std::cout << "Readed left table with " << num_batches << " batches." << std::endl; - TIME_MICRO_OR_THROW(elapse_eval, expr_probe->finish(&probe_result_iterator)); - TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->SetDependency(probe_result_iterator)); - TIME_MICRO_OR_THROW(elapse_eval, expr_shuffle->finish(&shuffle_result_iterator)); + TIME_MICRO_OR_THROW(elapse_finish, expr_probe->finish(&probe_result_iterator)); num_batches = 0; uint64_t num_output_batches = 0; @@ -339,9 +304,7 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) { right_column_vector.push_back(right_record_batch->column(i)); } TIME_MICRO_OR_THROW(elapse_probe_process, - probe_result_iterator->ProcessAndCacheOne(right_column_vector)); - TIME_MICRO_OR_THROW(elapse_shuffle_process, - shuffle_result_iterator->Process(right_column_vector, &out)); + probe_result_iterator->Process(right_column_vector, &out)); num_batches += 1; num_output_batches++; num_rows += out->num_rows(); @@ -349,16 +312,17 @@ TEST_F(BenchmarkArrowComputeJoin, JoinBenchmarkWithCondition) { } while (right_record_batch); std::cout << "Readed right table with " << num_batches << " batches." << std::endl; - std::cout << "BenchmarkArrowComputeJoin processed " << num_batches - << " batches, then output " << num_output_batches << " batches with " - << num_rows << " rows, to complete, it took " << TIME_TO_STRING(elapse_gen) - << " doing codegen, took " << TIME_TO_STRING(elapse_left_read) - << " doing left BatchRead, took " << TIME_TO_STRING(elapse_right_read) - << " doing right BatchRead, took " << TIME_TO_STRING(elapse_eval) - << " doing left table hashmap insert, took " - << TIME_TO_STRING(elapse_probe_process) << " doing probe indice fetch, took " - << TIME_TO_STRING(elapse_shuffle_process) << " doing final shuffle." - << std::endl; + std::cout << "==========================================" + << "\nBenchmarkArrowComputeJoin processed " << num_batches << " batches" + << "\noutput " << num_output_batches << " batches with " << num_rows + << " rows" + << "\nCodeGen took " << TIME_TO_STRING(elapse_gen) + << "\nLeft Batch Read took " << TIME_TO_STRING(elapse_left_read) + << "\nRight Batch Read took " << TIME_TO_STRING(elapse_right_read) + << "\nLeft Table Hash Insert took " << TIME_TO_STRING(elapse_eval) + << "\nMake Result Iterator took " << TIME_TO_STRING(elapse_finish) + << "\nProbe and Shuffle took " << TIME_TO_STRING(elapse_probe_process) << "\n" + << "===========================================" << std::endl; } } // namespace codegen } // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc index 92bb22a83..9fc6b5b8a 100644 --- a/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc +++ b/oap-native-sql/cpp/src/benchmarks/arrow_compute_benchmark_sort.cc @@ -67,25 +67,12 @@ class BenchmarkArrowComputeSort : public ::testing::Test { ASSERT_NOT_OK( parquet_reader->GetRecordBatchReader({0}, {0, 1, 2}, &record_batch_reader)); - schema = record_batch_reader->schema(); - std::cout << schema->ToString() << std::endl; - ////////////////// expr prepration //////////////// field_list = record_batch_reader->schema()->fields(); ret_field_list = record_batch_reader->schema()->fields(); } - void StartWithIterator() { - uint64_t elapse_gen = 0; - uint64_t elapse_read = 0; - uint64_t elapse_eval = 0; - uint64_t elapse_sort = 0; - uint64_t elapse_shuffle = 0; - uint64_t num_batches = 0; - std::shared_ptr sort_expr; - TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(schema, {sortArrays_expr}, - {f_indices}, &sort_expr, true)); - + void StartWithIterator(std::shared_ptr sort_expr) { std::vector> input_batch_list; std::vector> dummy_result_batches; std::shared_ptr> sort_result_iterator; @@ -126,45 +113,84 @@ class BenchmarkArrowComputeSort : public ::testing::Test { std::shared_ptr file; std::unique_ptr<::parquet::arrow::FileReader> parquet_reader; std::shared_ptr record_batch_reader; - std::shared_ptr schema; std::vector> field_list; std::vector> ret_field_list; int primary_key_index = 0; - std::shared_ptr f_indices; std::shared_ptr f_res; - ::gandiva::ExpressionPtr sortArrays_expr; - ::gandiva::ExpressionPtr conditionShuffleExpr; + + uint64_t elapse_gen = 0; + uint64_t elapse_read = 0; + uint64_t elapse_eval = 0; + uint64_t elapse_sort = 0; + uint64_t elapse_shuffle = 0; + uint64_t num_batches = 0; }; TEST_F(BenchmarkArrowComputeSort, SortBenchmark) { + elapse_gen = 0; + elapse_read = 0; + elapse_eval = 0; + elapse_sort = 0; + elapse_shuffle = 0; + num_batches = 0; + ////////////////////// prepare expr_vector /////////////////////// + auto indices_type = std::make_shared(16); + f_res = field("res", arrow::uint64()); + + std::vector> gandiva_field_list; + for (auto field : field_list) { + gandiva_field_list.push_back(TreeExprBuilder::MakeField(field)); + } + auto n_sort_to_indices = + TreeExprBuilder::MakeFunction("sortArraysToIndicesNullsFirstAsc", + {gandiva_field_list[primary_key_index]}, uint64()); + std::shared_ptr schema; + schema = arrow::schema(field_list); + std::cout << schema->ToString() << std::endl; + + ::gandiva::ExpressionPtr sortArrays_expr; + sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort_to_indices, f_res); + + std::shared_ptr sort_expr; + TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(schema, {sortArrays_expr}, + ret_field_list, &sort_expr, true)); + + ///////////////////// Calculation ////////////////// + StartWithIterator(sort_expr); +} + +TEST_F(BenchmarkArrowComputeSort, SortBenchmarkWOPayLoad) { + elapse_gen = 0; + elapse_read = 0; + elapse_eval = 0; + elapse_sort = 0; + elapse_shuffle = 0; + num_batches = 0; ////////////////////// prepare expr_vector /////////////////////// auto indices_type = std::make_shared(16); - f_indices = field("indices", indices_type); f_res = field("res", arrow::uint64()); std::vector> gandiva_field_list; for (auto field : field_list) { gandiva_field_list.push_back(TreeExprBuilder::MakeField(field)); } - auto n_left = - TreeExprBuilder::MakeFunction("codegen_left_schema", gandiva_field_list, uint64()); - auto n_right = TreeExprBuilder::MakeFunction("codegen_right_schema", {}, uint64()); auto n_sort_to_indices = TreeExprBuilder::MakeFunction("sortArraysToIndicesNullsFirstAsc", {gandiva_field_list[primary_key_index]}, uint64()); - sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort_to_indices, f_indices); + std::shared_ptr schema; + schema = arrow::schema({field_list[primary_key_index]}); + ::gandiva::ExpressionPtr sortArrays_expr; + sortArrays_expr = TreeExprBuilder::MakeExpression(n_sort_to_indices, f_res); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint64()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint64()); - conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); + std::shared_ptr sort_expr; + TIME_MICRO_OR_THROW(elapse_gen, CreateCodeGenerator(schema, {sortArrays_expr}, + {ret_field_list[primary_key_index]}, + &sort_expr, true)); ///////////////////// Calculation ////////////////// - StartWithIterator(); + StartWithIterator(sort_expr); } } // namespace codegen diff --git a/oap-native-sql/cpp/src/benchmarks/make.sh b/oap-native-sql/cpp/src/benchmarks/make.sh new file mode 100755 index 000000000..c85c35956 --- /dev/null +++ b/oap-native-sql/cpp/src/benchmarks/make.sh @@ -0,0 +1,2 @@ +#gcc $1.cc -std=c++17 -O3 -shared -fPIC /mnt/nvme2/chendi/intel-bigdata/arrow/cpp/release-build/release/libarrow.a -o $1.so +gcc -I /mnt/nvme2/chendi/intel-bigdata/OAP/oap-native-sql/cpp/src/ -I/mnt/nvme2/chendi/intel-bigdata/OAP/oap-native-sql/cpp/src/third_party/sparsehash $1.cc -std=c++17 -O3 -shared -fPIC -larrow -o $1.so diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc index 74e7caf75..7a37ea570 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor.cc @@ -226,15 +226,6 @@ arrow::Status ExprVisitor::MakeExprVisitorImpl( std::vector> left_field_list, std::vector> right_field_list, std::vector> ret_fields, ExprVisitor* p) { - if (func_name.compare("conditionedShuffleArrayList") == 0) { - std::shared_ptr child_node; - if (func_node->children().size() > 0) { - child_node = func_node->children()[0]; - } - RETURN_NOT_OK(ConditionedShuffleArrayListVisitorImpl::Make( - child_node, left_field_list, right_field_list, ret_fields, p, &impl_)); - goto finish; - } if (func_name.compare("conditionedProbeArraysInner") == 0 || func_name.compare("conditionedProbeArraysOuter") == 0 || func_name.compare("conditionedProbeArraysAnti") == 0 || @@ -272,7 +263,7 @@ arrow::Status ExprVisitor::MakeExprVisitorImpl( } RETURN_NOT_OK(ConditionedProbeArraysVisitorImpl::Make( left_key_list, right_key_list, condition_node, join_type, left_field_list, - right_field_list, p, &impl_)); + right_field_list, ret_fields, p, &impl_)); goto finish; } finish: diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h index 3fa7c6c69..8e216b87b 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/expr_visitor_impl.h @@ -414,93 +414,6 @@ class SortArraysToIndicesVisitorImpl : public ExprVisitorImpl { bool asc_; }; -////////////////////////// ConditionedShuffleArrayListVisitorImpl -////////////////////////// -class ConditionedShuffleArrayListVisitorImpl : public ExprVisitorImpl { - public: - ConditionedShuffleArrayListVisitorImpl( - std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::vector> ret_field_list, ExprVisitor* p) - : func_node_(func_node), - left_field_list_(left_field_list), - right_field_list_(right_field_list), - output_field_list_(ret_field_list), - ExprVisitorImpl(p) {} - static arrow::Status Make(std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::vector> ret_field_list, - ExprVisitor* p, std::shared_ptr* out) { - auto impl = std::make_shared( - func_node, left_field_list, right_field_list, ret_field_list, p); - *out = impl; - return arrow::Status::OK(); - } - - arrow::Status Init() override { - if (initialized_) { - return arrow::Status::OK(); - } - RETURN_NOT_OK(extra::ConditionedShuffleArrayListKernel::Make( - &p_->ctx_, func_node_, left_field_list_, right_field_list_, output_field_list_, - &kernel_)); - initialized_ = true; - return arrow::Status::OK(); - } - - arrow::Status Eval() override { - switch (p_->dependency_result_type_) { - case ArrowComputeResultType::None: { - ArrayList in; - for (int i = 0; i < p_->in_record_batch_->num_columns(); i++) { - in.push_back(p_->in_record_batch_->column(i)); - } - TIME_MICRO_OR_RAISE(p_->elapse_time_, kernel_->Evaluate(in)); - finish_return_type_ = ArrowComputeResultType::Batch; - } break; - default: - return arrow::Status::NotImplemented( - "ConditionedShuffleArrayListVisitorImpl: Does not support this type of " - "input."); - } - return arrow::Status::OK(); - } - - arrow::Status SetDependency( - const std::shared_ptr>& dependency_iter, - int index) override { - RETURN_NOT_OK(kernel_->SetDependencyIter(dependency_iter, index)); - p_->dependency_result_type_ = ArrowComputeResultType::BatchIterator; - return arrow::Status::OK(); - } - - arrow::Status MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) override { - switch (finish_return_type_) { - case ArrowComputeResultType::Batch: { - TIME_MICRO_OR_RAISE(p_->elapse_time_, kernel_->MakeResultIterator(schema, out)); - p_->return_type_ = ArrowComputeResultType::BatchIterator; - } break; - default: - return arrow::Status::Invalid( - "ConditionedShuffleArrayListVisitorImpl MakeResultIterator does not " - "support " - "dependency type other than Batch."); - } - return arrow::Status::OK(); - } - - private: - int col_id_; - std::shared_ptr func_node_; - std::vector> left_field_list_; - std::vector> right_field_list_; - std::vector> output_field_list_; -}; - ////////////////////////// ConditionedProbeArraysVisitorImpl /////////////////////// class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl { public: @@ -509,23 +422,26 @@ class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl { std::vector> right_key_list, std::shared_ptr func_node, int join_type, std::vector> left_field_list, - std::vector> right_field_list, ExprVisitor* p) + std::vector> right_field_list, + std::vector> ret_fields, ExprVisitor* p) : left_key_list_(left_key_list), right_key_list_(right_key_list), join_type_(join_type), func_node_(func_node), left_field_list_(left_field_list), right_field_list_(right_field_list), + ret_fields_(ret_fields), ExprVisitorImpl(p) {} static arrow::Status Make(std::vector> left_key_list, std::vector> right_key_list, std::shared_ptr func_node, int join_type, std::vector> left_field_list, std::vector> right_field_list, + std::vector> ret_fields, ExprVisitor* p, std::shared_ptr* out) { auto impl = std::make_shared( left_key_list, right_key_list, func_node, join_type, left_field_list, - right_field_list, p); + right_field_list, ret_fields, p); *out = impl; return arrow::Status::OK(); } @@ -536,7 +452,7 @@ class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl { } RETURN_NOT_OK(extra::ConditionedProbeArraysKernel::Make( &p_->ctx_, left_key_list_, right_key_list_, func_node_, join_type_, - left_field_list_, right_field_list_, &kernel_)); + left_field_list_, right_field_list_, arrow::schema(ret_fields_), &kernel_)); initialized_ = true; return arrow::Status::OK(); } @@ -583,6 +499,7 @@ class ConditionedProbeArraysVisitorImpl : public ExprVisitorImpl { std::vector> right_key_list_; std::vector> left_field_list_; std::vector> right_field_list_; + std::vector> ret_fields_; }; } // namespace arrowcompute } // namespace codegen diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h index 01f22ea0a..4cc39bd10 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/array_item_index.h @@ -23,10 +23,10 @@ namespace codegen { namespace arrowcompute { namespace extra { struct ArrayItemIndex { - uint64_t id = 0; - uint64_t array_id = 0; + uint16_t id = 0; + uint16_t array_id = 0; ArrayItemIndex() : array_id(0), id(0) {} - ArrayItemIndex(uint64_t array_id, uint64_t id) : array_id(array_id), id(id) {} + ArrayItemIndex(uint16_t array_id, uint16_t id) : array_id(array_id), id(id) {} }; } // namespace extra } // namespace arrowcompute diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h index afd15d261..86b8bde79 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/code_generator_base.h @@ -31,16 +31,17 @@ using ArrayList = std::vector>; class CodeGenBase { public: virtual arrow::Status Evaluate(const ArrayList& in) { - return arrow::Status::NotImplemented("SortBase Evaluate is an abstract interface."); + return arrow::Status::NotImplemented( + "CodeGenBase Evaluate is an abstract interface."); } virtual arrow::Status Finish(std::shared_ptr* out) { - return arrow::Status::NotImplemented("SortBase Finish is an abstract interface."); + return arrow::Status::NotImplemented("CodeGenBase Finish is an abstract interface."); } virtual arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) { return arrow::Status::NotImplemented( - "SortBase MakeResultIterator is an abstract interface."); + "CodeGenBase MakeResultIterator is an abstract interface."); } }; } // namespace extra diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc index bce0918ec..b51b669c4 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.cc @@ -48,104 +48,159 @@ std::string BaseCodes() { #include #include -template class ResultIterator { -public: - virtual bool HasNext() { return false; } - virtual arrow::Status Next(std::shared_ptr *out) { - return arrow::Status::NotImplemented("ResultIterator abstract Next function"); - } - virtual arrow::Status - Process(std::vector> in, - std::shared_ptr *out, - const std::shared_ptr &selection = nullptr) { - return arrow::Status::NotImplemented("ResultIterator abstract Process function"); - } - virtual arrow::Status - ProcessAndCacheOne(std::vector> in, - const std::shared_ptr &selection = nullptr) { - return arrow::Status::NotImplemented( - "ResultIterator abstract ProcessAndCacheOne function"); - } - virtual arrow::Status GetResult(std::shared_ptr* out) { - return arrow::Status::NotImplemented("ResultIterator abstract GetResult function"); - } - virtual std::string ToString() { return ""; } -}; - -using ArrayList = std::vector>; -struct ArrayItemIndex { - uint64_t id = 0; - uint64_t array_id = 0; - ArrayItemIndex(uint64_t array_id, uint64_t id) : array_id(array_id), id(id) {} -}; - -class CodeGenBase { - public: - virtual arrow::Status Evaluate(const ArrayList& in) { - return arrow::Status::NotImplemented("SortBase Evaluate is an abstract interface."); - } - virtual arrow::Status Finish(std::shared_ptr* out) { - return arrow::Status::NotImplemented("SortBase Finish is an abstract interface."); - } - virtual arrow::Status MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) { - return arrow::Status::NotImplemented( - "SortBase MakeResultIterator is an abstract interface."); - } -};)"; -} - -int FileSpinLock(std::string path) { - std::string lockfile = path + "/nativesql_compile.lock"; +#include "codegen/arrow_compute/ext/array_item_index.h" +#include "codegen/arrow_compute/ext/code_generator_base.h" +#include "codegen/arrow_compute/ext/kernels_ext.h" +#include "codegen/common/result_iterator.h" +#include "sparsehash/sparse_hash_map.h" +#include "third_party/arrow/utils/hashing.h" - auto fd = open(lockfile.c_str(), O_CREAT, S_IRWXU | S_IRWXG); - flock(fd, LOCK_EX); +using namespace sparkcolumnarplugin::codegen::arrowcompute::extra; - return fd; +)"; } -void FileSpinUnLock(int fd) { - flock(fd, LOCK_UN); - close(fd); +std::string GetArrowTypeDefString(std::shared_ptr type) { + switch (type->id()) { + case arrow::UInt8Type::type_id: + return "uint8()"; + case arrow::Int8Type::type_id: + return "int8()"; + case arrow::UInt16Type::type_id: + return "uint16()"; + case arrow::Int16Type::type_id: + return "int16()"; + case arrow::UInt32Type::type_id: + return "uint32()"; + case arrow::Int32Type::type_id: + return "int32()"; + case arrow::UInt64Type::type_id: + return "uint64()"; + case arrow::Int64Type::type_id: + return "int64()"; + case arrow::FloatType::type_id: + return "float632()"; + case arrow::DoubleType::type_id: + return "float64()"; + case arrow::Date32Type::type_id: + return "date32()"; + case arrow::StringType::type_id: + return "utf8()"; + default: + std::cout << "GetTypeString can't convert " << type->ToString() << std::endl; + throw; + } } - -std::string GetTypeString(std::shared_ptr type) { +std::string GetCTypeString(std::shared_ptr type) { + switch (type->id()) { + case arrow::UInt8Type::type_id: + return "uint8_t"; + case arrow::Int8Type::type_id: + return "int8_t"; + case arrow::UInt16Type::type_id: + return "uint16_t"; + case arrow::Int16Type::type_id: + return "int16_t"; + case arrow::UInt32Type::type_id: + return "uint32_t"; + case arrow::Int32Type::type_id: + return "int32_t"; + case arrow::UInt64Type::type_id: + return "uint64_t"; + case arrow::Int64Type::type_id: + return "int64_t"; + case arrow::FloatType::type_id: + return "float"; + case arrow::DoubleType::type_id: + return "double"; + case arrow::Date32Type::type_id: + std::cout << "Can't handle Data32Type yet" << std::endl; + throw; + case arrow::StringType::type_id: + return "std::string"; + default: + std::cout << "GetTypeString can't convert " << type->ToString() << std::endl; + throw; + } +} +std::string GetTypeString(std::shared_ptr type, std::string tail) { switch (type->id()) { case arrow::UInt8Type::type_id: - return "UInt8Type"; + return "UInt8" + tail; case arrow::Int8Type::type_id: - return "Int8Type"; + return "Int8" + tail; case arrow::UInt16Type::type_id: - return "UInt16Type"; + return "UInt16" + tail; case arrow::Int16Type::type_id: - return "Int16Type"; + return "Int16" + tail; case arrow::UInt32Type::type_id: - return "UInt32Type"; + return "UInt32" + tail; case arrow::Int32Type::type_id: - return "Int32Type"; + return "Int32" + tail; case arrow::UInt64Type::type_id: - return "UInt64Type"; + return "UInt64" + tail; case arrow::Int64Type::type_id: - return "Int64Type"; + return "Int64" + tail; case arrow::FloatType::type_id: - return "FloatType"; + return "Float" + tail; case arrow::DoubleType::type_id: - return "DoubleType"; + return "Double" + tail; case arrow::Date32Type::type_id: - return "Date32Type"; + return "Date32" + tail; case arrow::StringType::type_id: - return "StringType"; + return "String" + tail; default: std::cout << "GetTypeString can't convert " << type->ToString() << std::endl; throw; } } +std::string GetTempPath() { + std::string tmp_dir_; + const char* env_tmp_dir = std::getenv("NATIVESQL_TMP_DIR"); + if (env_tmp_dir != nullptr) { + tmp_dir_ = std::string(env_tmp_dir); + } else { +#ifdef NATIVESQL_SRC_PATH + tmp_dir_ = NATIVESQL_SRC_PATH; +#else + std::cerr << "envioroment variable NATIVESQL_TMP_DIR is not set" << std::endl; + throw; +#endif + } + return tmp_dir_; +} + +int GetBatchSize() { + int batch_size; + const char* env_batch_size = std::getenv("NATIVESQL_BATCH_SIZE"); + if (env_batch_size != nullptr) { + batch_size = atoi(env_batch_size); + } else { + batch_size = 10000; + } + return batch_size; +} + +int FileSpinLock() { + std::string lockfile = GetTempPath() + "/nativesql_compile.lock"; + + auto fd = open(lockfile.c_str(), O_CREAT, S_IRWXU | S_IRWXG); + flock(fd, LOCK_EX); + + return fd; +} + +void FileSpinUnLock(int fd) { + flock(fd, LOCK_UN); + close(fd); +} + arrow::Status CompileCodes(std::string codes, std::string signature) { // temporary cpp/library output files srand(time(NULL)); - std::string outpath = "/tmp"; + std::string outpath = GetTempPath() + "/tmp/"; + mkdir(outpath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); std::string prefix = "/spark-columnar-plugin-codegen-"; std::string cppfile = outpath + prefix + signature + ".cc"; std::string libfile = outpath + prefix + signature + ".so"; @@ -158,6 +213,10 @@ arrow::Status CompileCodes(std::string codes, std::string signature) { exit(EXIT_FAILURE); } out << codes; +#ifdef DEBUG + std::cout << "BatchSize is " << GetBatchSize() << std::endl; + std::cout << codes << std::endl; +#endif out.flush(); out.close(); @@ -170,18 +229,29 @@ arrow::Status CompileCodes(std::string codes, std::string signature) { const char* env_arrow_dir = std::getenv("LIBARROW_DIR"); std::string arrow_header; - std::string arrow_lib; + std::string arrow_lib, arrow_lib2; + std::string nativesql_header = " -I" + GetTempPath() + "/nativesql_include/ "; + std::string nativesql_lib = " -L" + GetTempPath() + " "; if (env_arrow_dir != nullptr) { arrow_header = " -I" + std::string(env_arrow_dir) + "/include "; arrow_lib = " -L" + std::string(env_arrow_dir) + "/lib64 "; + // incase there's a different location for libarrow.so + arrow_lib2 = " -L" + std::string(env_arrow_dir) + "/lib "; } // compile the code - std::string cmd = env_gcc + " -std=c++11 -Wall -Wextra " + arrow_header + arrow_lib + - cppfile + " -o " + libfile + " -O3 -shared -fPIC -larrow 2> " + + std::string cmd = env_gcc + " -std=c++14 -Wno-deprecated-declarations " + arrow_header + + arrow_lib + arrow_lib2 + nativesql_header + nativesql_lib + cppfile + " -o " + + libfile + " -O3 -march=native -shared -fPIC -larrow -lspark_columnar_jni 2> " + logfile; + //#ifdef DEBUG + std::cout << cmd << std::endl; + //#endif int ret = system(cmd.c_str()); if (WEXITSTATUS(ret) != EXIT_SUCCESS) { std::cout << "compilation failed, see " << logfile << std::endl; + std::cout << cmd << std::endl; + cmd = "ls -R -l " + GetTempPath() + "; cat " + logfile; + system(cmd.c_str()); exit(EXIT_FAILURE); } @@ -197,10 +267,9 @@ arrow::Status CompileCodes(std::string codes, std::string signature) { arrow::Status LoadLibrary(std::string signature, arrow::compute::FunctionContext* ctx, std::shared_ptr* out) { - std::string outpath = "/tmp"; + std::string outpath = GetTempPath() + "/tmp/"; std::string prefix = "/spark-columnar-plugin-codegen-"; std::string libfile = outpath + prefix + signature + ".so"; - std::cout << "LoadLibrary " << libfile << std::endl; // load dynamic library void* dynlib = dlopen(libfile.c_str(), RTLD_LAZY); if (!dynlib) { @@ -210,6 +279,7 @@ arrow::Status LoadLibrary(std::string signature, arrow::compute::FunctionContext // loading symbol from library and assign to pointer // (to be cast to function pointer later) + std::cout << "LoadLibrary " << libfile << std::endl; void (*MakeCodeGen)(arrow::compute::FunctionContext * ctx, std::shared_ptr * out); *(void**)(&MakeCodeGen) = dlsym(dynlib, "MakeCodeGen"); diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h index 1316659ae..a47590586 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_common.h @@ -30,11 +30,15 @@ namespace extra { std::string BaseCodes(); -int FileSpinLock(std::string path); +int FileSpinLock(); void FileSpinUnLock(int fd); -std::string GetTypeString(std::shared_ptr type); +int GetBatchSize(); +std::string GetArrowTypeDefString(std::shared_ptr type); +std::string GetCTypeString(std::shared_ptr type); +std::string GetTypeString(std::shared_ptr type, + std::string tail = "Type"); arrow::Status CompileCodes(std::string codes, std::string signature); diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc index 3196f5928..da1f1313b 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc @@ -32,87 +32,52 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FunctionNode& node) { std::shared_ptr child_visitor; *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeCodeGenNodeVisitor(child, field_list_v_, func_count_, codes_ss_, - &child_visitor)); + left_indices_, right_indices_, &child_visitor)); child_visitor_list.push_back(child_visitor); } auto func_name = node.descriptor()->name(); std::stringstream ss; if (func_name.compare("less_than") == 0) { - auto check_str = child_visitor_list[0]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } - check_str = child_visitor_list[1]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } ss << "(" << child_visitor_list[0]->GetResult() << " < " << child_visitor_list[1]->GetResult() << ")"; - } - if (func_name.compare("greater_than") == 0) { - auto check_str = child_visitor_list[0]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } - check_str = child_visitor_list[1]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } + } else if (func_name.compare("greater_than") == 0) { ss << "(" << child_visitor_list[0]->GetResult() << " > " << child_visitor_list[1]->GetResult() << ")"; - } - if (func_name.compare("less_than_or_equal_to") == 0) { - auto check_str = child_visitor_list[0]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } - check_str = child_visitor_list[1]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } + } else if (func_name.compare("less_than_or_equal_to") == 0) { ss << "(" << child_visitor_list[0]->GetResult() << " <= " << child_visitor_list[1]->GetResult() << ")"; - } - if (func_name.compare("greater_than_or_equal_to") == 0) { - auto check_str = child_visitor_list[0]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } - check_str = child_visitor_list[1]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } + } else if (func_name.compare("greater_than_or_equal_to") == 0) { ss << "(" << child_visitor_list[0]->GetResult() << " >= " << child_visitor_list[1]->GetResult() << ")"; - } - if (func_name.compare("equal") == 0) { - auto check_str = child_visitor_list[0]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } - check_str = child_visitor_list[1]->GetPreCheck(); - if (!check_str.empty()) { - ss << check_str << " || "; - } + } else if (func_name.compare("equal") == 0) { ss << "(" << child_visitor_list[0]->GetResult() << " == " << child_visitor_list[1]->GetResult() << ")"; - } - if (func_name.compare("not") == 0) { + } else if (func_name.compare("not") == 0) { ss << "!(" << child_visitor_list[0]->GetResult() << ")"; + } else { + ss << child_visitor_list[0]->GetResult(); + } + if (cur_func_id == 0) { + codes_str_ = ss.str(); + } else { + codes_str_ = ss.str(); } - codes_str_ = ss.str(); return arrow::Status::OK(); } + arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FieldNode& node) { auto cur_func_id = *func_count_; auto this_field = node.field(); int arg_id = 0; bool found = false; + int index = 0; for (auto field_list : field_list_v_) { + arg_id = 0; for (auto field : field_list) { if (field->name() == this_field->name()) { found = true; + InsertToIndices(index, arg_id); break; } arg_id++; @@ -120,27 +85,26 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FieldNode& node) { if (found) { break; } + index = 1; } - std::string type_name; - if (this_field->type()->id() == arrow::Type::STRING) { - type_name = "std::string"; + if (index == 0) { + *codes_ss_ << "if (cached_0_" << arg_id + << "_[x.array_id]->IsNull(x.id)) {return false;}" << std::endl; + *codes_ss_ << "auto input_field_" << cur_func_id << " = cached_0_" << arg_id + << "_[x.array_id]->GetView(x.id);" << std::endl; + } else { - type_name = this_field->type()->name(); + *codes_ss_ << "if (cached_1_" << arg_id << "_->IsNull(y)) {return false;}" + << std::endl; + *codes_ss_ << "auto input_field_" << cur_func_id << " = cached_1_" << arg_id + << "_->GetView(y);" << std::endl; } - *codes_ss_ << type_name << " *input_field_" << cur_func_id << " = nullptr;" - << std::endl; - *codes_ss_ << "if (!is_null_func_list[" << arg_id << "]()) {" << std::endl; - *codes_ss_ << " input_field_" << cur_func_id << " = (" << type_name - << "*)get_func_list[" << arg_id << "]();" << std::endl; - *codes_ss_ << "}" << std::endl; std::stringstream ss; - ss << "*input_field_" << cur_func_id; + ss << "input_field_" << cur_func_id; codes_str_ = ss.str(); - std::stringstream ss_check; - ss_check << "(input_field_" << cur_func_id << " == nullptr)"; - check_str_ = ss_check.str(); + check_str_ = ""; return arrow::Status::OK(); } @@ -151,12 +115,12 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::IfNode& node) { arrow::Status CodeGenNodeVisitor::Visit(const gandiva::LiteralNode& node) { auto cur_func_id = *func_count_; if (node.return_type()->id() == arrow::Type::STRING) { - *codes_ss_ << "const std::string input_field_" << cur_func_id << R"( = ")" + *codes_ss_ << "auto input_field_" << cur_func_id << R"( = ")" << gandiva::ToString(node.holder()) << R"(";)" << std::endl; } else { - *codes_ss_ << "const " << node.return_type()->name() << " input_field_" << cur_func_id - << " = " << gandiva::ToString(node.holder()) << ";" << std::endl; + *codes_ss_ << "auto input_field_" << cur_func_id << " = " + << gandiva::ToString(node.holder()) << ";" << std::endl; } std::stringstream ss; @@ -172,7 +136,7 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::BooleanNode& node) { std::shared_ptr child_visitor; *func_count_ = *func_count_ + 1; RETURN_NOT_OK(MakeCodeGenNodeVisitor(child, field_list_v_, func_count_, codes_ss_, - &child_visitor)); + left_indices_, right_indices_, &child_visitor)); child_visitor_list.push_back(child_visitor); } @@ -188,14 +152,106 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::BooleanNode& node) { codes_str_ = ss.str(); return arrow::Status::OK(); } + arrow::Status CodeGenNodeVisitor::Visit(const gandiva::InExpressionNode& node) { + auto cur_func_id = *func_count_; + std::shared_ptr child_visitor; + *func_count_ = *func_count_ + 1; + RETURN_NOT_OK(MakeCodeGenNodeVisitor(node.eval_expr(), field_list_v_, func_count_, + codes_ss_, left_indices_, right_indices_, + &child_visitor)); + *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {"; + bool add_comma = false; + for (auto& value : node.values()) { + if (add_comma) { + *codes_ss_ << ", "; + } + // add type in the front to differentiate + *codes_ss_ << value; + add_comma = true; + } + *codes_ss_ << "};" << std::endl; + + std::stringstream ss; + ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id + << ".end(), " << child_visitor->GetResult() << ") != " + << "input_field_" << cur_func_id << ".end()"; + codes_str_ = ss.str(); + check_str_ = child_visitor->GetPreCheck(); return arrow::Status::OK(); } + arrow::Status CodeGenNodeVisitor::Visit(const gandiva::InExpressionNode& node) { + auto cur_func_id = *func_count_; + std::shared_ptr child_visitor; + *func_count_ = *func_count_ + 1; + RETURN_NOT_OK(MakeCodeGenNodeVisitor(node.eval_expr(), field_list_v_, func_count_, + codes_ss_, left_indices_, right_indices_, + &child_visitor)); + *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {"; + bool add_comma = false; + for (auto& value : node.values()) { + if (add_comma) { + *codes_ss_ << ", "; + } + // add type in the front to differentiate + *codes_ss_ << value; + add_comma = true; + } + *codes_ss_ << "};" << std::endl; + + std::stringstream ss; + ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id + << ".end(), " << child_visitor->GetResult() << ") != " + << "input_field_" << cur_func_id << ".end()"; + codes_str_ = ss.str(); + check_str_ = child_visitor->GetPreCheck(); return arrow::Status::OK(); } + arrow::Status CodeGenNodeVisitor::Visit( const gandiva::InExpressionNode& node) { + auto cur_func_id = *func_count_; + std::shared_ptr child_visitor; + *func_count_ = *func_count_ + 1; + RETURN_NOT_OK(MakeCodeGenNodeVisitor(node.eval_expr(), field_list_v_, func_count_, + codes_ss_, left_indices_, right_indices_, + &child_visitor)); + *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {"; + bool add_comma = false; + for (auto& value : node.values()) { + if (add_comma) { + *codes_ss_ << ", "; + } + // add type in the front to differentiate + *codes_ss_ << R"(")" << value << R"(")"; + add_comma = true; + } + *codes_ss_ << "};" << std::endl; + + std::stringstream ss; + ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id + << ".end(), " << child_visitor->GetResult() << ") != " + << "input_field_" << cur_func_id << ".end()"; + codes_str_ = ss.str(); + check_str_ = child_visitor->GetPreCheck(); + return arrow::Status::OK(); +} + +arrow::Status CodeGenNodeVisitor::InsertToIndices(int index, int arg_id) { + if (index == 0) { + if (std::find((*left_indices_).begin(), (*left_indices_).end(), arg_id) == + (*left_indices_).end()) { + (*left_indices_).push_back(arg_id); + } + } + if (index == 1) { + if (std::find((*right_indices_).begin(), (*right_indices_).end(), arg_id) == + (*right_indices_).end()) { + (*right_indices_).push_back(arg_id); + } + } + return arrow::Status::OK(); } diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h index 127d95534..bbb77eebe 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.h @@ -28,11 +28,14 @@ class CodeGenNodeVisitor : public VisitorBase { public: CodeGenNodeVisitor(std::shared_ptr func, std::vector>> field_list_v, - int* func_count, std::stringstream* codes_ss) + int* func_count, std::stringstream* codes_ss, + std::vector* left_indices, std::vector* right_indices) : func_(func), field_list_v_(field_list_v), func_count_(func_count), - codes_ss_(codes_ss) {} + codes_ss_(codes_ss), + left_indices_(left_indices), + right_indices_(right_indices) {} arrow::Status Eval() { RETURN_NOT_OK(func_->Accept(*this)); @@ -57,13 +60,17 @@ class CodeGenNodeVisitor : public VisitorBase { std::stringstream* codes_ss_; std::string codes_str_; std::string check_str_; + std::vector* left_indices_; + std::vector* right_indices_; + arrow::Status InsertToIndices(int index, int arg_id); }; static arrow::Status MakeCodeGenNodeVisitor( std::shared_ptr func, std::vector>> field_list_v, int* func_count, - std::stringstream* codes_ss, std::shared_ptr* out) { - auto visitor = - std::make_shared(func, field_list_v, func_count, codes_ss); + std::stringstream* codes_ss, std::vector* left_indices, + std::vector* right_indices, std::shared_ptr* out) { + auto visitor = std::make_shared( + func, field_list_v, func_count, codes_ss, left_indices, right_indices); RETURN_NOT_OK(visitor->Eval()); *out = visitor; return arrow::Status::OK(); diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.cc deleted file mode 100644 index 7a9af5781..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.cc +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "codegen/arrow_compute/ext/codegen_node_visitor_v2.h" - -#include - -#include -#include - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { -std::string CodeGenNodeVisitorV2::GetResult() { return codes_str_; } -std::string CodeGenNodeVisitorV2::GetPreCheck() { return check_str_; } -arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::FunctionNode& node) { - std::vector> child_visitor_list; - auto cur_func_id = *func_count_; - for (auto child : node.children()) { - std::shared_ptr child_visitor; - *func_count_ = *func_count_ + 1; - RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(child, field_list_v_, func_count_, codes_ss_, - &child_visitor)); - child_visitor_list.push_back(child_visitor); - } - - auto func_name = node.descriptor()->name(); - std::stringstream ss; - if (func_name.compare("less_than") == 0) { - ss << "(" << child_visitor_list[0]->GetResult() << " < " - << child_visitor_list[1]->GetResult() << ")"; - } else if (func_name.compare("greater_than") == 0) { - ss << "(" << child_visitor_list[0]->GetResult() << " > " - << child_visitor_list[1]->GetResult() << ")"; - } else if (func_name.compare("less_than_or_equal_to") == 0) { - ss << "(" << child_visitor_list[0]->GetResult() - << " <= " << child_visitor_list[1]->GetResult() << ")"; - } else if (func_name.compare("greater_than_or_equal_to") == 0) { - ss << "(" << child_visitor_list[0]->GetResult() - << " >= " << child_visitor_list[1]->GetResult() << ")"; - } else if (func_name.compare("equal") == 0) { - ss << "(" << child_visitor_list[0]->GetResult() - << " == " << child_visitor_list[1]->GetResult() << ")"; - } else if (func_name.compare("not") == 0) { - ss << "!(" << child_visitor_list[0]->GetResult() << ")"; - } else { - ss << child_visitor_list[0]->GetResult(); - } - if (cur_func_id == 0) { - codes_str_ = ss.str(); - } else { - codes_str_ = ss.str(); - } - return arrow::Status::OK(); -} - -arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::FieldNode& node) { - auto cur_func_id = *func_count_; - auto this_field = node.field(); - int arg_id = 0; - bool found = false; - std::string is_null_func_list = "left_is_null_func_list_"; - std::string get_func_list = "left_get_func_list_"; - std::string index = "left_index"; - for (auto field_list : field_list_v_) { - arg_id = 0; - for (auto field : field_list) { - if (field->name() == this_field->name()) { - found = true; - break; - } - arg_id++; - } - if (found) { - break; - } - is_null_func_list = "right_is_null_func_list_"; - get_func_list = "right_get_func_list_"; - index = "right_index"; - } - std::string type_name; - if (this_field->type()->id() == arrow::Type::STRING) { - type_name = "std::string"; - } else { - type_name = this_field->type()->name(); - } - *codes_ss_ << type_name << " *input_field_" << cur_func_id << " = nullptr;" - << std::endl; - *codes_ss_ << "if (!" << is_null_func_list << "[" << arg_id << "](" + index + ")) {" - << std::endl; - *codes_ss_ << " input_field_" << cur_func_id << " = (" << type_name << "*)" - << get_func_list << "[" << arg_id << "](" + index + ");" << std::endl; - *codes_ss_ << "} else {" << std::endl; - *codes_ss_ << " return false;" << std::endl; - *codes_ss_ << "}" << std::endl; - - std::stringstream ss; - ss << "*input_field_" << cur_func_id; - codes_str_ = ss.str(); - - check_str_ = ""; - return arrow::Status::OK(); -} - -arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::IfNode& node) { - return arrow::Status::OK(); -} - -arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::LiteralNode& node) { - auto cur_func_id = *func_count_; - if (node.return_type()->id() == arrow::Type::STRING) { - *codes_ss_ << "const std::string input_field_" << cur_func_id << R"( = ")" - << gandiva::ToString(node.holder()) << R"(";)" << std::endl; - - } else { - *codes_ss_ << "const " << node.return_type()->name() << " input_field_" << cur_func_id - << " = " << gandiva::ToString(node.holder()) << ";" << std::endl; - } - - std::stringstream ss; - ss << "input_field_" << cur_func_id; - codes_str_ = ss.str(); - return arrow::Status::OK(); -} - -arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::BooleanNode& node) { - std::vector> child_visitor_list; - auto cur_func_id = *func_count_; - for (auto child : node.children()) { - std::shared_ptr child_visitor; - *func_count_ = *func_count_ + 1; - RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(child, field_list_v_, func_count_, codes_ss_, - &child_visitor)); - child_visitor_list.push_back(child_visitor); - } - - std::stringstream ss; - if (node.expr_type() == gandiva::BooleanNode::AND) { - ss << "(" << child_visitor_list[0]->GetResult() << ") && (" - << child_visitor_list[1]->GetResult() << ")"; - } - if (node.expr_type() == gandiva::BooleanNode::OR) { - ss << "(" << child_visitor_list[0]->GetResult() << ") || (" - << child_visitor_list[1]->GetResult() << ")"; - } - codes_str_ = ss.str(); - return arrow::Status::OK(); -} - -arrow::Status CodeGenNodeVisitorV2::Visit(const gandiva::InExpressionNode& node) { - auto cur_func_id = *func_count_; - std::shared_ptr child_visitor; - *func_count_ = *func_count_ + 1; - RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(node.eval_expr(), field_list_v_, func_count_, - codes_ss_, &child_visitor)); - *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {"; - bool add_comma = false; - for (auto& value : node.values()) { - if (add_comma) { - *codes_ss_ << ", "; - } - // add type in the front to differentiate - *codes_ss_ << value; - add_comma = true; - } - *codes_ss_ << "};" << std::endl; - - std::stringstream ss; - ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id - << ".end(), " << child_visitor->GetResult() << ") != " - << "input_field_" << cur_func_id << ".end()"; - codes_str_ = ss.str(); - check_str_ = child_visitor->GetPreCheck(); - return arrow::Status::OK(); -} - -arrow::Status CodeGenNodeVisitorV2::Visit( - const gandiva::InExpressionNode& node) { - auto cur_func_id = *func_count_; - std::shared_ptr child_visitor; - *func_count_ = *func_count_ + 1; - RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(node.eval_expr(), field_list_v_, func_count_, - codes_ss_, &child_visitor)); - *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {"; - bool add_comma = false; - for (auto& value : node.values()) { - if (add_comma) { - *codes_ss_ << ", "; - } - // add type in the front to differentiate - *codes_ss_ << value; - add_comma = true; - } - *codes_ss_ << "};" << std::endl; - - std::stringstream ss; - ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id - << ".end(), " << child_visitor->GetResult() << ") != " - << "input_field_" << cur_func_id << ".end()"; - codes_str_ = ss.str(); - check_str_ = child_visitor->GetPreCheck(); - return arrow::Status::OK(); -} - -arrow::Status CodeGenNodeVisitorV2::Visit( - const gandiva::InExpressionNode& node) { - auto cur_func_id = *func_count_; - std::shared_ptr child_visitor; - *func_count_ = *func_count_ + 1; - RETURN_NOT_OK(MakeCodeGenNodeVisitorV2(node.eval_expr(), field_list_v_, func_count_, - codes_ss_, &child_visitor)); - *codes_ss_ << "std::vector input_field_" << cur_func_id << " = {"; - bool add_comma = false; - for (auto& value : node.values()) { - if (add_comma) { - *codes_ss_ << ", "; - } - // add type in the front to differentiate - *codes_ss_ << R"(")" << value << R"(")"; - add_comma = true; - } - *codes_ss_ << "};" << std::endl; - - std::stringstream ss; - ss << "std::find(input_field_" << cur_func_id << ".begin(), input_field_" << cur_func_id - << ".end(), " << child_visitor->GetResult() << ") != " - << "input_field_" << cur_func_id << ".end()"; - codes_str_ = ss.str(); - check_str_ = child_visitor->GetPreCheck(); - return arrow::Status::OK(); -} - -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.h deleted file mode 100644 index 88cbd3caa..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor_v2.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include "codegen/common/visitor_base.h" - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { -class CodeGenNodeVisitorV2 : public VisitorBase { - public: - CodeGenNodeVisitorV2( - std::shared_ptr func, - std::vector>> field_list_v, - int* func_count, std::stringstream* codes_ss) - : func_(func), - field_list_v_(field_list_v), - func_count_(func_count), - codes_ss_(codes_ss) {} - - arrow::Status Eval() { - RETURN_NOT_OK(func_->Accept(*this)); - return arrow::Status::OK(); - } - std::string GetResult(); - std::string GetPreCheck(); - arrow::Status Visit(const gandiva::FunctionNode& node) override; - arrow::Status Visit(const gandiva::FieldNode& node) override; - arrow::Status Visit(const gandiva::IfNode& node) override; - arrow::Status Visit(const gandiva::LiteralNode& node) override; - arrow::Status Visit(const gandiva::BooleanNode& node) override; - arrow::Status Visit(const gandiva::InExpressionNode& node) override; - arrow::Status Visit(const gandiva::InExpressionNode& node) override; - arrow::Status Visit(const gandiva::InExpressionNode& node) override; - - private: - std::shared_ptr func_; - std::vector>> field_list_v_; - int* func_count_; - // output - std::stringstream* codes_ss_; - std::string codes_str_; - std::string check_str_; -}; -static arrow::Status MakeCodeGenNodeVisitorV2( - std::shared_ptr func, - std::vector>> field_list_v, int* func_count, - std::stringstream* codes_ss, std::shared_ptr* out) { - auto visitor = - std::make_shared(func, field_list_v, func_count, codes_ss); - RETURN_NOT_OK(visitor->Eval()); - *out = visitor; - return arrow::Status::OK(); -} -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc deleted file mode 100644 index a10b5ebf5..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_probe_kernel.cc +++ /dev/null @@ -1,815 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "codegen/arrow_compute/ext/codegen_node_visitor_v2.h" -#include "codegen/arrow_compute/ext/conditioner.h" -#include "codegen/arrow_compute/ext/item_iterator.h" -#include "codegen/arrow_compute/ext/kernels_ext.h" -#include "codegen/arrow_compute/ext/shuffle_v2_action.h" -#include "third_party/arrow/utils/hashing.h" - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { - -using ArrayList = std::vector>; - -/////////////// ConditionedProbeArrays //////////////// -class ConditionedProbeArraysKernel::Impl { - public: - Impl() {} - virtual ~Impl() {} - virtual arrow::Status Evaluate(const ArrayList& in) { - return arrow::Status::NotImplemented( - "ConditionedProbeArraysKernel::Impl Evaluate is abstract"); - } // namespace extra - virtual arrow::Status MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) { - return arrow::Status::NotImplemented( - "ConditionedProbeArraysKernel::Impl MakeResultIterator is abstract"); - } -}; // namespace arrowcompute - -template -class ConditionedProbeArraysTypedImpl : public ConditionedProbeArraysKernel::Impl { - public: - ConditionedProbeArraysTypedImpl( - arrow::compute::FunctionContext* ctx, - std::vector> left_key_list, - std::vector> right_key_list, - std::shared_ptr func_node, int join_type, - std::vector> left_field_list, - std::vector> right_field_list) - : ctx_(ctx), - left_key_list_(left_key_list), - right_key_list_(right_key_list), - func_node_(func_node), - join_type_(join_type), - left_field_list_(left_field_list), - right_field_list_(right_field_list) { - hash_table_ = std::make_shared(ctx_->memory_pool()); - for (auto key_field : left_key_list) { - int i = 0; - for (auto field : left_field_list) { - if (key_field->name() == field->name()) { - break; - } - i++; - } - left_key_indices_.push_back(i); - } - for (auto key_field : right_key_list) { - int i = 0; - for (auto field : right_field_list) { - if (key_field->name() == field->name()) { - break; - } - i++; - } - right_key_indices_.push_back(i); - } - if (left_key_list.size() > 1) { - std::vector> type_list; - for (auto key_field : left_key_list) { - type_list.push_back(key_field->type()); - } - extra::HashAggrArrayKernel::Make(ctx_, type_list, &concat_kernel_); - } - for (auto field : left_field_list) { - std::shared_ptr iter; - MakeArrayListItemIterator(field->type(), &iter); - input_iterator_list_.push_back(iter); - } - for (auto field : right_field_list) { - std::shared_ptr iter; - MakeArrayItemIterator(field->type(), &iter); - input_iterator_list_.push_back(iter); - } - input_cache_.resize(left_field_list.size()); - - // This Function suppose to return a lambda function for later ResultIterator - auto start = std::chrono::steady_clock::now(); - if (func_node_) { - LoadJITFunction(func_node_, left_field_list_, right_field_list_, &conditioner_); - } - auto end = std::chrono::steady_clock::now(); - std::cout - << "Code Generation took " - << (std::chrono::duration_cast(end - start).count() / - 1000) - << " ms." << std::endl; - } - - arrow::Status Evaluate(const ArrayList& in_arr_list) override { - if (in_arr_list.size() != input_cache_.size()) { - return arrow::Status::Invalid( - "ConditionedShuffleArrayListKernel input arrayList size does not match numCols " - "in cache, which are ", - in_arr_list.size(), " and ", input_cache_.size()); - } - // we need to convert std::vector to std::vector - for (int col_id = 0; col_id < input_cache_.size(); col_id++) { - input_cache_[col_id].push_back(in_arr_list[col_id]); - } - - // do concat_join if necessary - std::shared_ptr in; - if (concat_kernel_) { - ArrayList concat_kernel_arr_list; - for (auto i : left_key_indices_) { - concat_kernel_arr_list.push_back(in_arr_list[i]); - } - RETURN_NOT_OK(concat_kernel_->Evaluate(concat_kernel_arr_list, &in)); - } else { - in = in_arr_list[left_key_indices_[0]]; - } - - // we should put items into hashmap - auto typed_array = std::dynamic_pointer_cast(in); - auto insert_on_found = [this](int32_t i) { - left_table_size_++; - memo_index_to_arrayid_[i].emplace_back(cur_array_id_, cur_id_); - }; - auto insert_on_not_found = [this](int32_t i) { - left_table_size_++; - memo_index_to_arrayid_.push_back({ArrayItemIndex(cur_array_id_, cur_id_)}); - }; - - cur_id_ = 0; - int memo_index = 0; - if (typed_array->null_count() == 0) { - for (; cur_id_ < typed_array->length(); cur_id_++) { - hash_table_->GetOrInsert(typed_array->GetView(cur_id_), insert_on_found, - insert_on_not_found, &memo_index); - } - } else { - for (; cur_id_ < typed_array->length(); cur_id_++) { - if (typed_array->IsNull(cur_id_)) { - hash_table_->GetOrInsertNull(insert_on_found, insert_on_not_found); - } else { - hash_table_->GetOrInsert(typed_array->GetView(cur_id_), insert_on_found, - insert_on_not_found, &memo_index); - } - } - } - cur_array_id_++; - return arrow::Status::OK(); - } - - arrow::Status MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) override { - std::function& in, - std::function ConditionCheck, - std::shared_ptr* out)> - eval_func; - // prepare process next function - std::cout << "HashMap lenghth is " << memo_index_to_arrayid_.size() - << ", total stored items count is " << left_table_size_ << std::endl; - switch (join_type_) { - case 0: { /*Inner Join*/ - eval_func = [this, schema]( - const std::shared_ptr& in, - std::function ConditionCheck, - std::shared_ptr* out) { - // prepare - std::unique_ptr left_indices_builder; - auto left_array_type = arrow::fixed_size_binary(sizeof(ArrayItemIndex)); - left_indices_builder.reset( - new arrow::FixedSizeBinaryBuilder(left_array_type, ctx_->memory_pool())); - - std::unique_ptr right_indices_builder; - right_indices_builder.reset( - new arrow::UInt32Builder(arrow::uint32(), ctx_->memory_pool())); - - auto typed_array = std::dynamic_pointer_cast(in); - for (int i = 0; i < typed_array->length(); i++) { - if (!typed_array->IsNull(i)) { - auto index = hash_table_->Get(typed_array->GetView(i)); - if (index != -1) { - for (auto tmp : memo_index_to_arrayid_[index]) { - if (ConditionCheck(tmp, i)) { - RETURN_NOT_OK(left_indices_builder->Append((uint8_t*)&tmp)); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } - } - } - } - } - // create buffer and null_vector to FixedSizeBinaryArray - std::shared_ptr left_arr_out; - std::shared_ptr right_arr_out; - RETURN_NOT_OK(left_indices_builder->Finish(&left_arr_out)); - RETURN_NOT_OK(right_indices_builder->Finish(&right_arr_out)); - auto result_schema = - arrow::schema({arrow::field("left_indices", left_array_type), - arrow::field("right_indices", arrow::uint32())}); - *out = arrow::RecordBatch::Make(result_schema, right_arr_out->length(), - {left_arr_out, right_arr_out}); - return arrow::Status::OK(); - }; - } break; - case 1: { /*Outer Join*/ - eval_func = [this, schema]( - const std::shared_ptr& in, - std::function ConditionCheck, - std::shared_ptr* out) { - std::unique_ptr left_indices_builder; - auto left_array_type = arrow::fixed_size_binary(sizeof(ArrayItemIndex)); - left_indices_builder.reset( - new arrow::FixedSizeBinaryBuilder(left_array_type, ctx_->memory_pool())); - - std::unique_ptr right_indices_builder; - right_indices_builder.reset( - new arrow::UInt32Builder(arrow::uint32(), ctx_->memory_pool())); - - auto typed_array = std::dynamic_pointer_cast(in); - for (int i = 0; i < typed_array->length(); i++) { - if (typed_array->IsNull(i)) { - auto index = hash_table_->GetNull(); - if (index == -1) { - RETURN_NOT_OK(left_indices_builder->AppendNull()); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } else { - for (auto tmp : memo_index_to_arrayid_[index]) { - if (ConditionCheck(tmp, i)) { - RETURN_NOT_OK(left_indices_builder->Append((uint8_t*)&tmp)); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } - } - } - } else { - auto index = hash_table_->Get(typed_array->GetView(i)); - if (index == -1) { - RETURN_NOT_OK(left_indices_builder->AppendNull()); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } else { - for (auto tmp : memo_index_to_arrayid_[index]) { - if (ConditionCheck(tmp, i)) { - RETURN_NOT_OK(left_indices_builder->Append((uint8_t*)&tmp)); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } - } - } - } - } - // create buffer and null_vector to FixedSizeBinaryArray - std::shared_ptr left_arr_out; - std::shared_ptr right_arr_out; - RETURN_NOT_OK(left_indices_builder->Finish(&left_arr_out)); - RETURN_NOT_OK(right_indices_builder->Finish(&right_arr_out)); - auto result_schema = - arrow::schema({arrow::field("left_indices", left_array_type), - arrow::field("right_indices", arrow::uint32())}); - *out = arrow::RecordBatch::Make(result_schema, right_arr_out->length(), - {left_arr_out, right_arr_out}); - - return arrow::Status::OK(); - }; - } break; - case 2: { /*Anti Join*/ - eval_func = [this, schema]( - const std::shared_ptr& in, - std::function ConditionCheck, - std::shared_ptr* out) { - std::unique_ptr left_indices_builder; - auto left_array_type = arrow::fixed_size_binary(sizeof(ArrayItemIndex)); - left_indices_builder.reset( - new arrow::FixedSizeBinaryBuilder(left_array_type, ctx_->memory_pool())); - - std::unique_ptr right_indices_builder; - right_indices_builder.reset( - new arrow::UInt32Builder(arrow::uint32(), ctx_->memory_pool())); - - auto typed_array = std::dynamic_pointer_cast(in); - for (int i = 0; i < typed_array->length(); i++) { - if (!typed_array->IsNull(i)) { - auto index = hash_table_->Get(typed_array->GetView(i)); - if (index == -1) { - RETURN_NOT_OK(left_indices_builder->AppendNull()); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } else { - bool found = false; - for (auto tmp : memo_index_to_arrayid_[index]) { - if (ConditionCheck(tmp, i)) { - found = true; - break; - } - } - if (!found) { - RETURN_NOT_OK(left_indices_builder->AppendNull()); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } - } - } else { - auto index = hash_table_->GetNull(); - if (index == -1) { - RETURN_NOT_OK(left_indices_builder->AppendNull()); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } else { - bool found = false; - for (auto tmp : memo_index_to_arrayid_[index]) { - if (ConditionCheck(tmp, i)) { - found = true; - break; - } - } - if (!found) { - RETURN_NOT_OK(left_indices_builder->AppendNull()); - RETURN_NOT_OK(right_indices_builder->Append(i)); - } - } - } - } - - // create buffer and null_vector to FixedSizeBinaryArray - std::shared_ptr left_arr_out; - std::shared_ptr right_arr_out; - RETURN_NOT_OK(left_indices_builder->Finish(&left_arr_out)); - RETURN_NOT_OK(right_indices_builder->Finish(&right_arr_out)); - auto result_schema = - arrow::schema({arrow::field("left_indices", left_array_type), - arrow::field("right_indices", arrow::uint32())}); - *out = arrow::RecordBatch::Make(result_schema, right_arr_out->length(), - {left_arr_out, right_arr_out}); - return arrow::Status::OK(); - }; - } break; - case 3: { /*Semi Join*/ - eval_func = [this, schema]( - const std::shared_ptr& in, - std::function ConditionCheck, - std::shared_ptr* out) { - // prepare - std::unique_ptr left_indices_builder; - auto left_array_type = arrow::fixed_size_binary(sizeof(ArrayItemIndex)); - left_indices_builder.reset( - new arrow::FixedSizeBinaryBuilder(left_array_type, ctx_->memory_pool())); - - std::unique_ptr right_indices_builder; - right_indices_builder.reset( - new arrow::UInt32Builder(arrow::uint32(), ctx_->memory_pool())); - - auto typed_array = std::dynamic_pointer_cast(in); - for (int i = 0; i < typed_array->length(); i++) { - if (!typed_array->IsNull(i)) { - auto index = hash_table_->Get(typed_array->GetView(i)); - if (index != -1) { - for (auto tmp : memo_index_to_arrayid_[index]) { - if (ConditionCheck(tmp, i)) { - RETURN_NOT_OK(left_indices_builder->AppendNull()); - RETURN_NOT_OK(right_indices_builder->Append(i)); - break; - } - } - } - } - } - // create buffer and null_vector to FixedSizeBinaryArray - std::shared_ptr left_arr_out; - std::shared_ptr right_arr_out; - RETURN_NOT_OK(left_indices_builder->Finish(&left_arr_out)); - RETURN_NOT_OK(right_indices_builder->Finish(&right_arr_out)); - auto result_schema = - arrow::schema({arrow::field("left_indices", left_array_type), - arrow::field("right_indices", arrow::uint32())}); - *out = arrow::RecordBatch::Make(result_schema, right_arr_out->length(), - {left_arr_out, right_arr_out}); - return arrow::Status::OK(); - }; - } break; - default: - return arrow::Status::Invalid( - "ConditionedProbeArraysTypedImpl only support join type: InnerJoin, " - "RightJoin"); - } - - *out = std::make_shared( - ctx_, conditioner_, right_key_indices_, concat_kernel_, input_iterator_list_, - input_cache_, eval_func); - return arrow::Status::OK(); - } - - private: - using ArrayType = typename arrow::TypeTraits::ArrayType; - - uint64_t left_table_size_ = 0; - std::shared_ptr func_node_; - std::vector left_key_indices_; - std::vector right_key_indices_; - std::vector> left_key_list_; - std::vector> right_key_list_; - std::vector> left_field_list_; - std::vector> right_field_list_; - std::vector> input_iterator_list_; - int join_type_; - std::shared_ptr concat_kernel_; - std::vector input_cache_; - std::shared_ptr out_type_; - arrow::compute::FunctionContext* ctx_; - std::shared_ptr hash_table_; - std::vector> memo_index_to_arrayid_; - std::shared_ptr conditioner_; - - uint64_t cur_array_id_ = 0; - uint64_t cur_id_ = 0; - - arrow::Status LoadJITFunction( - std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::shared_ptr* out) { - // generate ddl signature - std::stringstream func_args_ss; - func_args_ss << func_node->ToString(); - for (auto field : left_field_list) { - func_args_ss << field->ToString(); - } - for (auto field : right_field_list) { - func_args_ss << field->ToString(); - } - - std::stringstream signature_ss; - signature_ss << std::hex << std::hash{}(func_args_ss.str()); - std::string signature = signature_ss.str(); - std::cout << "LoadJITFunction signature is " << signature << std::endl; - - auto file_lock = FileSpinLock("/tmp"); - std::cout << "GetFileLock" << std::endl; - auto status = LoadLibrary(signature, out); - if (!status.ok()) { - // process - auto codes = ProduceCodes(func_node_, left_field_list_, right_field_list_); - // compile codes - RETURN_NOT_OK(CompileCodes(codes, signature)); - RETURN_NOT_OK(LoadLibrary(signature, out)); - } - FileSpinUnLock(file_lock); - std::cout << "ReleaseFileLock" << std::endl; - return arrow::Status::OK(); - } - - std::string ProduceCodes(std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list) { - // CodeGen - std::stringstream codes_ss; - codes_ss << R"(#include -#include -#define int8 int8_t -#define int16 int16_t -#define int32 int32_t -#define int64 int64_t -#define uint8 uint8_t -#define uint16 uint16_t -#define uint32 uint32_t -#define uint64 uint64_t - -struct ArrayItemIndex { - uint64_t id = 0; - uint64_t array_id = 0; - ArrayItemIndex(uint64_t array_id, uint64_t id) : array_id(array_id), id(id) {} -}; - -class ConditionerBase { -public: - virtual arrow::Status Submit( - std::vector> left_is_null_func_list, - std::vector> left_get_func_list, - std::vector> right_is_null_func_list, - std::vector> right_get_func_list, - std::function *out) { - return arrow::Status::NotImplemented( - "ConditionerBase Submit is an abstract interface."); - } -}; - -class Conditioner : public ConditionerBase { -public: - arrow::Status Submit( - std::vector> left_is_null_func_list, - std::vector> left_get_func_list, - std::vector> right_is_null_func_list, - std::vector> right_get_func_list, - std::function* out) override { - left_is_null_func_list_ = left_is_null_func_list; - left_get_func_list_ = left_get_func_list; - right_is_null_func_list_ = right_is_null_func_list; - right_get_func_list_ = right_get_func_list; - *out = [this](ArrayItemIndex left_index, int right_index) { -)"; - - std::shared_ptr func_node_visitor; - int func_count = 0; - MakeCodeGenNodeVisitorV2(func_node, {left_field_list, right_field_list}, &func_count, - &codes_ss, &func_node_visitor); - codes_ss << " return (" << func_node_visitor->GetResult() << ");" << std::endl; - codes_ss << R"( - }; - return arrow::Status::OK(); - } -private: - std::vector> left_is_null_func_list_; - std::vector> left_get_func_list_; - std::vector> right_is_null_func_list_; - std::vector> right_get_func_list_; -}; - -extern "C" void MakeConditioner(std::shared_ptr *out) { - *out = std::make_shared(); -})"; - return codes_ss.str(); - } - - int FileSpinLock(std::string path) { - std::string lockfile = path + "/nativesql_compile.lock"; - - auto fd = open(lockfile.c_str(), O_CREAT, S_IRWXU|S_IRWXG); - flock(fd, LOCK_EX); - - return fd; - } - - void FileSpinUnLock(int fd) { - flock(fd, LOCK_UN); - close(fd); - } - - arrow::Status CompileCodes(std::string codes, std::string signature) { - // temporary cpp/library output files - srand(time(NULL)); - std::string outpath = "/tmp"; - std::string prefix = "/spark-columnar-plugin-codegen-"; - std::string cppfile = outpath + prefix + signature + ".cc"; - std::string libfile = outpath + prefix + signature + ".so"; - std::string logfile = outpath + prefix + signature + ".log"; - std::ofstream out(cppfile.c_str(), std::ofstream::out); - - // output code to file - if (out.bad()) { - std::cout << "cannot open " << cppfile << std::endl; - exit(EXIT_FAILURE); - } - out << codes; - out.flush(); - out.close(); - - // compile the code - std::string cmd = "gcc -std=c++11 -Wall -Wextra " + cppfile + " -o " + libfile + - " -O3 -shared -fPIC -larrow 2> " + logfile; - int ret = system(cmd.c_str()); - if (WEXITSTATUS(ret) != EXIT_SUCCESS) { - std::cout << "compilation failed, see " << logfile << std::endl; - exit(EXIT_FAILURE); - } - - struct stat tstat; - ret = stat(libfile.c_str(), &tstat); - if (ret == -1) { - std::cout << "stat failed: " << strerror(errno) << std::endl; - exit(EXIT_FAILURE); - } - - return arrow::Status::OK(); - } - - arrow::Status LoadLibrary(std::string signature, - std::shared_ptr* out) { - std::string outpath = "/tmp"; - std::string prefix = "/spark-columnar-plugin-codegen-"; - std::string libfile = outpath + prefix + signature + ".so"; - // load dynamic library - void* dynlib = dlopen(libfile.c_str(), RTLD_LAZY); - if (!dynlib) { - return arrow::Status::Invalid(libfile, " is not generated"); - } - - // loading symbol from library and assign to pointer - // (to be cast to function pointer later) - - void (*MakeConditioner)(std::shared_ptr * out); - *(void**)(&MakeConditioner) = dlsym(dynlib, "MakeConditioner"); - const char* dlsym_error = dlerror(); - if (dlsym_error != NULL) { - std::cerr << "error loading symbol:\n" << dlsym_error << std::endl; - exit(EXIT_FAILURE); - } - - MakeConditioner(out); - return arrow::Status::OK(); - } - - class ConditionedProbeArraysResultIterator : public ResultIterator { - public: - ConditionedProbeArraysResultIterator( - arrow::compute::FunctionContext* ctx, - std::shared_ptr conditioner, std::vector right_key_indices, - std::shared_ptr concat_kernel, - std::vector> input_iterator_list, - std::vector>> left_input_cache, - std::function< - arrow::Status(const std::shared_ptr& in, - std::function ConditionCheck, - std::shared_ptr* out)> - eval_func) - : eval_func_(eval_func), - conditioner_(conditioner), - ctx_(ctx), - right_key_indices_(right_key_indices), - concat_kernel_(concat_kernel), - input_iterator_list_(input_iterator_list) { - int i = 0; - for (; i < left_input_cache.size(); i++) { - std::function is_null; - std::function get; - input_iterator_list_[i]->Submit(left_input_cache[i], &is_null, &get); - left_is_null_func_list_.push_back(is_null); - left_get_func_list_.push_back(get); - } - right_id_ = i; - condition_check_ = [this](ArrayItemIndex left_index, int right_index) { - return true; - }; - } - - std::string ToString() override { return "ConditionedProbeArraysResultIterator"; } - - arrow::Status ProcessAndCacheOne( - ArrayList in, const std::shared_ptr& selection) override { - // key preparation - std::shared_ptr in_arr; - if (right_key_indices_.size() > 1) { - ArrayList in_arr_list; - for (auto i : right_key_indices_) { - in_arr_list.push_back(in[i]); - } - RETURN_NOT_OK(concat_kernel_->Evaluate(in_arr_list, &in_arr)); - } else { - in_arr = in[right_key_indices_[0]]; - } - // condition lambda preparation - if (conditioner_) { - std::vector> right_is_null_func_list; - std::vector> right_get_func_list; - for (int i = 0; i < in.size(); i++) { - auto arr = in[i]; - std::function is_null; - std::function get; - input_iterator_list_[i + right_id_]->Submit(arr, &is_null, &get); - right_is_null_func_list.push_back(is_null); - right_get_func_list.push_back(get); - } - conditioner_->Submit(left_is_null_func_list_, left_get_func_list_, - right_is_null_func_list, right_get_func_list, - &condition_check_); - } - RETURN_NOT_OK(eval_func_(in_arr, condition_check_, &out_cache_)); - return arrow::Status::OK(); - } - - arrow::Status GetResult(std::shared_ptr* out) { - *out = out_cache_; - return arrow::Status::OK(); - } - - private: - std::function& in, - std::function ConditionCheck, - std::shared_ptr* out)> - eval_func_; - int right_id_; - std::shared_ptr conditioner_; - std::function condition_check_; - std::vector right_key_indices_; - std::shared_ptr out_cache_; - arrow::compute::FunctionContext* ctx_; - std::shared_ptr concat_kernel_; - std::vector> input_iterator_list_; - std::vector> left_is_null_func_list_; - std::vector> left_get_func_list_; - }; -}; // namespace extra - -arrow::Status ConditionedProbeArraysKernel::Make( - arrow::compute::FunctionContext* ctx, - std::vector> left_key_list, - std::vector> right_key_list, - std::shared_ptr func_node, int join_type, - std::vector> left_field_list, - std::vector> right_field_list, - std::shared_ptr* out) { - *out = std::make_shared( - ctx, left_key_list, right_key_list, func_node, join_type, left_field_list, - right_field_list); - return arrow::Status::OK(); -} - -#define PROCESS_SUPPORTED_TYPES(PROCESS) \ - PROCESS(arrow::BooleanType) \ - PROCESS(arrow::UInt8Type) \ - PROCESS(arrow::Int8Type) \ - PROCESS(arrow::UInt16Type) \ - PROCESS(arrow::Int16Type) \ - PROCESS(arrow::UInt32Type) \ - PROCESS(arrow::Int32Type) \ - PROCESS(arrow::UInt64Type) \ - PROCESS(arrow::Int64Type) \ - PROCESS(arrow::FloatType) \ - PROCESS(arrow::DoubleType) \ - PROCESS(arrow::Date32Type) \ - PROCESS(arrow::Date64Type) \ - PROCESS(arrow::Time32Type) \ - PROCESS(arrow::Time64Type) \ - PROCESS(arrow::TimestampType) \ - PROCESS(arrow::BinaryType) \ - PROCESS(arrow::StringType) \ - PROCESS(arrow::FixedSizeBinaryType) \ - PROCESS(arrow::Decimal128Type) -ConditionedProbeArraysKernel::ConditionedProbeArraysKernel( - arrow::compute::FunctionContext* ctx, - std::vector> left_key_list, - std::vector> right_key_list, - std::shared_ptr func_node, int join_type, - std::vector> left_field_list, - std::vector> right_field_list) { - auto type = left_key_list[0]->type(); - if (left_key_list.size() > 1) { - type = arrow::int64(); - } - switch (type->id()) { -#define PROCESS(InType) \ - case InType::type_id: { \ - using MemoTableType = typename arrow::internal::HashTraits::MemoTableType; \ - impl_.reset(new ConditionedProbeArraysTypedImpl( \ - ctx, left_key_list, right_key_list, func_node, join_type, left_field_list, \ - right_field_list)); \ - } break; - PROCESS_SUPPORTED_TYPES(PROCESS) -#undef PROCESS - default: - break; - } - kernel_name_ = "ConditionedProbeArraysKernel"; -} -#undef PROCESS_SUPPORTED_TYPES - -arrow::Status ConditionedProbeArraysKernel::Evaluate(const ArrayList& in) { - return impl_->Evaluate(in); -} - -arrow::Status ConditionedProbeArraysKernel::MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) { - return impl_->MakeResultIterator(schema, out); -} -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_shuffle_kernel.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_shuffle_kernel.cc deleted file mode 100644 index 430018fa7..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioned_shuffle_kernel.cc +++ /dev/null @@ -1,482 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "codegen/arrow_compute/ext/array_item_index.h" -#include "codegen/arrow_compute/ext/codegen_node_visitor.h" -#include "codegen/arrow_compute/ext/item_iterator.h" -#include "codegen/arrow_compute/ext/kernels_ext.h" -#include "codegen/arrow_compute/ext/shuffle_v2_action.h" - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { - -using ArrayList = std::vector>; - -/////////////// ConditionedShuffleArrayList //////////////// -class ConditionedShuffleArrayListKernel::Impl { - public: - Impl(arrow::compute::FunctionContext* ctx, std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::vector> output_field_list) - : ctx_(ctx), - func_node_(func_node), - left_field_list_(left_field_list), - right_field_list_(right_field_list) { - if (func_node) { - for (auto field : left_field_list) { - std::shared_ptr iter; - MakeArrayListItemIterator(field->type(), &iter); - input_iterator_list_.push_back(iter); - } - for (auto field : right_field_list) { - std::shared_ptr iter; - MakeArrayItemIterator(field->type(), &iter); - input_iterator_list_.push_back(iter); - } - } - - auto input_field_list = {left_field_list, right_field_list}; - input_cache_.resize(left_field_list.size()); - auto left_field_list_size = left_field_list.size(); - for (auto out_field : output_field_list) { - int col_id = 0; - bool found = false; - for (auto arg : input_field_list) { - for (auto col : arg) { - if (col->name() == out_field->name()) { - found = true; - break; - } - col_id++; - } - if (found) break; - } - - std::shared_ptr action; - if (col_id < left_field_list_size) { - MakeShuffleV2Action(ctx_, out_field->type(), true, &action); - } else { - MakeShuffleV2Action(ctx_, out_field->type(), false, &action); - } - output_indices_.push_back(col_id); - output_action_list_.push_back(action); - } - } - - ~Impl() {} - arrow::Status Evaluate(const ArrayList& in) { - if (in.size() != input_cache_.size()) { - return arrow::Status::Invalid( - "ConditionedShuffleArrayListKernel input arrayList size does not match " - "numCols " - "in cache, which are ", - in.size(), " and ", input_cache_.size()); - } - // we need to convert std::vector to std::vector - for (int col_id = 0; col_id < input_cache_.size(); col_id++) { - input_cache_[col_id].push_back(in[col_id]); - } - - return arrow::Status::OK(); - } - - arrow::Status SetDependencyIter( - const std::shared_ptr>& in, int index) { - in_indices_iter_ = in; - return arrow::Status::OK(); - } - - arrow::Status MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) { - // This Function suppose to return a lambda function for later ResultIterator - auto start = std::chrono::steady_clock::now(); - if (func_node_) { - LoadJITFunction(func_node_, left_field_list_, right_field_list_); - } - auto end = std::chrono::steady_clock::now(); - std::cout - << "Code Generation took " - << (std::chrono::duration_cast(end - start).count() / - 1000) - << " ms." << std::endl; - std::function, std::vector, - std::shared_ptr, ArrayList, - std::shared_ptr*)> - eval_func; - if (func_node_) { - eval_func = [this, schema](std::shared_ptr left_selection, - std::vector left_in, - std::shared_ptr right_selection, - ArrayList right_in, - std::shared_ptr* out) { - int num_rows = right_selection->length(); - std::vector> is_null_func_list; - std::vector> next_func_list; - std::vector> get_func_list; - std::vector> shuffle_func_list; - std::function is_null; - std::function next; - std::function get; - std::function shuffle; - auto left_field_list_size = left_field_list_.size(); - int i = 0; - for (auto array_list : left_in) { - RETURN_NOT_OK(input_iterator_list_[i++]->Submit(array_list, left_selection, - &next, &is_null, &get)); - is_null_func_list.push_back(is_null); - next_func_list.push_back(next); - get_func_list.push_back(get); - } - for (auto array : right_in) { - RETURN_NOT_OK(input_iterator_list_[i++]->Submit(array, right_selection, &next, - &is_null, &get)); - is_null_func_list.push_back(is_null); - next_func_list.push_back(next); - get_func_list.push_back(get); - } - i = 0; - for (auto id : output_indices_) { - if (id < left_field_list_size) { - RETURN_NOT_OK( - output_action_list_[i++]->Submit(left_in[id], left_selection, &shuffle)); - shuffle_func_list.push_back(shuffle); - } else { - RETURN_NOT_OK(output_action_list_[i++]->Submit( - right_in[id - left_field_list_size], right_selection, &shuffle)); - shuffle_func_list.push_back(shuffle); - } - } - ConditionShuffleCodeGen(num_rows, next_func_list, is_null_func_list, - get_func_list, shuffle_func_list); - ArrayList out_arr_list; - for (auto output_action : output_action_list_) { - RETURN_NOT_OK(output_action->FinishAndReset(&out_arr_list)); - } - *out = arrow::RecordBatch::Make(schema, out_arr_list[0]->length(), out_arr_list); - // arrow::PrettyPrint(*(*out).get(), 2, &std::cout); - - return arrow::Status::OK(); - }; - } else { - eval_func = [this, schema](std::shared_ptr left_selection, - std::vector left_in, - std::shared_ptr right_selection, - ArrayList right_in, - std::shared_ptr* out) { - int num_rows = right_selection->length(); - std::vector> is_null_func_list; - std::vector> next_func_list; - std::vector> get_func_list; - std::vector> shuffle_func_list; - std::function is_null; - std::function next; - std::function get; - std::function shuffle; - auto left_field_list_size = left_field_list_.size(); - int i = 0; - for (auto id : output_indices_) { - if (id < left_field_list_size) { - RETURN_NOT_OK( - output_action_list_[i++]->Submit(left_in[id], left_selection, &shuffle)); - shuffle_func_list.push_back(shuffle); - } else { - RETURN_NOT_OK(output_action_list_[i++]->Submit( - right_in[id - left_field_list_size], right_selection, &shuffle)); - shuffle_func_list.push_back(shuffle); - } - } - for (int row_id = 0; row_id < num_rows; row_id++) { - for (auto exec : shuffle_func_list) { - exec(); - } - } - ArrayList out_arr_list; - for (auto output_action : output_action_list_) { - RETURN_NOT_OK(output_action->FinishAndReset(&out_arr_list)); - } - *out = arrow::RecordBatch::Make(schema, out_arr_list[0]->length(), out_arr_list); - // arrow::PrettyPrint(*(*out).get(), 2, &std::cout); - - return arrow::Status::OK(); - }; - } - *out = std::make_shared( - ctx_, input_cache_, eval_func, in_indices_iter_); - - return arrow::Status::OK(); - } - - private: - arrow::compute::FunctionContext* ctx_; - // input_cache_ is used to cache left table for later on probe - std::shared_ptr> in_indices_iter_; - std::vector input_cache_; - - std::shared_ptr func_node_; - std::vector> left_field_list_; - std::vector> right_field_list_; - std::vector> input_iterator_list_; - std::vector output_indices_; - std::vector> output_action_list_; - void (*ConditionShuffleCodeGen)( - int num_rows, std::vector> next_func_list, - std::vector> is_null_func_list, - std::vector> get_func_list, - std::vector> output_function_list); - arrow::Status LoadJITFunction( - std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list) { - // generate ddl signature - std::stringstream func_args_ss; - func_args_ss << func_node->ToString(); - for (auto field : left_field_list) { - func_args_ss << field->ToString(); - } - for (auto field : right_field_list) { - func_args_ss << field->ToString(); - } - - std::stringstream signature_ss; - signature_ss << std::hex << std::hash{}(func_args_ss.str()); - std::string signature = signature_ss.str(); - - auto status = LoadLibrary(signature); - if (!status.ok()) { - // process - auto codes = ProduceCodes(func_node_, left_field_list_, right_field_list_); - // compile codes - RETURN_NOT_OK(CompileCodes(codes, signature)); - RETURN_NOT_OK(LoadLibrary(signature)); - } - return arrow::Status::OK(); - } - - std::string ProduceCodes(std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list) { - // CodeGen - std::stringstream codes_ss; - codes_ss << "#include " << std::endl; - codes_ss - << R"(#include )" - << std::endl; - codes_ss << "#define int8 int8_t" << std::endl; - codes_ss << "#define int16 int16_t" << std::endl; - codes_ss << "#define int32 int32_t" << std::endl; - codes_ss << "#define int64 int64_t" << std::endl; - codes_ss << "#define uint8 uint8_t" << std::endl; - codes_ss << "#define uint16 uint16_t" << std::endl; - codes_ss << "#define uint32 uint32_t" << std::endl; - codes_ss << "#define uint64 uint64_t" << std::endl; - codes_ss << "using namespace sparkcolumnarplugin::codegen::arrowcompute::extra;" - << std::endl; - codes_ss << R"(extern "C" void ConditionShuffleCodeGen(int num_rows, - std::vector> next_func_list, - std::vector> is_null_func_list, - std::vector> get_func_list, - std::vector> output_func_list) {)" - << std::endl; - codes_ss << "for (int row_id = 0; row_id < num_rows; row_id++) {" << std::endl; - codes_ss << " for (auto next : next_func_list) {" << std::endl; - codes_ss << " next();" << std::endl; - codes_ss << " }" << std::endl; - std::shared_ptr func_node_visitor; - int func_count = 0; - MakeCodeGenNodeVisitor(func_node, {left_field_list, right_field_list}, &func_count, - &codes_ss, &func_node_visitor); - codes_ss << "if (" << func_node_visitor->GetResult() << ") {" << std::endl; - codes_ss << " for (auto exec : output_func_list) {" << std::endl; - codes_ss << " exec();" << std::endl; - codes_ss << " }" << std::endl; - codes_ss << "}" << std::endl; - codes_ss << "}" << std::endl; - codes_ss << "}" << std::endl; - return codes_ss.str(); - } - - arrow::Status CompileCodes(std::string codes, std::string signature) { - // temporary cpp/library output files - srand(time(NULL)); - std::string randname = std::to_string(rand()); - std::string outpath = "/tmp"; - std::string prefix = "/spark-columnar-plugin-codegen-"; - std::string cppfile = outpath + prefix + signature + ".cc"; - std::string tmplibfile = outpath + prefix + signature + "." + randname + ".so"; - std::string libfile = outpath + prefix + signature + ".so"; - std::string logfile = outpath + prefix + signature + ".log"; - std::ofstream out(cppfile.c_str(), std::ofstream::out); - - // output code to file - if (out.bad()) { - std::cout << "cannot open " << cppfile << std::endl; - exit(EXIT_FAILURE); - } - out << codes; - out.flush(); - out.close(); - - // compile the code - std::string cmd = "gcc -Wall -Wextra " + cppfile + " -o " + tmplibfile + - " -O3 -shared -fPIC -lspark_columnar_jni > " + logfile; - int ret = system(cmd.c_str()); - if (WEXITSTATUS(ret) != EXIT_SUCCESS) { - std::cout << "compilation failed, see " << logfile << std::endl; - exit(EXIT_FAILURE); - } - - cmd = "mv -n " + tmplibfile + " " + libfile; - ret = system(cmd.c_str()); - cmd = "rm -rf " + tmplibfile; - ret = system(cmd.c_str()); - - return arrow::Status::OK(); - } - - arrow::Status LoadLibrary(std::string signature) { - std::string outpath = "/tmp"; - std::string prefix = "/spark-columnar-plugin-codegen-"; - std::string libfile = outpath + prefix + signature + ".so"; - // load dynamic library - void* dynlib = dlopen(libfile.c_str(), RTLD_LAZY); - if (!dynlib) { - return arrow::Status::Invalid(libfile, " is not generated"); - } - - // loading symbol from library and assign to pointer - // (to be cast to function pointer later) - *(void**)(&ConditionShuffleCodeGen) = dlsym(dynlib, "ConditionShuffleCodeGen"); - const char* dlsym_error = dlerror(); - if (dlsym_error != NULL) { - std::cerr << "error loading symbol:\n" << dlsym_error << std::endl; - exit(EXIT_FAILURE); - } - - return arrow::Status::OK(); - } - - class ConditionShuffleCodeGenResultIterator - : public ResultIterator { - public: - ConditionShuffleCodeGenResultIterator( - arrow::compute::FunctionContext* ctx, std::vector input_cache, - std::function, std::vector, - std::shared_ptr, ArrayList, - std::shared_ptr* out)> - eval_func, - std::shared_ptr> in_indices_iter) - : ctx_(ctx), - input_cache_(input_cache), - eval_func_(eval_func), - in_indices_iter_(in_indices_iter) {} - - std::string ToString() override { return "ConditionedShuffleArraysResultIterator"; } - - arrow::Status Process( - std::vector> in, - std::shared_ptr* out, - const std::shared_ptr& selection = nullptr) override { - std::shared_ptr indices_out; - if (!in_indices_iter_) { - return arrow::Status::Invalid( - "ConditionShuffleCodeGenResultIterator in_indices_iter_ is not set."); - } - RETURN_NOT_OK(in_indices_iter_->GetResult(&indices_out)); - auto left_selection = indices_out->column(0); - auto right_selection = indices_out->column(1); - RETURN_NOT_OK(eval_func_(left_selection, input_cache_, right_selection, in, out)); - return arrow::Status::OK(); - } - - private: - arrow::compute::FunctionContext* ctx_; - std::function, std::vector, - std::shared_ptr, ArrayList, - std::shared_ptr* out)> - eval_func_; - std::shared_ptr> in_indices_iter_; - std::vector input_cache_; - }; -}; - -arrow::Status ConditionedShuffleArrayListKernel::Make( - arrow::compute::FunctionContext* ctx, std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::vector> output_field_list, - std::shared_ptr* out) { - *out = std::make_shared( - ctx, func_node, left_field_list, right_field_list, output_field_list); - return arrow::Status::OK(); -} - -ConditionedShuffleArrayListKernel::ConditionedShuffleArrayListKernel( - arrow::compute::FunctionContext* ctx, std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::vector> output_field_list) { - impl_.reset( - new Impl(ctx, func_node, left_field_list, right_field_list, output_field_list)); - kernel_name_ = "ConditionedShuffleArrayListKernel"; -} - -arrow::Status ConditionedShuffleArrayListKernel::Evaluate(const ArrayList& in) { - return impl_->Evaluate(in); -} - -arrow::Status ConditionedShuffleArrayListKernel::MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) { - return impl_->MakeResultIterator(schema, out); -} - -arrow::Status ConditionedShuffleArrayListKernel::SetDependencyIter( - const std::shared_ptr>& in, int index) { - return impl_->SetDependencyIter(in, index); -} -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioner.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioner.h deleted file mode 100644 index a92378327..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/conditioner.h +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "codegen/arrow_compute/ext/array_item_index.h" - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { - -class ConditionerBase { - public: - virtual arrow::Status Submit( - std::vector> left_is_null_func_list, - std::vector> left_get_func_list, - std::vector> right_is_null_func_list, - std::vector> right_get_func_list, - std::function* out) { - return arrow::Status::NotImplemented( - "ConditionerBase Submit is an abstract interface."); - } -}; -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/item_iterator.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/item_iterator.cc deleted file mode 100644 index b091e838c..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/item_iterator.cc +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "codegen/arrow_compute/ext/item_iterator.h" -#include -#include -#include -#include -#include - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { - -using namespace arrow; -using ArrayList = std::vector>; - -template -class ItemIteratorTypedImpl; - -//////////////// ItemIterator /////////////// -class ItemIterator::Impl { - public: -#define PROCESS_SUPPORTED_TYPES(PROCESS) \ - PROCESS(UInt8Type) \ - PROCESS(Int8Type) \ - PROCESS(UInt16Type) \ - PROCESS(Int16Type) \ - PROCESS(UInt32Type) \ - PROCESS(Int32Type) \ - PROCESS(UInt64Type) \ - PROCESS(Int64Type) \ - PROCESS(FloatType) \ - PROCESS(DoubleType) \ - PROCESS(Date32Type) - - static arrow::Status MakeItemIteratorImpl(std::shared_ptr type, - bool is_array_list, - std::shared_ptr* out) { - switch (type->id()) { -#define PROCESS(InType) \ - case InType::type_id: { \ - using CType = typename arrow::TypeTraits::CType; \ - auto res = std::make_shared>(is_array_list); \ - *out = std::dynamic_pointer_cast(res); \ - } break; - PROCESS_SUPPORTED_TYPES(PROCESS) -#undef PROCESS - case arrow::StringType::type_id: { - auto res = std::make_shared>( - is_array_list); - *out = std::dynamic_pointer_cast(res); - } break; - default: { - std::cout << "Not Found " << type->ToString() << ", type id is " << type->id() - << std::endl; - } break; - } - return arrow::Status::OK(); - } - -#undef PROCESS_SUPPORTED_TYPES - Impl() {} - virtual ~Impl() {} - virtual arrow::Status Submit(ArrayList in_arr_list, - std::shared_ptr selection, - std::function* next, - std::function* is_null, - std::function* get) { - throw 1; - } - virtual arrow::Status Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* next, - std::function* is_null, - std::function* get) { - throw 1; - } - virtual arrow::Status Submit(ArrayList in_arr_list, - std::function* is_null, - std::function* get) { - throw 1; - } - virtual arrow::Status Submit(std::shared_ptr in_arr, - std::function* is_null, - std::function* get) { - throw 1; - } -}; - -template -class ItemIteratorTypedImpl : public ItemIterator::Impl { - public: - ItemIteratorTypedImpl(bool is_array_list) { -#ifdef DEBUG - std::cout << "ItemIteratorTypedImpl constructed" << std::endl; -#endif - is_array_list_ = is_array_list; - if (is_array_list_) { - next_ = [this]() { - row_id_++; - is_null_cache_ = true; - if (!selection_->IsNull(row_id_)) { - auto item = structed_selection_[row_id_]; - is_null_cache_ = typed_in_arr_list_[item.array_id]->IsNull(item.id); - if (!is_null_cache_) { - CType res(typed_in_arr_list_[item.array_id]->GetView(item.id)); - res_cache_ = res; - } - } - return arrow::Status::OK(); - }; - - } else { - next_ = [this]() { - row_id_++; - is_null_cache_ = true; - if (!selection_->IsNull(row_id_)) { - auto item = uint32_selection_[row_id_]; - is_null_cache_ = typed_in_arr_->IsNull(item); - if (!is_null_cache_) { - CType res(typed_in_arr_->GetView(item)); - res_cache_ = res; - } - } - return arrow::Status::OK(); - }; - } - is_null_ = [this]() { return is_null_cache_; }; - get_ = [this]() { return (void*)&res_cache_; }; - is_null_with_item_index_ = [this](ArrayItemIndex item) { - return typed_in_arr_list_[item.array_id]->IsNull(item.id); - }; - is_null_with_index_ = [this](int item) { return typed_in_arr_->IsNull(item); }; - get_with_item_index_ = [this](ArrayItemIndex item) { - CType res(typed_in_arr_list_[item.array_id]->GetView(item.id)); - res_cache_ = res; - return (void*)&res_cache_; - }; - get_with_index_ = [this](int item) { - CType res(typed_in_arr_->GetView(item)); - res_cache_ = res; - return (void*)&res_cache_; - }; - } - - ~ItemIteratorTypedImpl() { -#ifdef DEBUG - std::cout << "ItemIteratorTypedImpl destructed" << std::endl; -#endif - } - - arrow::Status Submit(ArrayList in_arr_list, std::shared_ptr selection, - std::function* next, - std::function* is_null, std::function* get) { - if (typed_in_arr_list_.size() == 0) { - for (auto arr : in_arr_list) { - typed_in_arr_list_.push_back(std::dynamic_pointer_cast(arr)); - } - } - row_id_ = -1; - selection_ = selection; - structed_selection_ = - (ArrayItemIndex*)std::dynamic_pointer_cast(selection) - ->raw_values(); - *next = next_; - *is_null = is_null_; - *get = get_; - return arrow::Status::OK(); - } - - arrow::Status Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* next, - std::function* is_null, std::function* get) { - typed_in_arr_ = std::dynamic_pointer_cast(in_arr); - row_id_ = -1; - selection_ = selection; - uint32_selection_ = - (uint32_t*)std::dynamic_pointer_cast(selection)->raw_values(); - *next = next_; - *is_null = is_null_; - *get = get_; - return arrow::Status::OK(); - } - - arrow::Status Submit(ArrayList in_arr_list, - std::function* is_null, - std::function* get) { - if (typed_in_arr_list_.size() == 0) { - for (auto arr : in_arr_list) { - typed_in_arr_list_.push_back(std::dynamic_pointer_cast(arr)); - } - } - *is_null = is_null_with_item_index_; - *get = get_with_item_index_; - return arrow::Status::OK(); - } - - arrow::Status Submit(std::shared_ptr in_arr, - std::function* is_null, - std::function* get) { - typed_in_arr_ = std::dynamic_pointer_cast(in_arr); - *is_null = is_null_with_index_; - *get = get_with_index_; - return arrow::Status::OK(); - } - - private: - using ArrayType = typename arrow::TypeTraits::ArrayType; - std::vector> typed_in_arr_list_; - std::shared_ptr typed_in_arr_; - ArrayItemIndex* structed_selection_; - std::shared_ptr selection_; - uint32_t* uint32_selection_; - uint32_t row_id_ = -1; - std::function is_null_; - std::function get_; - std::function is_null_with_item_index_; - std::function get_with_item_index_; - std::function is_null_with_index_; - std::function get_with_index_; - std::function next_; - bool is_array_list_; - // cache result when calling Next() for optimization purpose - uint64_t total_length_; - CType res_cache_; - bool is_null_cache_; -}; - -///////////////////// Public Functions ////////////////// -ItemIterator::ItemIterator(std::shared_ptr type, bool is_array_list) { - auto status = Impl::MakeItemIteratorImpl(type, is_array_list, &impl_); -} - -arrow::Status ItemIterator::Submit(ArrayList in_arr_list, - std::shared_ptr selection, - std::function* next, - std::function* is_null, - std::function* get) { - return impl_->Submit(in_arr_list, selection, next, is_null, get); -} -arrow::Status ItemIterator::Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* next, - std::function* is_null, - std::function* get) { - return impl_->Submit(in_arr, selection, next, is_null, get); -} -arrow::Status ItemIterator::Submit(ArrayList in_arr_list, - std::function* is_null, - std::function* get) { - return impl_->Submit(in_arr_list, is_null, get); -} -arrow::Status ItemIterator::Submit(std::shared_ptr in_arr, - std::function* is_null, - std::function* get) { - return impl_->Submit(in_arr, is_null, get); -} -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/item_iterator.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/item_iterator.h deleted file mode 100644 index 4ee434fc4..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/item_iterator.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "codegen/arrow_compute/ext/array_item_index.h" - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { -using ArrayList = std::vector>; -class ItemIterator { - public: - ItemIterator(std::shared_ptr type, bool is_array_list); - ~ItemIterator() {} - arrow::Status Submit(ArrayList in_arr_list, std::shared_ptr selection, - std::function* next, - std::function* is_null, std::function* get); - arrow::Status Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* next, - std::function* is_null, std::function* get); - arrow::Status Submit(ArrayList in_arr_list, - std::function* is_null, - std::function* get); - arrow::Status Submit(std::shared_ptr in_arr, - std::function* is_null, std::function* get); - class Impl; - - private: - std::shared_ptr impl_; -}; -static arrow::Status MakeArrayListItemIterator(std::shared_ptr type, - std::shared_ptr* out) { - *out = std::make_shared(type, true); - return arrow::Status::OK(); -} -static arrow::Status MakeArrayItemIterator(std::shared_ptr type, - std::shared_ptr* out) { - *out = std::make_shared(type, false); - return arrow::Status::OK(); -} -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc index 562fe6b06..5f589e003 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc @@ -38,7 +38,6 @@ #include #include #include -#include "third_party/arrow/utils/hashing.h" #include #include @@ -48,9 +47,9 @@ #include "codegen/arrow_compute/ext/actions_impl.h" #include "codegen/arrow_compute/ext/array_item_index.h" +#include "codegen/arrow_compute/ext/codegen_common.h" #include "codegen/arrow_compute/ext/codegen_node_visitor.h" -#include "codegen/arrow_compute/ext/item_iterator.h" -#include "codegen/arrow_compute/ext/shuffle_v2_action.h" +#include "third_party/arrow/utils/hashing.h" #include "utils/macros.h" namespace sparkcolumnarplugin { @@ -58,7 +57,6 @@ namespace codegen { namespace arrowcompute { namespace extra { -#define MAXBATCHNUMROWS 10000 using ArrayList = std::vector>; /////////////// SplitArrayListWithAction //////////////// @@ -217,7 +215,9 @@ class SplitArrayListWithActionKernel::Impl { *out = nullptr; return arrow::Status::OK(); } - auto length = (total_length_ - offset_) > 4096 ? 4096 : (total_length_ - offset_); + auto length = (total_length_ - offset_) > GetBatchSize() + ? GetBatchSize() + : (total_length_ - offset_); TIME_MICRO_OR_RAISE(elapse_time_, eval_func_(offset_, length, out)); offset_ += length; // arrow::PrettyPrint(*(*out).get(), 2, &std::cout); @@ -922,22 +922,22 @@ class HashAggrArrayKernel::Impl { field_list.push_back(field); auto field_node = gandiva::TreeExprBuilder::MakeField(field); auto func_node = - gandiva::TreeExprBuilder::MakeFunction("hash64", {field_node}, arrow::int64()); + gandiva::TreeExprBuilder::MakeFunction("hash32", {field_node}, arrow::int32()); func_node_list.push_back(func_node); if (func_node_list.size() == 2) { auto shift_func_node = gandiva::TreeExprBuilder::MakeFunction( "multiply", - {func_node_list[0], gandiva::TreeExprBuilder::MakeLiteral((int64_t)10)}, - arrow::int64()); + {func_node_list[0], gandiva::TreeExprBuilder::MakeLiteral((int32_t)10)}, + arrow::int32()); auto tmp_func_node = gandiva::TreeExprBuilder::MakeFunction( - "add", {shift_func_node, func_node_list[1]}, arrow::int64()); + "add", {shift_func_node, func_node_list[1]}, arrow::int32()); func_node_list.clear(); func_node_list.push_back(tmp_func_node); } index++; } expr = gandiva::TreeExprBuilder::MakeExpression(func_node_list[0], - arrow::field("res", arrow::int64())); + arrow::field("res", arrow::int32())); std::cout << expr->ToString() << std::endl; schema_ = arrow::schema(field_list); auto configuration = gandiva::ConfigurationBuilder().DefaultConfiguration(); diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.h index 865de7e61..75cfc1c0e 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.h +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/kernels_ext.h @@ -36,9 +36,6 @@ class KernalBase { public: KernalBase() {} ~KernalBase() {} - virtual arrow::Status SetMember(const std::shared_ptr& in) { - return arrow::Status::NotImplemented("KernalBase abstract interface."); - } virtual arrow::Status Evaluate(const ArrayList& in) { return arrow::Status::NotImplemented("Evaluate is abstract interface for ", kernel_name_, ", input is arrayList."); @@ -59,20 +56,6 @@ class KernalBase { kernel_name_, ", input is arrayList, output is array."); } - virtual arrow::Status Evaluate(const std::shared_ptr& in) { - return arrow::Status::NotImplemented("Evaluate is abstract interface for ", - kernel_name_, ", input is array."); - } - virtual arrow::Status Evaluate(const std::shared_ptr& selection_arr, - const std::shared_ptr& in) { - return arrow::Status::NotImplemented("Evaluate is abstract interface for ", - kernel_name_, - ", input is selection_array and array."); - } - virtual arrow::Status Evaluate(const std::shared_ptr& in, int group_id) { - return arrow::Status::NotImplemented("Evaluate is abstract interface for ", - kernel_name_, ", input is array and group_id"); - } virtual arrow::Status Evaluate(const std::shared_ptr& in, std::shared_ptr* out) { return arrow::Status::NotImplemented("Evaluate is abstract interface for ", @@ -83,23 +66,6 @@ class KernalBase { return arrow::Status::NotImplemented("Finish is abstract interface for ", kernel_name_, ", output is arrayList"); } - virtual arrow::Status Finish(std::vector* out) { - return arrow::Status::NotImplemented("Finish is abstract interface for ", - kernel_name_, ", output is batchList"); - } - virtual arrow::Status Finish(std::shared_ptr* out) { - return arrow::Status::NotImplemented("Finish is abstract interface for ", - kernel_name_, ", output is array"); - } - virtual arrow::Status SetDependencyInput(const std::shared_ptr& in) { - return arrow::Status::NotImplemented("SetDependencyInput is abstract interface for ", - kernel_name_, ", input is array"); - } - virtual arrow::Status SetDependencyIter( - const std::shared_ptr>& in, int index) { - return arrow::Status::NotImplemented("SetDependencyIter is abstract interface for ", - kernel_name_, ", input is array"); - } virtual arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) { @@ -294,48 +260,25 @@ class SortArraysToIndicesKernel : public KernalBase { arrow::compute::FunctionContext* ctx_; };*/ -class ConditionedShuffleArrayListKernel : public KernalBase { - public: - static arrow::Status Make(arrow::compute::FunctionContext* ctx, - std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::vector> output_field_list, - std::shared_ptr* out); - ConditionedShuffleArrayListKernel( - arrow::compute::FunctionContext* ctx, std::shared_ptr func_node, - std::vector> left_field_list, - std::vector> right_field_list, - std::vector> output_field_list); - arrow::Status Evaluate(const ArrayList& in) override; - arrow::Status SetDependencyIter( - const std::shared_ptr>& in, int index) override; - arrow::Status MakeResultIterator( - std::shared_ptr schema, - std::shared_ptr>* out) override; - - private: - class Impl; - std::unique_ptr impl_; - arrow::compute::FunctionContext* ctx_; -}; - class ConditionedProbeArraysKernel : public KernalBase { public: - static arrow::Status Make(arrow::compute::FunctionContext* ctx, - std::vector> left_key_list, - std::vector> right_key_list, - std::shared_ptr func_node, int join_type, - std::vector> left_field_list, - std::vector> right_field_list, - std::shared_ptr* out); + static arrow::Status Make( + arrow::compute::FunctionContext* ctx, + const std::vector>& left_key_list, + const std::vector>& right_key_list, + const std::shared_ptr& func_node, int join_type, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + const std::shared_ptr& result_schema, + std::shared_ptr* out); ConditionedProbeArraysKernel( arrow::compute::FunctionContext* ctx, - std::vector> left_key_list, - std::vector> right_key_list, - std::shared_ptr func_node, int join_type, - std::vector> left_field_list, - std::vector> right_field_list); + const std::vector>& left_key_list, + const std::vector>& right_key_list, + const std::shared_ptr& func_node, int join_type, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + const std::shared_ptr& result_schema); arrow::Status Evaluate(const ArrayList& in) override; arrow::Status MakeResultIterator( std::shared_ptr schema, diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/probe_kernel.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/probe_kernel.cc new file mode 100644 index 000000000..2688ff791 --- /dev/null +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/probe_kernel.cc @@ -0,0 +1,968 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "codegen/arrow_compute/ext/array_item_index.h" +#include "codegen/arrow_compute/ext/code_generator_base.h" +#include "codegen/arrow_compute/ext/codegen_common.h" +#include "codegen/arrow_compute/ext/codegen_node_visitor.h" +#include "codegen/arrow_compute/ext/kernels_ext.h" +#include "utils/macros.h" + +namespace sparkcolumnarplugin { +namespace codegen { +namespace arrowcompute { +namespace extra { + +using ArrayList = std::vector>; + +/////////////// ConditionedProbeArrays //////////////// +class ConditionedProbeArraysKernel::Impl { + public: + Impl(arrow::compute::FunctionContext* ctx, + const std::vector>& left_key_list, + const std::vector>& right_key_list, + const std::shared_ptr& func_node, int join_type, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + const std::shared_ptr& result_schema) + : ctx_(ctx) { + std::vector left_key_index_list; + THROW_NOT_OK(GetIndexList(left_key_list, left_field_list, &left_key_index_list)); + std::vector right_key_index_list; + THROW_NOT_OK(GetIndexList(right_key_list, right_field_list, &right_key_index_list)); + + std::vector left_shuffle_index_list; + std::vector right_shuffle_index_list; + THROW_NOT_OK( + GetIndexListFromSchema(result_schema, left_field_list, &left_shuffle_index_list)); + THROW_NOT_OK(GetIndexListFromSchema(result_schema, right_field_list, + &right_shuffle_index_list)); + + std::vector> result_schema_index_list; + THROW_NOT_OK(GetResultIndexList(result_schema, left_field_list, right_field_list, + &result_schema_index_list)); + + THROW_NOT_OK(LoadJITFunction(func_node, join_type, left_key_index_list, + right_key_index_list, left_shuffle_index_list, + right_shuffle_index_list, left_field_list, + right_field_list, result_schema_index_list, &prober_)); + } + + arrow::Status Evaluate(const ArrayList& in_arr_list) { + // cache in_arr_list for prober data shuffling + RETURN_NOT_OK(prober_->Evaluate(in_arr_list)); + return arrow::Status::OK(); + } + + arrow::Status MakeResultIterator( + std::shared_ptr schema, + std::shared_ptr>* out) { + RETURN_NOT_OK(prober_->MakeResultIterator(schema, out)); + return arrow::Status::OK(); + } + + private: + using ArrayType = typename arrow::TypeTraits::ArrayType; + + arrow::compute::FunctionContext* ctx_; + std::shared_ptr prober_; + + arrow::Status GetIndexList( + const std::vector>& target_list, + const std::vector>& source_list, + std::vector* out) { + for (auto key_field : target_list) { + int i = 0; + for (auto field : source_list) { + if (key_field->name() == field->name()) { + break; + } + i++; + } + (*out).push_back(i); + } + return arrow::Status::OK(); + } + + arrow::Status GetIndexListFromSchema( + const std::shared_ptr& result_schema, + const std::vector>& field_list, + std::vector* index_list) { + int i = 0; + for (auto field : field_list) { + auto indices = result_schema->GetAllFieldIndices(field->name()); + if (indices.size() == 1) { + (*index_list).push_back(i); + } + i++; + } + return arrow::Status::OK(); + } + + arrow::Status GetResultIndexList( + const std::shared_ptr& result_schema, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + std::vector>* result_schema_index_list) { + int i = 0; + bool found = false; + for (auto target_field : result_schema->fields()) { + i = 0; + found = false; + for (auto field : left_field_list) { + if (target_field->name() == field->name()) { + (*result_schema_index_list).push_back(std::make_pair(0, i)); + found = true; + break; + } + i++; + } + if (found == true) continue; + i = 0; + for (auto field : right_field_list) { + if (target_field->name() == field->name()) { + (*result_schema_index_list).push_back(std::make_pair(1, i)); + break; + } + i++; + } + } + return arrow::Status::OK(); + } + arrow::Status LoadJITFunction( + const std::shared_ptr& func_node, int join_type, + const std::vector& left_key_index_list, + const std::vector& right_key_index_list, + const std::vector& left_shuffle_index_list, + const std::vector& right_shuffle_index_list, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + const std::vector>& result_schema_index_list, + std::shared_ptr* out) { + // generate ddl signature + std::stringstream func_args_ss; + func_args_ss << "" + << "[JoinType]" << join_type; + if (func_node) { + func_args_ss << "[cond]" << func_node->ToString(); + } + func_args_ss << "[BuildSchema]"; + for (auto field : left_field_list) { + func_args_ss << field->ToString(); + } + func_args_ss << "[ProbeSchema]"; + for (auto field : right_field_list) { + func_args_ss << field->ToString(); + } + func_args_ss << "[LeftKeyIndex]"; + for (auto i : left_key_index_list) { + func_args_ss << i << ","; + } + func_args_ss << "[RightKeyIndex]"; + for (auto i : right_key_index_list) { + func_args_ss << i << ","; + } + func_args_ss << "[LeftShuffleIndex]"; + for (auto i : left_shuffle_index_list) { + func_args_ss << i << ","; + } + func_args_ss << "[RightShuffleIndex]"; + for (auto i : right_shuffle_index_list) { + func_args_ss << i << ","; + } + + std::stringstream signature_ss; + signature_ss << std::hex << std::hash{}(func_args_ss.str()); + std::string signature = signature_ss.str(); + + auto file_lock = FileSpinLock(); + auto status = LoadLibrary(signature, ctx_, out); + if (!status.ok()) { + // process + auto codes = + ProduceCodes(func_node, join_type, left_key_index_list, right_key_index_list, + left_shuffle_index_list, right_shuffle_index_list, left_field_list, + right_field_list, result_schema_index_list); + // compile codes + RETURN_NOT_OK(CompileCodes(codes, signature)); + RETURN_NOT_OK(LoadLibrary(signature, ctx_, out)); + } + FileSpinUnLock(file_lock); + return arrow::Status::OK(); + } + + class TypedProberCodeGenImpl { + public: + TypedProberCodeGenImpl(std::string indice, std::string dataTypeName, bool left = true) + : indice_(indice), dataTypeName_(dataTypeName), left_(left) {} + std::string GetImplCachedDefine() { + std::stringstream ss; + ss << "using DataType_" << indice_ << " = typename arrow::" << dataTypeName_ << ";" + << std::endl; + ss << "using ArrayType_" << indice_ << " = typename arrow::TypeTraits::ArrayType;" << std::endl; + ss << "std::vector> cached_" << indice_ + << "_;" << std::endl; + return ss.str(); + } + std::string GetResultIteratorPrepare() { + std::stringstream ss; + ss << "std::unique_ptr builder_" << indice_ << ";" + << std::endl; + ss << "arrow::MakeBuilder(ctx_->memory_pool(), data_type_" << indice_ + << ", &builder_" << indice_ << ");" << std::endl; + ss << "builder_" << indice_ << "_.reset(arrow::internal::checked_cast(builder_" << indice_ << ".release()));" << std::endl; + return ss.str(); + } + std::string GetProcessFinish() { + std::stringstream ss; + ss << "std::shared_ptr out_" << indice_ << ";" << std::endl; + ss << "RETURN_NOT_OK(builder_" << indice_ << "_->Finish(&out_" << indice_ << "));" + << std::endl; + ss << "builder_" << indice_ << "_->Reset();" << std::endl; + return ss.str(); + } + std::string GetProcessOutList() { + std::stringstream ss; + ss << "out_" << indice_; + return ss.str(); + } + std::string GetResultIterCachedDefine() { + std::stringstream ss; + ss << "using DataType_" << indice_ << " = typename arrow::" << dataTypeName_ << ";" + << std::endl; + ss << "using ArrayType_" << indice_ << " = typename arrow::TypeTraits::ArrayType;" << std::endl; + ss << "using BuilderType_" << indice_ + << " = typename " + "arrow::TypeTraits::BuilderType;" << std::endl; + if (left_) { + ss << "std::vector> cached_" + << indice_ << "_;" << std::endl; + } else { + ss << "std::shared_ptr cached_" << indice_ << "_;" + << std::endl; + } + ss << "std::shared_ptr data_type_" << indice_ + << " = arrow::TypeTraits::type_singleton();" + << std::endl; + ss << "std::shared_ptr builder_" << indice_ << "_;" + << std::endl; + return ss.str(); + } + + private: + std::string indice_; + std::string dataTypeName_; + bool left_; + }; + std::string GetJoinKeyTypeListDefine( + std::vector key_index_list, + const std::vector>& field_list) { + std::stringstream ss; + for (int i = 0; i < key_index_list.size(); i++) { + auto field = field_list[key_index_list[i]]; + if (i != (key_index_list.size() - 1)) { + ss << "arrow::" << GetArrowTypeDefString(field->type()) << ", "; + } else { + ss << "arrow::" << GetArrowTypeDefString(field->type()); + } + } + return ss.str(); + } + std::string GetEvaluateCacheInsert(const std::vector& index_list) { + std::stringstream ss; + for (auto i : index_list) { + ss << "cached_0_" << i << "_.push_back(std::dynamic_pointer_cast(in[" << i << "]));" << std::endl; + } + return ss.str(); + } + std::string GetEncodeJoinKey(std::vector key_indices) { + std::stringstream ss; + for (int i = 0; i < key_indices.size(); i++) { + if (i != (key_indices.size() - 1)) { + ss << "in[" << key_indices[i] << "], "; + } else { + ss << "in[" << key_indices[i] << "]"; + } + } + return ss.str(); + } + std::string GetFinishCachedParameter(const std::vector& key_indices) { + std::stringstream ss; + for (int i = 0; i < key_indices.size(); i++) { + if (i != (key_indices.size() - 1)) { + ss << "cached_0_" << key_indices[i] << "_, "; + } else { + ss << "cached_0_" << key_indices[i] << "_"; + } + } + auto ret = ss.str(); + if (ret.empty()) { + return ret; + } else { + return ", " + ret; + } + } + std::string GetImplCachedDefine( + std::vector> codegen_list) { + std::stringstream ss; + for (auto codegen : codegen_list) { + ss << codegen->GetImplCachedDefine() << std::endl; + } + return ss.str(); + } + std::string GetResultIteratorParams(std::vector key_indices) { + std::stringstream ss; + for (int i = 0; i < key_indices.size(); i++) { + if (i != (key_indices.size() - 1)) { + ss << "const std::vector> &cached_0_" << key_indices[i] << ", " << std::endl; + } else { + ss << "const std::vector> &cached_0_" << key_indices[i]; + } + } + auto ret = ss.str(); + if (ret.empty()) { + return ret; + } else { + return ", " + ret; + } + } + std::string GetResultIteratorSet(std::vector key_indices) { + std::stringstream ss; + for (auto i : key_indices) { + ss << "cached_0_" << i << "_ = cached_0_" << i << ";" << std::endl; + } + return ss.str(); + } + std::string GetResultIteratorPrepare( + std::vector> left_codegen_list, + std::vector> right_codegen_list) { + std::stringstream ss; + for (auto codegen : left_codegen_list) { + ss << codegen->GetResultIteratorPrepare() << std::endl; + } + for (auto codegen : right_codegen_list) { + ss << codegen->GetResultIteratorPrepare() << std::endl; + } + return ss.str(); + } + std::string GetProcessRightSet(std::vector indices) { + std::stringstream ss; + for (auto i : indices) { + ss << "cached_1_" << i << "_ = std::dynamic_pointer_cast(in[" << i << "]);" << std::endl; + } + return ss.str(); + } + std::string GetProcessFinish( + std::vector> left_codegen_list, + std::vector> right_codegen_list) { + std::stringstream ss; + for (auto codegen : left_codegen_list) { + ss << codegen->GetProcessFinish() << std::endl; + } + for (auto codegen : right_codegen_list) { + ss << codegen->GetProcessFinish() << std::endl; + } + return ss.str(); + } + std::string GetProcessOutList( + const std::vector>& result_schema_index_list, + std::vector> left_codegen_list, + std::vector> right_codegen_list) { + std::stringstream ss; + auto item_count = result_schema_index_list.size(); + int i = 0; + for (auto index : result_schema_index_list) { + std::shared_ptr codegen; + if (index.first == 0) { + codegen = left_codegen_list[index.second]; + } else { + codegen = right_codegen_list[index.second]; + } + if (i++ != (item_count - 1)) { + ss << codegen->GetProcessOutList() << ", "; + } else { + ss << codegen->GetProcessOutList(); + } + } + return ss.str(); + } + std::string GetResultIterCachedDefine( + std::vector> left_codegen_list, + std::vector> right_codegen_list) { + std::stringstream ss; + for (auto codegen : left_codegen_list) { + ss << codegen->GetResultIterCachedDefine() << std::endl; + } + for (auto codegen : right_codegen_list) { + ss << codegen->GetResultIterCachedDefine() << std::endl; + } + return ss.str(); + } + std::string GetInnerJoin(bool cond_check, + const std::vector& left_shuffle_index_list, + const std::vector& right_shuffle_index_list) { + std::stringstream ss; + for (auto i : left_shuffle_index_list) { + ss << "RETURN_NOT_OK(builder_0_" << i << "_->Append(cached_0_" << i + << "_[tmp.array_id]->GetView(tmp." + "id)));" + << std::endl; + } + for (auto i : right_shuffle_index_list) { + ss << "RETURN_NOT_OK(builder_1_" << i << "_->Append(cached_1_" << i + << "_->GetView(i)));" << std::endl; + } + std::string shuffle_str; + if (cond_check) { + shuffle_str = R"( + if (ConditionCheck(tmp, i)) { + )" + ss.str() + + R"( + out_length += 1; + } + )"; + } else { + shuffle_str = R"( + )" + ss.str() + + R"( + out_length += 1; + )"; + } + return R"( + if (!typed_array->IsNull(i)) { + auto index = hash_table_->Get(typed_array->GetView(i)); + if (index != -1) { + for (auto tmp : (*memo_index_to_arrayid_)[index]) { + )" + + shuffle_str + R"( + } + } + } + )"; + } + std::string GetOuterJoin(bool cond_check, + const std::vector& left_shuffle_index_list, + const std::vector& right_shuffle_index_list) { + std::stringstream left_null_ss; + std::stringstream left_valid_ss; + std::stringstream right_valid_ss; + for (auto i : left_shuffle_index_list) { + left_valid_ss << "RETURN_NOT_OK(builder_0_" << i << "_->Append(cached_0_" << i + << "_[tmp.array_id]->GetView(tmp." + "id)));" + << std::endl; + left_null_ss << "RETURN_NOT_OK(builder_0_" << i << "_->AppendNull());" << std::endl; + } + for (auto i : right_shuffle_index_list) { + right_valid_ss << "RETURN_NOT_OK(builder_1_" << i << "_->Append(cached_1_" << i + << "_->GetView(i)));" << std::endl; + } + std::string shuffle_str; + if (cond_check) { + shuffle_str = R"( + if (ConditionCheck(tmp, i)) { + )" + left_valid_ss.str() + + right_valid_ss.str() + R"( + out_length += 1; + } + )"; + } else { + shuffle_str = R"( + )" + left_valid_ss.str() + + right_valid_ss.str() + R"( + out_length += 1; + )"; + } + return R"( + int32_t index; + if (!typed_array->IsNull(i)) { + index = hash_table_->Get(typed_array->GetView(i)); + } else { + index = hash_table_->GetNull(); + } + if (index == -1) { + )" + + left_null_ss.str() + right_valid_ss.str() + R"( + out_length += 1; + } else { + for (auto tmp : (*memo_index_to_arrayid_)[index]) { + )" + + shuffle_str + R"( + } + } + )"; + } + std::string GetAntiJoin(bool cond_check, + const std::vector& left_shuffle_index_list, + const std::vector& right_shuffle_index_list) { + std::stringstream left_null_ss; + std::stringstream right_valid_ss; + for (auto i : left_shuffle_index_list) { + left_null_ss << "RETURN_NOT_OK(builder_0_" << i << "_->AppendNull());" << std::endl; + } + for (auto i : right_shuffle_index_list) { + right_valid_ss << "RETURN_NOT_OK(builder_1_" << i << "_->Append(cached_1_" << i + << "_->GetView(i)));" << std::endl; + } + std::string shuffle_str; + if (cond_check) { + shuffle_str = R"( + } else { + bool found = false; + for (auto tmp : (*memo_index_to_arrayid_)[index]) { + if (ConditionCheck(tmp, i)) { + found = true; + break; + } + } + if (!found) { + )" + left_null_ss.str() + + right_valid_ss.str() + R"( + out_length += 1; + } + )"; + } + return R"( + int32_t index; + if (!typed_array->IsNull(i)) { + index = hash_table_->Get(typed_array->GetView(i)); + } else { + index = hash_table_->GetNull(); + } + if (index == -1) { + )" + + left_null_ss.str() + right_valid_ss.str() + R"( + out_length += 1; + )" + + shuffle_str + R"( + } + )"; + } + std::string GetSemiJoin(bool cond_check, + const std::vector& left_shuffle_index_list, + const std::vector& right_shuffle_index_list) { + std::stringstream ss; + for (auto i : left_shuffle_index_list) { + ss << "RETURN_NOT_OK(builder_0_" << i << "_->AppendNull());" << std::endl; + } + for (auto i : right_shuffle_index_list) { + ss << "RETURN_NOT_OK(builder_1_" << i << "_->Append(cached_1_" << i + << "_->GetView(i)));" << std::endl; + } + std::string shuffle_str; + if (cond_check) { + shuffle_str = R"( + for (auto tmp : (*memo_index_to_arrayid_)[index]) { + if (ConditionCheck(tmp, i)) { + )" + ss.str() + + R"( + out_length += 1; + break; + } + } + )"; + } else { + shuffle_str = R"( + )" + ss.str() + + R"( + out_length += 1; + )"; + } + return R"( + if (!typed_array->IsNull(i)) { + auto index = hash_table_->Get(typed_array->GetView(i)); + if (index != -1) { + )" + + shuffle_str + R"( + } + } + )"; + } + std::string GetProcessProbe(int join_type, bool cond_check, + const std::vector& left_shuffle_index_list, + const std::vector& right_shuffle_index_list) { + switch (join_type) { + case 0: { /*Inner Join*/ + return GetInnerJoin(cond_check, left_shuffle_index_list, + right_shuffle_index_list); + } break; + case 1: { /*Outer Join*/ + return GetOuterJoin(cond_check, left_shuffle_index_list, + right_shuffle_index_list); + } break; + case 2: { /*Anti Join*/ + return GetAntiJoin(cond_check, left_shuffle_index_list, right_shuffle_index_list); + } break; + case 3: { /*Semi Join*/ + return GetSemiJoin(cond_check, left_shuffle_index_list, right_shuffle_index_list); + } break; + default: + std::cout << "ConditionedProbeArraysTypedImpl only support join type: InnerJoin, " + "RightJoin" + << std::endl; + throw; + } + return ""; + } + std::string GetConditionCheckFunc( + const std::shared_ptr& func_node, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + std::vector* left_out_index_list, std::vector* right_out_index_list) { + std::shared_ptr func_node_visitor; + int func_count = 0; + std::stringstream codes_ss; + MakeCodeGenNodeVisitor(func_node, {left_field_list, right_field_list}, &func_count, + &codes_ss, left_out_index_list, right_out_index_list, + &func_node_visitor); + + return R"( + inline bool ConditionCheck(ArrayItemIndex x, int y) { + )" + codes_ss.str() + + R"( + return )" + + func_node_visitor->GetResult() + + R"(; + } + )"; + } + arrow::Status GetTypedProberCodeGen( + std::string prefix, bool left, const std::vector& index_list, + const std::vector>& field_list, + std::vector>* out_list) { + for (auto i : index_list) { + auto field = field_list[i]; + auto codegen = std::make_shared( + prefix + std::to_string(i), GetTypeString(field->type()), left); + (*out_list).push_back(codegen); + } + return arrow::Status::OK(); + } + std::vector MergeKeyIndexList(const std::vector& left_index_list, + const std::vector& right_index_list) { + std::vector ret = left_index_list; + for (auto i : right_index_list) { + if (std::find(left_index_list.begin(), left_index_list.end(), i) == + left_index_list.end()) { + ret.push_back(i); + } + } + std::sort(ret.begin(), ret.end()); + return ret; + } + std::string GetKeyCType(const std::vector& key_index_list, + const std::vector>& field_list) { + auto field = field_list[key_index_list[0]]; + return GetCTypeString(field->type()); + } + std::string GetTypedArray(bool multiple_cols, std::string index, int i, + std::string data_type, + std::string evaluate_encode_join_key_str) { + std::stringstream ss; + if (multiple_cols) { + ss << "auto concat_kernel_arr_list = {" << evaluate_encode_join_key_str << "};" + << std::endl; + ss << "std::shared_ptr hash_in;" << std::endl; + ss << "RETURN_NOT_OK(hash_kernel_->Evaluate(concat_kernel_arr_list, &hash_in));" + << std::endl; + ss << "auto typed_array = std::dynamic_pointer_cast(hash_in);" + << std::endl; + } else { + ss << "auto typed_array = std::dynamic_pointer_cast(in[" + << i << "]);" << std::endl; + } + return ss.str(); + } + std::string ProduceCodes( + const std::shared_ptr& func_node, int join_type, + const std::vector& left_key_index_list, + const std::vector& right_key_index_list, + const std::vector& left_shuffle_index_list, + const std::vector& right_shuffle_index_list, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + const std::vector>& result_schema_index_list) { + std::vector left_cond_index_list; + std::vector right_cond_index_list; + bool cond_check = false; + bool multiple_cols = (left_key_index_list.size() > 1); + std::string key_ctype_str = + "int64_t"; // multiple col will use gandiva hash to get int64_t + if (!multiple_cols) { + key_ctype_str = GetKeyCType(left_key_index_list, left_field_list); + } + std::string condition_check_str; + if (func_node) { + condition_check_str = + GetConditionCheckFunc(func_node, left_field_list, right_field_list, + &left_cond_index_list, &right_cond_index_list); + cond_check = true; + } + auto process_probe_str = GetProcessProbe( + join_type, cond_check, left_shuffle_index_list, right_shuffle_index_list); + auto left_cache_index_list = + MergeKeyIndexList(left_cond_index_list, left_shuffle_index_list); + auto right_cache_index_list = + MergeKeyIndexList(right_cond_index_list, right_shuffle_index_list); + + std::vector> left_cache_codegen_list; + std::vector> left_shuffle_codegen_list; + std::vector> right_shuffle_codegen_list; + GetTypedProberCodeGen("0_", true, left_cache_index_list, left_field_list, + &left_cache_codegen_list); + GetTypedProberCodeGen("0_", true, left_shuffle_index_list, left_field_list, + &left_shuffle_codegen_list); + GetTypedProberCodeGen("1_", false, right_shuffle_index_list, right_field_list, + &right_shuffle_codegen_list); + auto join_key_type_list_define_str = + GetJoinKeyTypeListDefine(left_key_index_list, left_field_list); + auto evaluate_cache_insert_str = GetEvaluateCacheInsert(left_cache_index_list); + auto evaluate_encode_join_key_str = GetEncodeJoinKey(left_key_index_list); + auto finish_cached_parameter_str = GetFinishCachedParameter(left_cache_index_list); + auto impl_cached_define_str = GetImplCachedDefine(left_cache_codegen_list); + auto result_iter_params_str = GetResultIteratorParams(left_cache_index_list); + auto result_iter_set_str = GetResultIteratorSet(left_cache_index_list); + auto result_iter_prepare_str = + GetResultIteratorPrepare(left_shuffle_codegen_list, right_shuffle_codegen_list); + auto process_right_set_str = GetProcessRightSet(right_cache_index_list); + auto process_encode_join_key_str = GetEncodeJoinKey(right_key_index_list); + auto process_finish_str = + GetProcessFinish(left_shuffle_codegen_list, right_shuffle_codegen_list); + auto process_out_list_str = GetProcessOutList( + result_schema_index_list, left_shuffle_codegen_list, right_shuffle_codegen_list); + auto result_iter_cached_define_str = + GetResultIterCachedDefine(left_cache_codegen_list, right_shuffle_codegen_list); + auto evaluate_get_typed_array_str = GetTypedArray( + multiple_cols, "0_" + std::to_string(left_key_index_list[0]), + left_key_index_list[0], + GetTypeString(left_field_list[left_key_index_list[0]]->type(), "Array"), + evaluate_encode_join_key_str); + auto process_get_typed_array_str = GetTypedArray( + multiple_cols, "1_" + std::to_string(right_key_index_list[0]), + right_key_index_list[0], + GetTypeString(left_field_list[left_key_index_list[0]]->type(), "Array"), + process_encode_join_key_str); + return BaseCodes() + R"( +//#include +//using HashMap = arrow::internal::ScalarMemoTable<)" + + key_ctype_str + R"(>; +using HashMap = SparseHashMap<)" + + key_ctype_str + R"(>; +class TypedProberImpl : public CodeGenBase { + public: + TypedProberImpl(arrow::compute::FunctionContext *ctx) : ctx_(ctx) { + hash_table_ = std::make_shared( + ctx_->memory_pool()); + )" + + (multiple_cols ? R"( + // Create Hash Kernel + auto type_list = {)" + join_key_type_list_define_str + + R"(}; + HashAggrArrayKernel::Make(ctx_, type_list, &hash_kernel_);)" + : "") + + R"( + } + ~TypedProberImpl() {} + + arrow::Status Evaluate(const ArrayList& in) override { + )" + evaluate_cache_insert_str + + evaluate_get_typed_array_str + + R"( + + auto insert_on_found = [this](int32_t i) { + memo_index_to_arrayid_[i].emplace_back(cur_array_id_, cur_id_); + }; + auto insert_on_not_found = [this](int32_t i) { + memo_index_to_arrayid_.push_back( + {ArrayItemIndex(cur_array_id_, cur_id_)}); + }; + + cur_id_ = 0; + int memo_index = 0; + if (typed_array->null_count() == 0) { + for (; cur_id_ < typed_array->length(); cur_id_++) { + hash_table_->GetOrInsert(typed_array->GetView(cur_id_), insert_on_found, + insert_on_not_found, &memo_index); + } + } else { + for (; cur_id_ < typed_array->length(); cur_id_++) { + if (typed_array->IsNull(cur_id_)) { + hash_table_->GetOrInsertNull(insert_on_found, insert_on_not_found); + } else { + hash_table_->GetOrInsert(typed_array->GetView(cur_id_), + insert_on_found, insert_on_not_found, + &memo_index); + } + } + } + cur_array_id_++; + return arrow::Status::OK(); + } + + arrow::Status MakeResultIterator( + std::shared_ptr schema, + std::shared_ptr> *out) override { + *out = std::make_shared( + ctx_, schema, hash_kernel_, hash_table_, &memo_index_to_arrayid_)" + + finish_cached_parameter_str + R"( + ); + return arrow::Status::OK(); + } + +private: + uint64_t cur_array_id_ = 0; + uint64_t cur_id_ = 0; + arrow::compute::FunctionContext *ctx_; + std::shared_ptr hash_kernel_; + std::shared_ptr hash_table_; + std::vector> memo_index_to_arrayid_; + )" + impl_cached_define_str + + R"( + + class ProberResultIterator : public ResultIterator { + public: + ProberResultIterator( + arrow::compute::FunctionContext *ctx, + std::shared_ptr schema, + std::shared_ptr hash_kernel, + std::shared_ptr hash_table, + std::vector> *memo_index_to_arrayid)" + + result_iter_params_str + R"( + ) + : ctx_(ctx), result_schema_(schema), hash_kernel_(hash_kernel), + hash_table_(hash_table), + memo_index_to_arrayid_(memo_index_to_arrayid) { + )" + + result_iter_set_str + result_iter_prepare_str + R"( + } + + std::string ToString() override { return "ProberResultIterator"; } + + arrow::Status + Process(const ArrayList &in, std::shared_ptr *out, + const std::shared_ptr &selection) override { + auto length = in[0]->length(); + uint64_t out_length = 0; + )" + process_right_set_str + + process_get_typed_array_str + + R"( + + for (int i = 0; i < length; i++) {)" + + process_probe_str + R"( + } + )" + process_finish_str + + R"( + *out = arrow::RecordBatch::Make( + result_schema_, out_length, + {)" + + process_out_list_str + R"(}); + //arrow::PrettyPrint(*(*out).get(), 2, &std::cout); + return arrow::Status::OK(); + } + + private: + arrow::compute::FunctionContext *ctx_; + std::shared_ptr result_schema_; + std::shared_ptr hash_kernel_; + std::shared_ptr hash_table_; + std::vector> *memo_index_to_arrayid_; +)" + result_iter_cached_define_str + + R"( + )" + condition_check_str + + R"( + }; +}; + +extern "C" void MakeCodeGen(arrow::compute::FunctionContext *ctx, + std::shared_ptr *out) { + *out = std::make_shared(ctx); +} + )"; + } +}; + +arrow::Status ConditionedProbeArraysKernel::Make( + arrow::compute::FunctionContext* ctx, + const std::vector>& left_key_list, + const std::vector>& right_key_list, + const std::shared_ptr& func_node, int join_type, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + const std::shared_ptr& result_schema, + std::shared_ptr* out) { + *out = std::make_shared( + ctx, left_key_list, right_key_list, func_node, join_type, left_field_list, + right_field_list, result_schema); + return arrow::Status::OK(); +} + +ConditionedProbeArraysKernel::ConditionedProbeArraysKernel( + arrow::compute::FunctionContext* ctx, + const std::vector>& left_key_list, + const std::vector>& right_key_list, + const std::shared_ptr& func_node, int join_type, + const std::vector>& left_field_list, + const std::vector>& right_field_list, + const std::shared_ptr& result_schema) { + impl_.reset(new Impl(ctx, left_key_list, right_key_list, func_node, join_type, + left_field_list, right_field_list, result_schema)); + kernel_name_ = "ConditionedProbeArraysKernel"; +} + +arrow::Status ConditionedProbeArraysKernel::Evaluate(const ArrayList& in) { + return impl_->Evaluate(in); +} + +arrow::Status ConditionedProbeArraysKernel::MakeResultIterator( + std::shared_ptr schema, + std::shared_ptr>* out) { + return impl_->MakeResultIterator(schema, out); +} +} // namespace extra +} // namespace arrowcompute +} // namespace codegen +} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/shuffle_v2_action.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/shuffle_v2_action.cc deleted file mode 100644 index dd788fbbb..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/shuffle_v2_action.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "codegen/arrow_compute/ext/shuffle_v2_action.h" - -#include - -#include "codegen/arrow_compute/ext/array_item_index.h" - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { -using namespace arrow; - -template -class ShuffleV2ActionTypedImpl; - -class ShuffleV2Action::Impl { - public: -#define PROCESS_SUPPORTED_TYPES(PROCESS) \ - PROCESS(UInt8Type) \ - PROCESS(Int8Type) \ - PROCESS(UInt16Type) \ - PROCESS(Int16Type) \ - PROCESS(UInt32Type) \ - PROCESS(Int32Type) \ - PROCESS(UInt64Type) \ - PROCESS(Int64Type) \ - PROCESS(FloatType) \ - PROCESS(DoubleType) \ - PROCESS(Date32Type) - static arrow::Status MakeShuffleV2ActionImpl(arrow::compute::FunctionContext* ctx, - std::shared_ptr type, - bool is_arr_list, - std::shared_ptr* out) { - switch (type->id()) { -#define PROCESS(InType) \ - case InType::type_id: { \ - using CType = typename TypeTraits::CType; \ - auto res = \ - std::make_shared>(ctx, is_arr_list); \ - *out = std::dynamic_pointer_cast(res); \ - } break; - PROCESS_SUPPORTED_TYPES(PROCESS) -#undef PROCESS - case arrow::StringType::type_id: { - auto res = std::make_shared>( - ctx, is_arr_list); - *out = std::dynamic_pointer_cast(res); - } break; - default: { - std::cout << "Not Found " << type->ToString() << ", type id is " << type->id() - << std::endl; - } break; - } - return arrow::Status::OK(); - } -#undef PROCESS_SUPPORTED_TYPES - virtual arrow::Status Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* func) = 0; - virtual arrow::Status Submit(ArrayList in_arr_list, - std::shared_ptr selection, - std::function* func) = 0; - virtual arrow::Status FinishAndReset(ArrayList* out) = 0; -}; - -template -class ShuffleV2ActionTypedImpl : public ShuffleV2Action::Impl { - public: - ShuffleV2ActionTypedImpl(arrow::compute::FunctionContext* ctx, bool is_arr_list) - : ctx_(ctx) { -#ifdef DEBUG - std::cout << "Construct ShuffleV2ActionTypedImpl" << std::endl; -#endif - std::unique_ptr builder; - arrow::MakeBuilder(ctx_->memory_pool(), arrow::TypeTraits::type_singleton(), - &builder); - builder_.reset(arrow::internal::checked_cast(builder.release())); - if (is_arr_list) { - exec_ = [this]() { - auto item = structed_selection_[row_id_]; - if (!typed_in_arr_list_[item.array_id]->IsNull(item.id)) { - RETURN_NOT_OK( - builder_->Append(typed_in_arr_list_[item.array_id]->GetView(item.id))); - } else { - RETURN_NOT_OK(builder_->AppendNull()); - } - row_id_++; - return arrow::Status::OK(); - }; - - nullable_exec_ = [this]() { - if (!selection_->IsNull(row_id_)) { - auto item = structed_selection_[row_id_]; - if (!typed_in_arr_list_[item.array_id]->IsNull(item.id)) { - RETURN_NOT_OK( - builder_->Append(typed_in_arr_list_[item.array_id]->GetView(item.id))); - } else { - RETURN_NOT_OK(builder_->AppendNull()); - } - } else { - RETURN_NOT_OK(builder_->AppendNull()); - } - row_id_++; - return arrow::Status::OK(); - }; - - } else { - exec_ = [this]() { - auto item = uint32_selection_[row_id_]; - if (!typed_in_arr_->IsNull(item)) { - RETURN_NOT_OK(builder_->Append(typed_in_arr_->GetView(item))); - } else { - RETURN_NOT_OK(builder_->AppendNull()); - } - row_id_++; - return arrow::Status::OK(); - }; - nullable_exec_ = [this]() { - if (!selection_->IsNull(row_id_)) { - auto item = uint32_selection_[row_id_]; - if (!typed_in_arr_->IsNull(item)) { - RETURN_NOT_OK(builder_->Append(typed_in_arr_->GetView(item))); - } else { - RETURN_NOT_OK(builder_->AppendNull()); - } - } else { - RETURN_NOT_OK(builder_->AppendNull()); - } - row_id_++; - return arrow::Status::OK(); - }; - } - } - ~ShuffleV2ActionTypedImpl() { -#ifdef DEBUG - std::cout << "Destruct ShuffleV2ActionTypedImpl" << std::endl; -#endif - } - - arrow::Status Submit(ArrayList in_arr_list, std::shared_ptr selection, - std::function* exec) { - row_id_ = 0; - if (typed_in_arr_list_.size() == 0) { - for (auto arr : in_arr_list) { - typed_in_arr_list_.push_back(std::dynamic_pointer_cast(arr)); - } - } - selection_ = selection; - structed_selection_ = - (ArrayItemIndex*)std::dynamic_pointer_cast(selection) - ->raw_values(); - if (selection->null_count() == 0) { - *exec = exec_; - } else { - *exec = nullable_exec_; - } - return arrow::Status::OK(); - } - - arrow::Status Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* exec) { - row_id_ = 0; - typed_in_arr_ = std::dynamic_pointer_cast(in_arr); - selection_ = selection; - uint32_selection_ = - (uint32_t*)std::dynamic_pointer_cast(selection)->raw_values(); - if (selection->null_count() == 0) { - *exec = exec_; - } else { - *exec = nullable_exec_; - } - return arrow::Status::OK(); - } - - arrow::Status FinishAndReset(ArrayList* out) { - std::shared_ptr arr_out; - RETURN_NOT_OK(builder_->Finish(&arr_out)); - out->push_back(arr_out); - builder_->Reset(); - return arrow::Status::OK(); - } - - private: - using ArrayType = typename arrow::TypeTraits::ArrayType; - using BuilderType = typename arrow::TypeTraits::BuilderType; - // input - arrow::compute::FunctionContext* ctx_; - int arg_id_; - uint64_t row_id_ = 0; - std::vector> typed_in_arr_list_; - std::shared_ptr typed_in_arr_; - std::shared_ptr selection_; - ArrayItemIndex* structed_selection_; - uint32_t* uint32_selection_; - // result - std::function exec_; - std::function nullable_exec_; - std::shared_ptr builder_; -}; - -ShuffleV2Action::ShuffleV2Action(arrow::compute::FunctionContext* ctx, - std::shared_ptr type, - bool is_arr_list) { - auto status = Impl::MakeShuffleV2ActionImpl(ctx, type, is_arr_list, &impl_); -} - -ShuffleV2Action::~ShuffleV2Action() {} - -arrow::Status ShuffleV2Action::Submit(ArrayList in_arr_list, - std::shared_ptr selection, - std::function* func) { - RETURN_NOT_OK(impl_->Submit(in_arr_list, selection, func)); - return arrow::Status::OK(); -} - -arrow::Status ShuffleV2Action::Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* func) { - RETURN_NOT_OK(impl_->Submit(in_arr, selection, func)); - return arrow::Status::OK(); -} - -arrow::Status ShuffleV2Action::FinishAndReset(ArrayList* out) { - RETURN_NOT_OK(impl_->FinishAndReset(out)); - return arrow::Status::OK(); -} -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/shuffle_v2_action.h b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/shuffle_v2_action.h deleted file mode 100644 index adc1965f6..000000000 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/shuffle_v2_action.h +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace sparkcolumnarplugin { -namespace codegen { -namespace arrowcompute { -namespace extra { -using ArrayList = std::vector>; -//////////////// ShuffleV2Action /////////////// -class ShuffleV2Action { - public: - ShuffleV2Action(arrow::compute::FunctionContext* ctx, - std::shared_ptr type, bool is_arr_list); - ~ShuffleV2Action(); - arrow::Status Submit(ArrayList in_arr_list, std::shared_ptr selection, - std::function* func); - arrow::Status Submit(std::shared_ptr in_arr, - std::shared_ptr selection, - std::function* func); - arrow::Status FinishAndReset(ArrayList* out); - class Impl; - - private: - std::shared_ptr impl_; -}; - -///////////////////// Public Functions ////////////////// -static arrow::Status MakeShuffleV2Action(arrow::compute::FunctionContext* ctx, - std::shared_ptr type, - bool is_arr_list, - std::shared_ptr* out) { - *out = std::make_shared(ctx, type, is_arr_list); - return arrow::Status::OK(); -} -} // namespace extra -} // namespace arrowcompute -} // namespace codegen -} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc index 31253208b..25077e533 100644 --- a/oap-native-sql/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc +++ b/oap-native-sql/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc @@ -15,15 +15,7 @@ * limitations under the License. */ -#include -#include -#include -#include #include -#include -#include -#include -#include #include #include #include @@ -38,7 +30,6 @@ #include "codegen/arrow_compute/ext/array_item_index.h" #include "codegen/arrow_compute/ext/code_generator_base.h" #include "codegen/arrow_compute/ext/codegen_common.h" -#include "codegen/arrow_compute/ext/item_iterator.h" #include "codegen/arrow_compute/ext/kernels_ext.h" namespace sparkcolumnarplugin { @@ -64,29 +55,59 @@ class SortArraysToIndicesKernel::Impl { } key_index_list_.push_back(indices[0]); } + } + virtual arrow::Status LoadJITFunction( + std::vector> key_field_list, + std::shared_ptr result_schema) { + // generate ddl signature + std::stringstream func_args_ss; + func_args_ss << "[Sorter]" << (nulls_first_ ? "nulls_first" : "nulls_last") << "|" + << (asc_ ? "asc" : "desc"); + int i = 0; + for (auto field : key_field_list) { + func_args_ss << "[sort_key_" << i << "]" << field->ToString(); + } + + func_args_ss << "[schema]" << result_schema->ToString(); + + //#ifdef DEBUG + std::cout << "func_args_ss is " << func_args_ss.str() << std::endl; + //#endif + + std::stringstream signature_ss; + signature_ss << std::hex << std::hash{}(func_args_ss.str()); + std::string signature = signature_ss.str(); - auto status = LoadJITFunction(key_field_list, result_schema, &sorter); + auto file_lock = FileSpinLock(); + auto status = LoadLibrary(signature, ctx_, &sorter); if (!status.ok()) { - std::cout << "LoadJITFunction failed, msg is " << status.message() << std::endl; - throw; + // process + auto codes = ProduceCodes(result_schema); + // compile codes + RETURN_NOT_OK(CompileCodes(codes, signature)); + RETURN_NOT_OK(LoadLibrary(signature, ctx_, &sorter)); } + FileSpinUnLock(file_lock); + return arrow::Status::OK(); } - arrow::Status Evaluate(const ArrayList& in) { + virtual arrow::Status Evaluate(const ArrayList& in) { RETURN_NOT_OK(sorter->Evaluate(in)); return arrow::Status::OK(); } - arrow::Status MakeResultIterator( + virtual arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) { RETURN_NOT_OK(sorter->MakeResultIterator(schema, out)); return arrow::Status::OK(); } - arrow::Status Finish(std::shared_ptr* out) { return arrow::Status::OK(); } + virtual arrow::Status Finish(std::shared_ptr* out) { + return arrow::Status::OK(); + } - private: + protected: std::shared_ptr sorter; arrow::compute::FunctionContext* ctx_; bool nulls_first_; @@ -147,38 +168,7 @@ class SortArraysToIndicesKernel::Impl { std::string name_; }; - arrow::Status LoadJITFunction(std::vector> key_field_list, - std::shared_ptr result_schema, - std::shared_ptr* out) { - // generate ddl signature - std::stringstream func_args_ss; - func_args_ss << "[Sorter]" << (nulls_first_ ? "nulls_first" : "nulls_last") << "|" - << (asc_ ? "asc" : "desc"); - int i = 0; - for (auto field : key_field_list) { - func_args_ss << "[sort_key_" << i << "]" << field->ToString(); - } - - func_args_ss << "[schema]" << result_schema->ToString(); - - std::stringstream signature_ss; - signature_ss << std::hex << std::hash{}(func_args_ss.str()); - std::string signature = signature_ss.str(); - - auto file_lock = FileSpinLock("/tmp"); - auto status = LoadLibrary(signature, ctx_, out); - if (!status.ok()) { - // process - auto codes = ProduceCodes(result_schema); - // compile codes - RETURN_NOT_OK(CompileCodes(codes, signature)); - RETURN_NOT_OK(LoadLibrary(signature, ctx_, out)); - } - FileSpinUnLock(file_lock); - return arrow::Status::OK(); - } - - std::string ProduceCodes(std::shared_ptr result_schema) { + virtual std::string ProduceCodes(std::shared_ptr result_schema) { int indice = 0; std::vector> shuffle_typed_codegen_list; for (auto field : result_schema->fields()) { @@ -194,7 +184,7 @@ class SortArraysToIndicesKernel::Impl { std::string pre_sort_null_str = GetPreSortNull(); - std::string sort_func_str = GetSortFunction(); + std::string sort_func_str = GetSortFunction(key_index_list_); std::string make_result_iter_str = GetMakeResultIter(shuffle_typed_codegen_list.size()); @@ -218,6 +208,8 @@ class SortArraysToIndicesKernel::Impl { std::string typed_res_array_str = GetTypedResArray(shuffle_typed_codegen_list.size()); return BaseCodes() + R"( +#include "third_party/ska_sort.hpp" + class TypedSorterImpl : public CodeGenBase { public: TypedSorterImpl(arrow::compute::FunctionContext* ctx) : ctx_(ctx) {} @@ -315,7 +307,9 @@ class TypedSorterImpl : public CodeGenBase { } arrow::Status Next(std::shared_ptr* out) { - auto length = (total_length_ - offset_) > 4096 ? 4096 : (total_length_ - offset_); + auto length = (total_length_ - offset_) > )" + + std::to_string(GetBatchSize()) + R"( ? )" + std::to_string(GetBatchSize()) + + R"( : (total_length_ - offset_); uint64_t count = 0; while (count < length) { auto item = indices_begin_ + offset_ + count++; @@ -393,7 +387,6 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, return R"( (indices_begin + nulls_total_ + indices_i)->array_id = array_id; (indices_begin + nulls_total_ + indices_i)->id = i;)"; - } else { return R"( (indices_begin + indices_i)->array_id = array_id; @@ -411,14 +404,40 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, (indices_end - nulls_total_ + indices_null)->id = i;)"; } } - std::string GetSortFunction() { - if (nulls_first_) { - return "std::sort(indices_begin + nulls_total_, indices_begin + " - "items_total_, " - "comp);"; + std::string GetSortFunction(std::vector& key_index_list) { + if (asc_) { + if (key_index_list.size() == 1) { + if (nulls_first_) { + return "ska_sort(indices_begin + nulls_total_, indices_begin + " + "items_total_, " + "[this](auto& x) -> decltype(auto){ return cached_" + + std::to_string(key_index_list[0]) + "_[x.array_id]->GetView(x.id); });"; + } else { + return "ska_sort(indices_begin, indices_begin + items_total_ - " + "nulls_total_, " + "[this](auto& x) -> decltype(auto){ return cached_" + + std::to_string(key_index_list[0]) + "_[x.array_id]->GetView(x.id); });"; + } + + } else { + if (nulls_first_) { + return "std::sort(indices_begin + nulls_total_, indices_begin + " + "items_total_, " + "comp);"; + } else { + return "std::sort(indices_begin, indices_begin + items_total_ - " + "nulls_total_, comp);"; + } + } } else { - return "std::sort(indices_begin, indices_begin + items_total_ - " - "nulls_total_, comp);"; + if (nulls_first_) { + return "std::sort(indices_begin + nulls_total_, indices_begin + " + "items_total_, " + "comp);"; + } else { + return "std::sort(indices_begin, indices_begin + items_total_ - " + "nulls_total_, comp);"; + } } } std::string GetMakeResultIter(int shuffle_size) { @@ -514,6 +533,329 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, } }; +/////////////// SortArraysInPlace //////////////// +class SortInplaceKernel : public SortArraysToIndicesKernel::Impl { + public: + SortInplaceKernel(arrow::compute::FunctionContext* ctx, + std::vector> key_field_list, + std::shared_ptr result_schema, bool nulls_first, + bool asc) + : Impl(ctx, key_field_list, result_schema, nulls_first, asc) { + auto indices = result_schema->GetAllFieldIndices(key_field_list[0]->name()); + if (indices.size() != 1) { + std::cout << "[ERROR] SortArraysToIndicesKernel::Impl can't find key " + << key_field_list[0]->ToString() << " from " << result_schema->ToString() + << std::endl; + throw; + } + } + + arrow::Status Evaluate(const ArrayList& in) override { + RETURN_NOT_OK(sorter->Evaluate(in)); + return arrow::Status::OK(); + } + + arrow::Status MakeResultIterator( + std::shared_ptr schema, + std::shared_ptr>* out) override { + RETURN_NOT_OK(sorter->MakeResultIterator(schema, out)); + return arrow::Status::OK(); + } + + arrow::Status Finish(std::shared_ptr* out) override { + return arrow::Status::OK(); + } + + private: + class TypedSorterCodeGenImpl { + public: + TypedSorterCodeGenImpl(std::shared_ptr dataType, std::string indice, + std::string name) + : indice_(indice), + dataTypeName_(GetTypeString(dataType)), + name_(name), + dataType_(dataType) {} + std::string GetCTypeName() { return GetCTypeString(dataType_); } + std::string GetResultIterDefine() { + return "std::unique_ptr builder_" + indice_ + + ";\n" + "arrow::MakeBuilder(ctx_->memory_pool(), data_type_" + + indice_ + ", &builder_" + indice_ + + ");\n" + "builder_" + + indice_ + "_.reset(arrow::internal::checked_cast(builder_" + indice_ + ".release()));\n"; + } + std::string GetFieldDefine() { + return "arrow::field(\"" + name_ + "\", data_type_" + indice_ + ")"; + } + std::string GetResultIterVariables() { + return R"( + using DataType_)" + + indice_ + R"( = typename arrow::)" + dataTypeName_ + R"(; + using ArrayType_)" + + indice_ + R"( = typename arrow::TypeTraits::ArrayType; + using BuilderType_)" + + indice_ + R"( = typename arrow::TypeTraits::BuilderType; + std::shared_ptr data_type_)" + + indice_ + R"( = arrow::TypeTraits::type_singleton(); + std::vector> cached_)" + indice_ + R"(_; + std::shared_ptr builder_)" + indice_ + R"(_; + )"; + } + + private: + std::string indice_; + std::string dataTypeName_; + std::string name_; + std::shared_ptr dataType_; + }; + + std::string ProduceCodes(std::shared_ptr result_schema) override { + int indice = 0; + std::vector> typed_codegen_list; + for (auto field : result_schema->fields()) { + auto codegen = std::make_shared( + field->type(), std::to_string(indice), field->name()); + typed_codegen_list.push_back(codegen); + indice++; + } + std::string ctype_str = typed_codegen_list[0]->GetCTypeName(); + + std::string comp_func_str = GetCompFunction(ctype_str); + + std::string sort_func_str = GetSortFunction(); + + std::string partition_func_str = GetPartitionFunction(); + + std::string result_iter_define_str = GetResultIterDefine(typed_codegen_list); + + std::string result_variables_define_str = GetResultIterVariables(typed_codegen_list); + + std::string typed_res_array_build_str = + GetTypedResArrayBuild(typed_codegen_list.size()); + + std::string typed_res_array_str = GetTypedResArray(typed_codegen_list.size()); + + return BaseCodes() + R"( +#include + +#include "third_party/ska_sort.hpp" + +class TypedSorterImpl : public CodeGenBase { + public: + TypedSorterImpl(arrow::compute::FunctionContext* ctx) : ctx_(ctx) {} + + arrow::Status Evaluate(const ArrayList& in) override { + num_batches_++; + items_total_ += in[0]->length(); + nulls_total_ += in[0]->null_count(); + cached_0_.push_back(in[0]); + return arrow::Status::OK(); + } + + arrow::Status MakeResultIterator( + std::shared_ptr schema, + std::shared_ptr>* out) override { + )" + comp_func_str + + R"( + RETURN_NOT_OK(arrow::Concatenate(cached_0_, ctx_->memory_pool(), &concatenated_array_)); + )" + ctype_str + + " *indices_begin = concatenated_array_->data()->GetMutableValues<" + + ctype_str + ">(1);\n" + ctype_str + + R"(* indices_end = indices_begin + concatenated_array_->length(); + if (nulls_total_ > 0) { + )" + partition_func_str + + R"( + } + + )" + sort_func_str + + R"( + *out = std::make_shared(ctx_, indices_begin, nulls_total_, + items_total_); + return arrow::Status::OK(); + } + + private: + arrow::ArrayVector cached_0_; + std::shared_ptr concatenated_array_; + arrow::compute::FunctionContext* ctx_; + uint64_t num_batches_ = 0; + uint64_t items_total_ = 0; + uint64_t nulls_total_ = 0; + + class SorterResultIterator : public ResultIterator { + public: + SorterResultIterator(arrow::compute::FunctionContext* ctx, + )" + + ctype_str + R"(*indices_begin, uint64_t nulls_total, uint64_t length) + : ctx_(ctx), total_length_(length), nulls_total_(nulls_total), indices_begin_(indices_begin) { + )" + result_iter_define_str + + R"( + } + + std::string ToString() override { return "SortArraysToIndicesResultIterator"; } + + bool HasNext() override { + if (offset_ >= total_length_) { + return false; + } + return true; + } + + arrow::Status Next(std::shared_ptr* out) { + auto length = (total_length_ - offset_) > )" + + std::to_string(GetBatchSize()) + R"( ? )" + std::to_string(GetBatchSize()) + + R"( : (total_length_ - offset_); + uint64_t count = 0; + if (offset_ >= nulls_total_) { + while (count < length){ + RETURN_NOT_OK(builder_0_->Append(indices_begin_[offset_ + count++])); + } + } else { + while (count < length) { + if ((offset_ + count) < nulls_total_) { + RETURN_NOT_OK(builder_0_->AppendNull()); + } else { + RETURN_NOT_OK(builder_0_->Append(indices_begin_[offset_ + count++])); + } + } + } + offset_ += length; + )" + typed_res_array_build_str + + R"( + *out = arrow::RecordBatch::Make(result_schema_, length, {)" + + typed_res_array_str + R"(}); + return arrow::Status::OK(); + } + + private: + )" + result_variables_define_str + + R"( + std::shared_ptr indices_in_cache_; + uint64_t offset_ = 0; + )" + ctype_str + + R"(* indices_begin_; + const uint64_t total_length_; + const uint64_t nulls_total_; + std::shared_ptr result_schema_; + arrow::compute::FunctionContext* ctx_; + }; +}; + +extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, + std::shared_ptr* out) { + *out = std::make_shared(ctx); +} + + )"; + } + std::string GetCompFunction(std::string data_type) { + std::stringstream ss; + if (asc_) { + ss << "auto comp = [this](" << data_type << "& x, " << data_type << "& y) {" + << "return x < y; };"; + + } else { + ss << "auto comp = [this](" << data_type << "& x, " << data_type << "& y) {" + << "return x > y; };"; + } + return ss.str(); + } + std::string GetPartitionFunction() { + if (nulls_first_) { + return "std::stable_partition(indices_begin, indices_end, [this](auto ind) {return " + "concatenated_array_->IsNull(ind);});"; + } else { + return "std::stable_partition(indices_begin, indices_end, [this](auto ind) {return " + "!concatenated_array_->IsNull(ind);});"; + } + } + std::string GetSortFunction() { + if (asc_) { + if (nulls_first_) { + return "ska_sort(indices_begin + nulls_total_, indices_begin + " + "items_total_);"; + } else { + return "ska_sort(indices_begin, indices_begin + items_total_ - " + "nulls_total_);"; + } + } else { + if (nulls_first_) { + return "std::sort(indices_begin + nulls_total_, indices_begin + " + "items_total_, comp);"; + } else { + return "std::sort(indices_begin, indices_begin + items_total_ - " + "nulls_total_, comp);"; + } + } + } + std::string GetMakeResultIter(int shuffle_size) { + std::stringstream ss; + std::stringstream params_ss; + for (int i = 0; i < shuffle_size; i++) { + if (i + 1 < shuffle_size) { + params_ss << "cached_" << i << "_,"; + } else { + params_ss << "cached_" << i << "_"; + } + } + auto params = params_ss.str(); + ss << "*out = std::make_shared(ctx_, indices_out, " << params + << ");"; + return ss.str(); + } + std::string GetResultIterDefine( + std::vector> shuffle_typed_codegen_list) { + std::stringstream ss; + std::stringstream field_define_ss; + for (auto codegen : shuffle_typed_codegen_list) { + ss << codegen->GetResultIterDefine() << std::endl; + if (codegen != *(shuffle_typed_codegen_list.end() - 1)) { + field_define_ss << codegen->GetFieldDefine() << ","; + } else { + field_define_ss << codegen->GetFieldDefine(); + } + } + ss << "result_schema_ = arrow::schema({" << field_define_ss.str() << "});\n" + << std::endl; + return ss.str(); + } + std::string GetTypedResArrayBuild(int shuffle_size) { + std::stringstream ss; + for (int i = 0; i < shuffle_size; i++) { + ss << "std::shared_ptr out_" << i << ";\n" + << "RETURN_NOT_OK(builder_" << i << "_->Finish(&out_" << i << "));\n" + << "builder_" << i << "_->Reset();" << std::endl; + } + return ss.str(); + } + std::string GetTypedResArray(int shuffle_size) { + std::stringstream ss; + for (int i = 0; i < shuffle_size; i++) { + if (i + 1 < shuffle_size) { + ss << "out_" << i << ", "; + } else { + ss << "out_" << i; + } + } + return ss.str(); + } + std::string GetResultIterVariables( + std::vector> shuffle_typed_codegen_list) { + std::stringstream ss; + for (auto codegen : shuffle_typed_codegen_list) { + ss << codegen->GetResultIterVariables() << std::endl; + } + return ss.str(); + } +}; + arrow::Status SortArraysToIndicesKernel::Make( arrow::compute::FunctionContext* ctx, std::vector> key_field_list, @@ -528,7 +870,18 @@ SortArraysToIndicesKernel::SortArraysToIndicesKernel( arrow::compute::FunctionContext* ctx, std::vector> key_field_list, std::shared_ptr result_schema, bool nulls_first, bool asc) { - impl_.reset(new Impl(ctx, key_field_list, result_schema, nulls_first, asc)); + if (key_field_list.size() == 1 && result_schema->num_fields() == 1) { + std::cout << "UseSortInplace" << std::endl; + impl_.reset( + new SortInplaceKernel(ctx, key_field_list, result_schema, nulls_first, asc)); + } else { + impl_.reset(new Impl(ctx, key_field_list, result_schema, nulls_first, asc)); + } + auto status = impl_->LoadJITFunction(key_field_list, result_schema); + if (!status.ok()) { + std::cout << "LoadJITFunction failed, msg is " << status.message() << std::endl; + throw; + } kernel_name_ = "SortArraysToIndicesKernelKernel"; } #undef PROCESS_SUPPORTED_TYPES diff --git a/oap-native-sql/cpp/src/codegen/common/result_iterator.h b/oap-native-sql/cpp/src/codegen/common/result_iterator.h index a179e02fc..1625ffe0c 100644 --- a/oap-native-sql/cpp/src/codegen/common/result_iterator.h +++ b/oap-native-sql/cpp/src/codegen/common/result_iterator.h @@ -29,12 +29,12 @@ class ResultIterator { return arrow::Status::NotImplemented("ResultIterator abstract Next()"); } virtual arrow::Status Process( - std::vector> in, std::shared_ptr* out, + const std::vector>& in, std::shared_ptr* out, const std::shared_ptr& selection = nullptr) { return arrow::Status::NotImplemented("ResultIterator abstract Process()"); } virtual arrow::Status ProcessAndCacheOne( - std::vector> in, + const std::vector>& in, const std::shared_ptr& selection = nullptr) { return arrow::Status::NotImplemented("ResultIterator abstract ProcessAndCacheOne()"); } diff --git a/oap-native-sql/cpp/src/jni/jni_common.h b/oap-native-sql/cpp/src/jni/jni_common.h index 5a56dd588..fff791cad 100644 --- a/oap-native-sql/cpp/src/jni/jni_common.h +++ b/oap-native-sql/cpp/src/jni/jni_common.h @@ -207,3 +207,45 @@ jbyteArray ToSchemaByteArray(JNIEnv* env, std::shared_ptr schema) env->SetByteArrayRegion(out, 0, buffer->size(), src); return out; } + +arrow::Result> DecompressBuffer( + jlong in_addr, jlong in_size, arrow::Compression::type compression_codec) { + if (compression_codec == arrow::Compression::UNCOMPRESSED) { + return std::shared_ptr( + new arrow::Buffer(reinterpret_cast(in_addr), in_size)); + } + + if (in_size == 0) { + return nullptr; + } + + if (in_size < 8) { + return Status::Invalid( + "Likely corrupted message, compressed buffers " + "are larger than 8 bytes by construction"); + } + + auto data = reinterpret_cast(in_addr); + + std::unique_ptr codec; + ARROW_ASSIGN_OR_RAISE(codec, arrow::util::Codec::Create(compression_codec)); + + int64_t compressed_size = in_size - sizeof(int64_t); + int64_t uncompressed_size = arrow::util::SafeLoadAs(data); + + std::shared_ptr uncompressed; + ARROW_ASSIGN_OR_RAISE(uncompressed, arrow::AllocateBuffer(uncompressed_size)) + + int64_t actual_decompressed; + ARROW_ASSIGN_OR_RAISE( + actual_decompressed, + codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size, + uncompressed->mutable_data())); + + if (actual_decompressed != uncompressed_size) { + return Status::Invalid("Failed to fully decompress buffer, expected ", + uncompressed_size, " bytes but decompressed ", + actual_decompressed); + } + return uncompressed; +} diff --git a/oap-native-sql/cpp/src/jni/jni_wrapper.cc b/oap-native-sql/cpp/src/jni/jni_wrapper.cc index d2a3bf573..015904e69 100644 --- a/oap-native-sql/cpp/src/jni/jni_wrapper.cc +++ b/oap-native-sql/cpp/src/jni/jni_wrapper.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include "codegen/common/result_iterator.h" #include "jni/concurrent_map.h" #include "jni/jni_common.h" +#include "shuffle/splitter.h" namespace types { class ExpressionList; @@ -47,6 +49,9 @@ static jmethodID arrow_field_node_builder_constructor; static jclass arrowbuf_builder_class; static jmethodID arrowbuf_builder_constructor; +static jclass partition_file_info_class; +static jmethodID partition_file_info_constructor; + using arrow::jni::ConcurrentMap; static ConcurrentMap> buffer_holder_; @@ -57,6 +62,11 @@ static arrow::jni::ConcurrentMap> handler_holder_ static arrow::jni::ConcurrentMap>> batch_iterator_holder_; +using sparkcolumnarplugin::shuffle::Splitter; +static arrow::jni::ConcurrentMap> shuffle_splitter_holder_; +static arrow::jni::ConcurrentMap> + decompression_schema_holder_; + std::shared_ptr GetCodeGenerator(JNIEnv* env, jlong id) { auto handler = handler_holder_.Lookup(id); if (!handler) { @@ -76,6 +86,16 @@ std::shared_ptr> GetBatchIterator(JNIEnv* env return handler; } +std::shared_ptr GetShuffleSplitter(JNIEnv* env, jlong id) { + auto splitter = shuffle_splitter_holder_.Lookup(id); + if (!splitter) { + std::string error_message = "invalid reader id " + std::to_string(id); + env->ThrowNew(illegal_argument_exception_class, error_message.c_str()); + } + + return splitter; +} + jobject MakeRecordBatchBuilder(JNIEnv* env, std::shared_ptr schema, std::shared_ptr record_batch) { jobjectArray field_array = @@ -179,6 +199,11 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { arrowbuf_builder_constructor = GetMethodID(env, arrowbuf_builder_class, "", "(JJIJ)V"); + partition_file_info_class = CreateGlobalClassReference( + env, "Lcom/intel/sparkColumnarPlugin/vectorized/PartitionFileInfo;"); + partition_file_info_constructor = + GetMethodID(env, partition_file_info_class, "", "(ILjava/lang/String;)V"); + return JNI_VERSION; } @@ -194,10 +219,28 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(arrow_field_node_builder_class); env->DeleteGlobalRef(arrowbuf_builder_class); env->DeleteGlobalRef(arrow_record_batch_builder_class); + env->DeleteGlobalRef(partition_file_info_class); buffer_holder_.Clear(); handler_holder_.Clear(); batch_iterator_holder_.Clear(); + shuffle_splitter_holder_.Clear(); + decompression_schema_holder_.Clear(); +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ExpressionEvaluatorJniWrapper_nativeSetJavaTmpDir( + JNIEnv* env, jobject obj, jstring pathObj) { + jboolean ifCopy; + auto path = env->GetStringUTFChars(pathObj, &ifCopy); + setenv("NATIVESQL_TMP_DIR", path, 1); + env->ReleaseStringUTFChars(pathObj, path); +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ExpressionEvaluatorJniWrapper_nativeSetBatchSize( + JNIEnv* env, jobject obj, jint batch_size) { + setenv("NATIVESQL_BATCH_SIZE", std::to_string(batch_size).c_str(), 1); } JNIEXPORT jlong JNICALL @@ -1041,6 +1084,289 @@ Java_com_intel_sparkColumnarPlugin_datasource_parquet_ParquetWriterJniWrapper_na env->ReleaseLongArrayElements(bufAddrs, in_buf_addrs, JNI_ABORT); env->ReleaseLongArrayElements(bufSizes, in_buf_sizes, JNI_ABORT); } + +JNIEXPORT jlong JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_make( + JNIEnv* env, jobject, jbyteArray schema_arr, jlong buffer_size, jstring pathObj) { + std::shared_ptr schema; + arrow::Status status; + + auto joined_path = env->GetStringUTFChars(pathObj, JNI_FALSE); + setenv("NATIVESQL_SPARK_LOCAL_DIRS", joined_path, 1); + + env->ReleaseStringUTFChars(pathObj, joined_path); + + status = MakeSchema(env, schema_arr, &schema); + if (!status.ok()) { + env->ThrowNew( + io_exception_class, + std::string("failed to readSchema, err msg is " + status.message()).c_str()); + } + + auto result = Splitter::Make(schema); + if (!result.ok()) { + env->ThrowNew(io_exception_class, + std::string("Failed create native shuffle splitter").c_str()); + } + + (*result)->set_buffer_size(buffer_size); + + return shuffle_splitter_holder_.Insert(std::shared_ptr(*result)); +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_split( + JNIEnv* env, jobject, jlong splitter_id, jint num_rows, jlongArray buf_addrs, + jlongArray buf_sizes) { + auto splitter = GetShuffleSplitter(env, splitter_id); + + int in_bufs_len = env->GetArrayLength(buf_addrs); + if (in_bufs_len != env->GetArrayLength(buf_sizes)) { + std::string error_message = + "native split: mismatch in arraylen of buf_addrs and buf_sizes"; + env->ThrowNew(io_exception_class, error_message.c_str()); + } + jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, JNI_FALSE); + jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, JNI_FALSE); + + std::shared_ptr in; + auto status = MakeRecordBatch(splitter->schema(), num_rows, (int64_t*)in_buf_addrs, + (int64_t*)in_buf_sizes, in_bufs_len, &in); + + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); + env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); + + if (!status.ok()) { + env->ThrowNew(io_exception_class, + std::string("native split: make record batch failed").c_str()); + } + + status = splitter->Split(*in); + + if (!status.ok()) { + env->ThrowNew(io_exception_class, + std::string("native split: splitter split failed").c_str()); + } +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_stop( + JNIEnv* env, jobject, jlong splitter_id) { + auto splitter = GetShuffleSplitter(env, splitter_id); + auto status = splitter->Stop(); + + if (!status.ok()) { + env->ThrowNew(io_exception_class, + std::string("native split: splitter stop failed, error message is " + + status.message()) + .c_str()); + } +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_setPartitionBufferSize( + JNIEnv* env, jobject, jlong splitter_id, jlong buffer_size) { + auto splitter = GetShuffleSplitter(env, splitter_id); + + splitter->set_buffer_size((int64_t)buffer_size); +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_setCompressionCodec( + JNIEnv* env, jobject, jlong splitter_id, jstring codec_jstr) { + auto splitter = GetShuffleSplitter(env, splitter_id); + + auto compression_codec = arrow::Compression::UNCOMPRESSED; + auto codec_l = env->GetStringUTFChars(codec_jstr, JNI_FALSE); + if (codec_l != nullptr) { + std::string codec_u; + std::transform(codec_l, codec_l + std::strlen(codec_l), std::back_inserter(codec_u), + ::toupper); + auto result = arrow::util::Codec::GetCompressionType(codec_u); + if (result.ok()) { + compression_codec = *result; + } else { + env->ThrowNew(io_exception_class, + std::string("failed to get compression codec, error message is " + + result.status().message()) + .c_str()); + } + if (compression_codec == arrow::Compression::LZ4) { + compression_codec = arrow::Compression::LZ4_FRAME; + } + } + env->ReleaseStringUTFChars(codec_jstr, codec_l); + + splitter->set_compression_codec(compression_codec); +} + +JNIEXPORT jobjectArray JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_getPartitionFileInfo( + JNIEnv* env, jobject, jlong splitter_id) { + auto splitter = GetShuffleSplitter(env, splitter_id); + + const auto& partition_file_info = splitter->GetPartitionFileInfo(); + auto num_partitions = partition_file_info.size(); + + jobjectArray partition_file_info_array = + env->NewObjectArray(num_partitions, partition_file_info_class, nullptr); + + for (auto i = 0; i < num_partitions; ++i) { + jobject file_info_obj = + env->NewObject(partition_file_info_class, partition_file_info_constructor, + partition_file_info[i].first, + env->NewStringUTF(partition_file_info[i].second.c_str())); + env->SetObjectArrayElement(partition_file_info_array, i, file_info_obj); + } + return partition_file_info_array; +} + +JNIEXPORT jlong JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_getTotalBytesWritten( + JNIEnv* env, jobject, jlong splitter_id) { + auto splitter = GetShuffleSplitter(env, splitter_id); + auto result = splitter->TotalBytesWritten(); + + if (!result.ok()) { + env->ThrowNew(io_exception_class, + std::string("native split: get total bytes written failed").c_str()); + } + + return (jlong)*result; +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleSplitterJniWrapper_close( + JNIEnv* env, jobject, jlong splitter_id) { + shuffle_splitter_holder_.Erase(splitter_id); +} + +JNIEXPORT jlong JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleDecompressionJniWrapper_make( + JNIEnv* env, jobject, jbyteArray schema_arr) { + std::shared_ptr schema; + arrow::Status status; + + status = MakeSchema(env, schema_arr, &schema); + if (!status.ok()) { + env->ThrowNew( + io_exception_class, + std::string("failed to readSchema, err msg is " + status.message()).c_str()); + } + + return decompression_schema_holder_.Insert(schema); +} + +JNIEXPORT jobject JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleDecompressionJniWrapper_decompress( + JNIEnv* env, jobject obj, jlong schema_holder_id, jstring codec_jstr, jint num_rows, + jlongArray buf_addrs, jlongArray buf_sizes, jlongArray buf_mask) { + auto schema = decompression_schema_holder_.Lookup(schema_holder_id); + + int in_bufs_len = env->GetArrayLength(buf_addrs); + if (in_bufs_len != env->GetArrayLength(buf_sizes)) { + std::string error_message = + "native decompress: mismatch in arraylen of buf_addrs and buf_sizes"; + env->ThrowNew(io_exception_class, error_message.c_str()); + } + + auto in_buf_addrs = env->GetLongArrayElements(buf_addrs, JNI_FALSE); + auto in_buf_sizes = env->GetLongArrayElements(buf_sizes, JNI_FALSE); + auto in_buf_mask = env->GetLongArrayElements(buf_mask, JNI_FALSE); + int buf_idx = 0; + int field_idx = 0; + + // get decompression compression_codec + auto compression_codec = arrow::Compression::UNCOMPRESSED; + auto codec_l = env->GetStringUTFChars(codec_jstr, JNI_FALSE); + if (codec_l != nullptr) { + std::string codec_u; + std::transform(codec_l, codec_l + std::strlen(codec_l), std::back_inserter(codec_u), + ::toupper); + auto result = arrow::util::Codec::GetCompressionType(codec_u); + if (result.ok()) { + compression_codec = *result; + } else { + env->ThrowNew(io_exception_class, + std::string("failed to get compression codec, error message is " + + result.status().message()) + .c_str()); + } + if (compression_codec == arrow::Compression::LZ4) { + compression_codec = arrow::Compression::LZ4_FRAME; + } + } + env->ReleaseStringUTFChars(codec_jstr, codec_l); + + std::vector> arrays; + while (field_idx < schema->num_fields()) { + auto field = schema->field(field_idx); + std::vector> buffers; + + // decompress validity buffer + auto result = arrow::BitUtil::GetBit((uint8_t*)in_buf_mask, buf_idx) + ? DecompressBuffer(in_buf_addrs[buf_idx], in_buf_sizes[buf_idx], + arrow::Compression::UNCOMPRESSED) + : DecompressBuffer(in_buf_addrs[buf_idx], in_buf_sizes[buf_idx], + compression_codec); + if (result.ok()) { + buffers.push_back(std::move(result).ValueOrDie()); + } else { + env->ThrowNew( + io_exception_class, + std::string("failed to decompress validity buffer, error message is " + + result.status().message()) + .c_str()); + } + + // decompress value buffer + result = DecompressBuffer(in_buf_addrs[buf_idx + 1], in_buf_sizes[buf_idx + 1], + compression_codec); + if (result.ok()) { + buffers.push_back(std::move(result).ValueOrDie()); + } else { + env->ThrowNew(io_exception_class, + std::string("failed to decompress value buffer, error message is " + + result.status().message()) + .c_str()); + } + + if (arrow::is_binary_like(field->type()->id())) { + // decompress offset buffer + result = DecompressBuffer(in_buf_addrs[buf_idx + 2], in_buf_sizes[buf_idx + 2], + compression_codec); + if (result.ok()) { + buffers.push_back(std::move(result).ValueOrDie()); + } else { + env->ThrowNew( + io_exception_class, + std::string("failed to decompress offset buffer, error message is " + + result.status().message()) + .c_str()); + } + buf_idx += 3; + } else { + buf_idx += 2; + } + arrays.push_back(arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers))); + + ++field_idx; + } + + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); + env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); + env->ReleaseLongArrayElements(buf_mask, in_buf_mask, JNI_ABORT); + + return MakeRecordBatchBuilder( + env, schema, arrow::RecordBatch::Make(schema, num_rows, std::move(arrays))); +} + +JNIEXPORT void JNICALL +Java_com_intel_sparkColumnarPlugin_vectorized_ShuffleDecompressionJniWrapper_close( + JNIEnv* env, jobject, jlong schema_holder_id) { + decompression_schema_holder_.Erase(schema_holder_id); +} + #ifdef __cplusplus } #endif diff --git a/oap-native-sql/cpp/src/shuffle/partition_writer.cc b/oap-native-sql/cpp/src/shuffle/partition_writer.cc new file mode 100644 index 000000000..bf582cf16 --- /dev/null +++ b/oap-native-sql/cpp/src/shuffle/partition_writer.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "shuffle/partition_writer.h" + +#include +#include +#include +#include +#include +#include + +namespace sparkcolumnarplugin { +namespace shuffle { + +arrow::Result> PartitionWriter::Create( + int32_t pid, int64_t capacity, Type::typeId last_type, + const std::vector& column_type_id, + const std::shared_ptr& schema, const std::string& temp_file_path, + arrow::Compression::type compression_codec) { + auto buffers = TypeBufferMessages(Type::NUM_TYPES); + auto binary_bulders = BinaryBuilders(); + auto large_binary_bulders = LargeBinaryBuilders(); + + for (auto type_id : column_type_id) { + switch (type_id) { + case Type::SHUFFLE_BINARY: { + std::unique_ptr builder; + builder.reset(new arrow::BinaryBuilder(arrow::default_memory_pool())); + binary_bulders.push_back(std::move(builder)); + } break; + case Type::SHUFFLE_LARGE_BINARY: { + std::unique_ptr builder; + builder.reset(new arrow::LargeBinaryBuilder(arrow::default_memory_pool())); + large_binary_bulders.push_back(std::move(builder)); + } break; + case Type::SHUFFLE_NULL: { + buffers[type_id].push_back(std::unique_ptr( + new BufferMessage{.validity_buffer = nullptr, .value_buffer = nullptr})); + } break; + default: { + std::shared_ptr validity_buffer; + std::shared_ptr value_buffer; + uint8_t* validity_addr; + uint8_t* value_addr; + + ARROW_ASSIGN_OR_RAISE(validity_buffer, arrow::AllocateEmptyBitmap(capacity)) + if (type_id == Type::SHUFFLE_BIT) { + ARROW_ASSIGN_OR_RAISE(value_buffer, arrow::AllocateEmptyBitmap(capacity)) + } else { + ARROW_ASSIGN_OR_RAISE(value_buffer, + arrow::AllocateBuffer(capacity * (1 << type_id))) + } + validity_addr = validity_buffer->mutable_data(); + value_addr = value_buffer->mutable_data(); + buffers[type_id].push_back(std::unique_ptr( + new BufferMessage{.validity_buffer = std::move(validity_buffer), + .value_buffer = std::move(value_buffer), + .validity_addr = validity_addr, + .value_addr = value_addr})); + } break; + } + } + + ARROW_ASSIGN_OR_RAISE(auto file, + arrow::io::FileOutputStream::Open(temp_file_path, true)); + + return std::make_shared( + pid, capacity, last_type, column_type_id, schema, temp_file_path, std::move(file), + std::move(buffers), std::move(binary_bulders), std::move(large_binary_bulders), + compression_codec); +} + +arrow::Status PartitionWriter::Stop() { + if (write_offset_[last_type_] != 0) { + RETURN_NOT_OK(WriteArrowRecordBatch()); + std::fill(std::begin(write_offset_), std::end(write_offset_), 0); + } + if (file_writer_opened_) { + RETURN_NOT_OK(file_writer_->Close()); + file_writer_opened_ = false; + } + if (!file_->closed()) { + ARROW_ASSIGN_OR_RAISE(file_footer_, file_->Tell()); + return file_->Close(); + } + return arrow::Status::OK(); +} + +arrow::Status PartitionWriter::WriteArrowRecordBatch() { + std::vector> arrays(schema_->num_fields()); + for (int i = 0; i < schema_->num_fields(); ++i) { + auto type_id = column_type_id_[i]; + if (type_id == Type::SHUFFLE_BINARY) { + auto builder = std::move(binary_builders_.front()); + binary_builders_.pop_front(); + RETURN_NOT_OK(builder->Finish(&arrays[i])); + binary_builders_.push_back(std::move(builder)); + } else if (type_id == Type::SHUFFLE_LARGE_BINARY) { + auto builder = std::move(large_binary_builders_.front()); + large_binary_builders_.pop_front(); + RETURN_NOT_OK(builder->Finish(&arrays[i])); + large_binary_builders_.push_back(std::move(builder)); + } else { + auto buf_msg_ptr = std::move(buffers_[type_id].front()); + buffers_[type_id].pop_front(); + auto arr = arrow::ArrayData::Make( + schema_->field(i)->type(), write_offset_[last_type_], + std::vector>{buf_msg_ptr->validity_buffer, + buf_msg_ptr->value_buffer}); + arrays[i] = arrow::MakeArray(arr); + buffers_[type_id].push_back(std::move(buf_msg_ptr)); + } + } + auto record_batch = + arrow::RecordBatch::Make(schema_, write_offset_[last_type_], std::move(arrays)); + + if (!file_writer_opened_) { + auto options = arrow::ipc::IpcWriteOptions::Defaults(); + options.allow_64bit = true; + options.compression = compression_codec_; + options.use_threads = false; + + auto res = arrow::ipc::NewStreamWriter(file_.get(), schema_, options); + RETURN_NOT_OK(res.status()); + file_writer_ = *res; + file_writer_opened_ = true; + } + RETURN_NOT_OK(file_writer_->WriteRecordBatch(*record_batch)); + + return arrow::Status::OK(); +} + +} // namespace shuffle +} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/shuffle/partition_writer.h b/oap-native-sql/cpp/src/shuffle/partition_writer.h new file mode 100644 index 000000000..4711be882 --- /dev/null +++ b/oap-native-sql/cpp/src/shuffle/partition_writer.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "shuffle/type.h" + +namespace sparkcolumnarplugin { +namespace shuffle { + +namespace detail { + +template +arrow::Status inline Write(const SrcBuffers& src, int64_t src_offset, + const BufferMessages& dst, int64_t dst_offset) { + for (size_t i = 0; i < src.size(); ++i) { + dst[i]->validity_addr[dst_offset / 8] |= + (((src[i].validity_addr)[src_offset / 8] >> (src_offset % 8)) & 1) + << (dst_offset % 8); + reinterpret_cast(dst[i]->value_addr)[dst_offset] = + reinterpret_cast(src[i].value_addr)[src_offset]; + } + return arrow::Status::OK(); +} + +template <> +arrow::Status inline Write(const SrcBuffers& src, int64_t src_offset, + const BufferMessages& dst, int64_t dst_offset) { + for (size_t i = 0; i < src.size(); ++i) { + dst[i]->validity_addr[dst_offset / 8] |= + (((src[i].validity_addr)[src_offset / 8] >> (src_offset % 8)) & 1) + << (dst_offset % 8); + dst[i]->value_addr[dst_offset / 8] |= + (((src[i].value_addr)[src_offset / 8] >> (src_offset % 8)) & 1) + << (dst_offset % 8); + } + return arrow::Status::OK(); +} + +template ::ArrayType, + typename BuilderType = typename arrow::TypeTraits::BuilderType> +arrow::enable_if_binary_like inline WriteBinary( + const std::vector>& src, int64_t offset, + const std::deque>& builders) { + using offset_type = typename T::offset_type; + + for (size_t i = 0; i < src.size(); ++i) { + offset_type length; + auto value = src[i]->GetValue(offset, &length); + RETURN_NOT_OK(builders[i]->Append(value, length)); + } + return arrow::Status::OK(); +} + +template ::ArrayType, + typename BuilderType = typename arrow::TypeTraits::BuilderType> +arrow::enable_if_binary_like inline WriteNullableBinary( + const std::vector>& src, int64_t offset, + const std::deque>& builders) { + using offset_type = typename T::offset_type; + + for (size_t i = 0; i < src.size(); ++i) { + // check not null + if (src[i]->IsValid(offset)) { + offset_type length; + auto value = src[i]->GetValue(offset, &length); + RETURN_NOT_OK(builders[i]->Append(value, length)); + } else { + RETURN_NOT_OK(builders[i]->AppendNull()); + } + } + return arrow::Status::OK(); +} + +} // namespace detail +class PartitionWriter { + public: + explicit PartitionWriter(int32_t pid, int64_t capacity, Type::typeId last_type, + const std::vector& column_type_id, + const std::shared_ptr& schema, + std::string file_path, + std::shared_ptr file, + TypeBufferMessages buffers, BinaryBuilders binary_builders, + LargeBinaryBuilders large_binary_builders, + arrow::Compression::type compression_codec) + : pid_(pid), + capacity_(capacity), + last_type_(last_type), + column_type_id_(column_type_id), + schema_(schema), + file_path_(std::move(file_path)), + file_(std::move(file)), + buffers_(std::move(buffers)), + binary_builders_(std::move(binary_builders)), + large_binary_builders_(std::move(large_binary_builders)), + compression_codec_(compression_codec), + write_offset_(Type::typeId::NUM_TYPES), + file_footer_(0), + file_writer_opened_(false), + file_writer_(nullptr) {} + + static arrow::Result> Create( + int32_t pid, int64_t capacity, Type::typeId last_type, + const std::vector& column_type_id, + const std::shared_ptr& schema, const std::string& temp_file_path, + arrow::Compression::type compression_codec); + + arrow::Status Stop(); + + int32_t pid() { return pid_; } + + int64_t capacity() { return capacity_; } + + int64_t write_offset() { return write_offset_[last_type_]; } + + Type::typeId last_type() { return last_type_; } + + const std::string& file_path() const { return file_path_; } + + int64_t file_footer() const { return file_footer_; } + + arrow::Status WriteArrowRecordBatch(); + + arrow::Result BytesWritten() { + if (!file_->closed()) { + ARROW_ASSIGN_OR_RAISE(file_footer_, file_->Tell()); + } + return file_footer_; + } + + arrow::Result inline CheckTypeWriteEnds(const Type::typeId& type_id) { + if (write_offset_[type_id] == capacity_) { + if (type_id == last_type_) { + RETURN_NOT_OK(WriteArrowRecordBatch()); + std::fill(std::begin(write_offset_), std::end(write_offset_), 0); + } + return true; + } + return false; + } + + /// Do memory copy, return true if mem-copy performed + /// if writer's memory buffer is full, then no mem-copy will be performed, will spill to + /// disk and return false + /// \tparam T arrow::DataType + /// \param type_id shuffle type id mapped from T + /// \param src source buffers + /// \param offset index of the element in source buffers + /// \return true if write performed, else false + template + arrow::Result inline Write(Type::typeId type_id, const SrcBuffers& src, + int64_t offset) { + // for the type_id, check if write ends. For the last type reset write_offset and + // spill + ARROW_ASSIGN_OR_RAISE(auto write_ends, CheckTypeWriteEnds(type_id)) + if (write_ends) { + return false; + } + + RETURN_NOT_OK( + detail::Write(src, offset, buffers_[type_id], write_offset_[type_id])); + + ++write_offset_[type_id]; + return true; + } + + /// Do memory copy for binary type + /// \param src source binary array + /// \param offset index of the element in source binary array + /// \return true if write performed, else false + arrow::Result inline WriteBinary( + const std::vector>& src, int64_t offset) { + ARROW_ASSIGN_OR_RAISE(auto write_ends, CheckTypeWriteEnds(Type::SHUFFLE_BINARY)) + if (write_ends) { + return false; + } + + RETURN_NOT_OK(detail::WriteBinary(src, offset, binary_builders_)); + + ++write_offset_[Type::SHUFFLE_BINARY]; + return true; + } + + /// Do memory copy for large binary type + /// \param src source binary array + /// \param offset index of the element in source binary array + /// \return + arrow::Result inline WriteLargeBinary( + const std::vector>& src, int64_t offset) { + ARROW_ASSIGN_OR_RAISE(auto write_ends, CheckTypeWriteEnds(Type::SHUFFLE_LARGE_BINARY)) + if (write_ends) { + return false; + } + + RETURN_NOT_OK( + detail::WriteBinary(src, offset, large_binary_builders_)); + + ++write_offset_[Type::SHUFFLE_LARGE_BINARY]; + return true; + } + /// Do memory copy for binary type + /// \param src source binary array + /// \param offset index of the element in source binary array + /// \return + arrow::Result inline WriteNullableBinary( + const std::vector>& src, int64_t offset) { + ARROW_ASSIGN_OR_RAISE(auto write_ends, CheckTypeWriteEnds(Type::SHUFFLE_BINARY)) + if (write_ends) { + return false; + } + + RETURN_NOT_OK( + detail::WriteNullableBinary(src, offset, binary_builders_)); + + ++write_offset_[Type::SHUFFLE_BINARY]; + return true; + } + + /// Do memory copy for large binary type + /// \param src source binary array + /// \param offset index of the element in source binary array + /// \return + arrow::Result inline WriteNullableLargeBinary( + const std::vector>& src, int64_t offset) { + ARROW_ASSIGN_OR_RAISE(auto write_ends, CheckTypeWriteEnds(Type::SHUFFLE_LARGE_BINARY)) + if (write_ends) { + return false; + } + + RETURN_NOT_OK(detail::WriteNullableBinary( + src, offset, large_binary_builders_)); + + ++write_offset_[Type::SHUFFLE_LARGE_BINARY]; + return true; + } + + private: + const int32_t pid_; + const int64_t capacity_; + const Type::typeId last_type_; + const std::vector& column_type_id_; + const std::shared_ptr& schema_; + const std::string file_path_; + + std::shared_ptr file_; + TypeBufferMessages buffers_; + BinaryBuilders binary_builders_; + LargeBinaryBuilders large_binary_builders_; + arrow::Compression::type compression_codec_; + + std::vector write_offset_; + int64_t file_footer_; + bool file_writer_opened_; + std::shared_ptr file_writer_; +}; + +} // namespace shuffle +} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/shuffle/splitter.cc b/oap-native-sql/cpp/src/shuffle/splitter.cc new file mode 100644 index 000000000..1a602ec2d --- /dev/null +++ b/oap-native-sql/cpp/src/shuffle/splitter.cc @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "shuffle/splitter.h" +#include +#include +#include +#include +#include +#include +#include "shuffle/partition_writer.h" + +#include +#include +#include +#include + +namespace sparkcolumnarplugin { +namespace shuffle { + +std::string GenerateUUID() { + boost::uuids::random_generator generator; + return boost::uuids::to_string(generator()); +} + +class Splitter::Impl { + public: + explicit Impl(const std::shared_ptr& schema) : schema_(schema) {} + + arrow::Status Init() { + // remove partition id field since we don't need it while splitting + ARROW_ASSIGN_OR_RAISE(writer_schema_, schema_->RemoveField(0)) + + const auto& fields = writer_schema_->fields(); + std::vector result; + result.reserve(fields.size()); + + std::transform(std::cbegin(fields), std::cend(fields), std::back_inserter(result), + [](const std::shared_ptr& field) -> Type::typeId { + auto arrow_type_id = field->type()->id(); + switch (arrow_type_id) { + case arrow::BooleanType::type_id: + return Type::SHUFFLE_BIT; + case arrow::Int8Type::type_id: + case arrow::UInt8Type::type_id: + return Type::SHUFFLE_1BYTE; + case arrow::Int16Type::type_id: + case arrow::UInt16Type::type_id: + case arrow::HalfFloatType::type_id: + return Type::SHUFFLE_2BYTE; + case arrow::Int32Type::type_id: + case arrow::UInt32Type::type_id: + case arrow::FloatType::type_id: + case arrow::Date32Type::type_id: + case arrow::Time32Type::type_id: + return Type::SHUFFLE_4BYTE; + case arrow::Int64Type::type_id: + case arrow::UInt64Type::type_id: + case arrow::DoubleType::type_id: + case arrow::Date64Type::type_id: + case arrow::Time64Type::type_id: + case arrow::TimestampType::type_id: + return Type::SHUFFLE_8BYTE; + case arrow::BinaryType::type_id: + case arrow::StringType::type_id: + return Type::SHUFFLE_BINARY; + case arrow::LargeBinaryType::type_id: + case arrow::LargeStringType::type_id: + return Type::SHUFFLE_LARGE_BINARY; + case arrow::NullType::type_id: + return Type::SHUFFLE_NULL; + default: + std::cout << field->ToString() << " field type id " + << arrow_type_id << std::endl; + return Type::SHUFFLE_NOT_IMPLEMENTED; + } + }); + + auto it = + std::find(std::begin(result), std::end(result), Type::SHUFFLE_NOT_IMPLEMENTED); + if (it != std::end(result)) { + RETURN_NOT_OK(arrow::Status::NotImplemented("field contains not implemented type")); + } + column_type_id_ = std::move(result); + + decltype(column_type_id_) remove_null_id(column_type_id_.size()); + std::copy_if(std::cbegin(column_type_id_), std::cend(column_type_id_), + std::begin(remove_null_id), + [](Type::typeId id) { return id != Type::typeId::SHUFFLE_NULL; }); + last_type_ = + *std::max_element(std::cbegin(remove_null_id), std::cend(remove_null_id)); + + auto local_fs = std::make_shared(); + ARROW_ASSIGN_OR_RAISE(auto local_dirs, GetConfiguredLocalDirs()); + std::transform(local_dirs.cbegin(), local_dirs.cend(), std::back_inserter(local_dirs_fs_), + [local_fs](const auto& base_dir) { + return std::make_unique(base_dir, + local_fs); + }); + return arrow::Status::OK(); + } + + arrow::Status Split(const arrow::RecordBatch& record_batch) { + const auto& pid_arr = record_batch.column_data(0); + if (pid_arr->GetNullCount() != 0) { + return arrow::Status::Invalid("Column partition id should not contain NULL value"); + } + if (pid_arr->type->id() != arrow::Int32Type::type_id) { + return arrow::Status::Invalid("Partition id data type mismatch, expected ", + arrow::Int32Type::type_name(), ", but got ", + record_batch.column(0)->type()->name()); + } + + auto num_rows = record_batch.num_rows(); + auto num_cols = record_batch.num_columns(); + auto src_addr = std::vector(Type::NUM_TYPES); + + auto src_binary_arr = SrcBinaryArrays(); + auto src_nullable_binary_arr = SrcBinaryArrays(); + + auto src_large_binary_arr = SrcLargeBinaryArrays(); + auto src_nullable_large_binary_arr = SrcLargeBinaryArrays(); + + // TODO: make dummy_buf private static if possible + arrow::TypedBufferBuilder null_bitmap_builder_; + RETURN_NOT_OK(null_bitmap_builder_.Append(num_rows, true)); + + std::shared_ptr dummy_buf; + RETURN_NOT_OK(null_bitmap_builder_.Finish(&dummy_buf)); + auto dummy_buf_p = const_cast(dummy_buf->data()); + + // Get the pointer of each buffer, Ignore column_data(0) which indicates the partition + // id + for (auto i = 0; i < num_cols - 1; ++i) { + const auto& buffers = record_batch.column_data(i + 1)->buffers; + if (record_batch.column_data(i + 1)->GetNullCount() == 0) { + if (column_type_id_[i] == Type::SHUFFLE_BINARY) { + src_binary_arr.push_back( + std::static_pointer_cast(record_batch.column(i + 1))); + } else if (column_type_id_[i] == Type::SHUFFLE_LARGE_BINARY) { + src_large_binary_arr.push_back( + std::static_pointer_cast( + record_batch.column(i + 1))); + } else if (column_type_id_[i] != Type::SHUFFLE_NULL) { + // null bitmap may be nullptr + src_addr[column_type_id_[i]].push_back( + {.validity_addr = dummy_buf_p, + .value_addr = const_cast(buffers[1]->data())}); + } + } else { + if (column_type_id_[i] == Type::SHUFFLE_BINARY) { + src_nullable_binary_arr.push_back( + std::static_pointer_cast(record_batch.column(i + 1))); + } else if (column_type_id_[i] == Type::SHUFFLE_LARGE_BINARY) { + src_nullable_large_binary_arr.push_back( + std::static_pointer_cast( + record_batch.column(i + 1))); + } else if (column_type_id_[i] != Type::SHUFFLE_NULL) { + src_addr[column_type_id_[i]].push_back( + {.validity_addr = const_cast(buffers[0]->data()), + .value_addr = const_cast(buffers[1]->data())}); + } + } + } + + // map discrete partition id (pid) to continuous integer (new_id) + // create a new writer every time a new_id occurs + std::vector new_id; + new_id.reserve(num_rows); + auto pid_cast_p = reinterpret_cast(pid_arr->buffers[1]->data()); + for (int64_t i = 0; i < num_rows; ++i) { + auto pid = pid_cast_p[i]; + if (pid_to_new_id_.find(pid) == pid_to_new_id_.end()) { + auto temp_dir = GenerateUUID(); + const auto& fs = local_dirs_fs_[num_partitiions_ % local_dirs_fs_.size()]; + while ((*fs->GetFileInfo(temp_dir)).type() != arrow::fs::FileType::NotFound) { + temp_dir = GenerateUUID(); + } + RETURN_NOT_OK(fs->CreateDir(temp_dir)); + auto temp_file_path = arrow::fs::internal::ConcatAbstractPath( + fs->base_path(), (*fs->GetFileInfo(temp_dir)).path()) + + "/data"; + temp_files.push_back({pid, temp_file_path}); + + ARROW_ASSIGN_OR_RAISE( + auto writer, + PartitionWriter::Create(pid, buffer_size_, last_type_, column_type_id_, + writer_schema_, temp_file_path, compression_codec_)); + pid_writer_.push_back(std::move(writer)); + new_id.push_back(num_partitiions_); + pid_to_new_id_[pid] = num_partitiions_++; + } else { + new_id.push_back(pid_to_new_id_[pid]); + } + } + + auto read_offset = 0; + +#define WRITE_FIXEDWIDTH(TYPE_ID, T) \ + if (!src_addr[TYPE_ID].empty()) { \ + for (i = read_offset; i < num_rows; ++i) { \ + ARROW_ASSIGN_OR_RAISE( \ + auto result, pid_writer_[new_id[i]]->Write(TYPE_ID, src_addr[TYPE_ID], i)) \ + if (!result) { \ + break; \ + } \ + } \ + } + +#define WRITE_BINARY(func, T, src_arr) \ + if (!src_arr.empty()) { \ + for (i = read_offset; i < num_rows; ++i) { \ + ARROW_ASSIGN_OR_RAISE(auto result, pid_writer_[new_id[i]]->func(src_arr, i)) \ + if (!result) { \ + break; \ + } \ + } \ + } + + while (read_offset < num_rows) { + auto i = read_offset; + WRITE_FIXEDWIDTH(Type::SHUFFLE_1BYTE, uint8_t); + WRITE_FIXEDWIDTH(Type::SHUFFLE_2BYTE, uint16_t); + WRITE_FIXEDWIDTH(Type::SHUFFLE_4BYTE, uint32_t); + WRITE_FIXEDWIDTH(Type::SHUFFLE_8BYTE, uint64_t); + WRITE_FIXEDWIDTH(Type::SHUFFLE_BIT, bool); + WRITE_BINARY(WriteBinary, arrow::BinaryType, src_binary_arr); + WRITE_BINARY(WriteLargeBinary, arrow::LargeBinaryType, src_large_binary_arr); + WRITE_BINARY(WriteNullableBinary, arrow::BinaryType, src_nullable_binary_arr); + WRITE_BINARY(WriteNullableLargeBinary, arrow::LargeBinaryType, + src_nullable_large_binary_arr); + read_offset = i; + } +#undef WRITE_FIXEDWIDTH + + return arrow::Status::OK(); + } + + arrow::Status Stop() { + // write final record batch + for (const auto& writer : pid_writer_) { + RETURN_NOT_OK(writer->Stop()); + } + std::sort(std::begin(temp_files), std::end(temp_files)); + return arrow::Status::OK(); + } + + arrow::Result TotalBytesWritten() { + int64_t res = 0; + for (const auto& writer : pid_writer_) { + ARROW_ASSIGN_OR_RAISE(auto bytes, writer->BytesWritten()); + res += bytes; + } + return res; + } + + static arrow::Result CreateAttemptSubDir(const std::string& root_dir) { + auto attempt_sub_dir = arrow::fs::internal::ConcatAbstractPath(root_dir, "columnar-shuffle-" + GenerateUUID()); + ARROW_ASSIGN_OR_RAISE(auto created, arrow::internal::CreateDirTree( + *arrow::internal::PlatformFilename::FromString(attempt_sub_dir))); + // if create succeed, use created subdir, else use root dir + if (created) { + return attempt_sub_dir; + } else { + return root_dir; + } + } + + static arrow::Result> GetConfiguredLocalDirs() { + auto joined_dirs_c = std::getenv("NATIVESQL_SPARK_LOCAL_DIRS"); + if (joined_dirs_c != nullptr && strcmp(joined_dirs_c, "") > 0) { + auto joined_dirs = std::string(joined_dirs_c); + std::string delimiter = ","; + std::vector dirs; + + size_t pos = 0; + std::string root_dir; + while ((pos = joined_dirs.find(delimiter)) != std::string::npos) { + root_dir = joined_dirs.substr(0, pos); + if (root_dir.length() > 0) { + dirs.push_back(*CreateAttemptSubDir(root_dir)); + } + joined_dirs.erase(0, pos + delimiter.length()); + } + if (joined_dirs.length() > 0) { + dirs.push_back(*CreateAttemptSubDir(joined_dirs)); + } + return dirs; + } else { + ARROW_ASSIGN_OR_RAISE(auto arrow_tmp_dir, + arrow::internal::TemporaryDir::Make("columnar-shuffle-")); + return std::vector{arrow_tmp_dir->path().ToString()}; + } + } + + Type::typeId column_type_id(int i) const { return column_type_id_[i]; } + + void set_buffer_size(int64_t buffer_size) { buffer_size_ = buffer_size; } + + void set_compression_codec(arrow::Compression::type compression_codec) { + compression_codec_ = compression_codec; + } + + std::shared_ptr schema() const { return schema_; } + + std::shared_ptr writer(int32_t pid) { + if (pid_to_new_id_.find(pid) == pid_to_new_id_.end()) { + return nullptr; + } + return pid_writer_[pid_to_new_id_[pid]]; + } + + std::vector> temp_files; + + private: + std::shared_ptr schema_; + + // writer_schema_ removes the first field of schema_ which indicates the partition id + std::shared_ptr writer_schema_; + + int32_t num_partitiions_ = 0; + Type::typeId last_type_; + std::vector column_type_id_; + std::unordered_map pid_to_new_id_; + std::vector> pid_writer_; + + int64_t buffer_size_ = kDefaultSplitterBufferSize; + arrow::Compression::type compression_codec_ = arrow::Compression::UNCOMPRESSED; + + std::vector> local_dirs_fs_; +}; + +arrow::Result> Splitter::Make( + const std::shared_ptr& schema) { + std::shared_ptr ptr(new Splitter(schema)); + RETURN_NOT_OK(ptr->impl_->Init()); + return ptr; +} + +Splitter::Splitter(const std::shared_ptr& schema) { + impl_.reset(new Impl(schema)); +} + +std::shared_ptr Splitter::schema() const { return impl_->schema(); } + +Type::typeId Splitter::column_type_id(int i) const { return impl_->column_type_id(i); } + +arrow::Status Splitter::Split(const arrow::RecordBatch& rb) { return impl_->Split(rb); } + +std::shared_ptr Splitter::writer(int32_t pid) { + return impl_->writer(pid); +} + +arrow::Status Splitter::Stop() { return impl_->Stop(); } + +const std::vector>& Splitter::GetPartitionFileInfo() + const { + return impl_->temp_files; +} + +void Splitter::set_buffer_size(int64_t buffer_size) { + impl_->set_buffer_size(buffer_size); +} + +void Splitter::set_compression_codec(arrow::Compression::type compression_codec) { + impl_->set_compression_codec(compression_codec); +} + +arrow::Result Splitter::TotalBytesWritten() { + return impl_->TotalBytesWritten(); +} + +Splitter::~Splitter() = default; + +} // namespace shuffle +} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/shuffle/splitter.h b/oap-native-sql/cpp/src/shuffle/splitter.h new file mode 100644 index 000000000..b71f3ea3d --- /dev/null +++ b/oap-native-sql/cpp/src/shuffle/splitter.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "shuffle/partition_writer.h" +#include "shuffle/type.h" + +namespace sparkcolumnarplugin { +namespace shuffle { + +class Splitter { + public: + ~Splitter(); + + static arrow::Result> Make( + const std::shared_ptr& schema); + + std::shared_ptr schema() const; + + Type::typeId column_type_id(int i) const; + + void set_buffer_size(int64_t buffer_size); + + void set_compression_codec(arrow::Compression::type compression_codec); + + arrow::Status Split(const arrow::RecordBatch&); + + /*** + * Stop all writers created by this splitter. If the data buffer managed by the writer + * is not empty, write to output stream as RecordBatch. Then sort the temporary files by + * partition id. + * @return + */ + arrow::Status Stop(); + + const std::vector>& GetPartitionFileInfo() const; + + arrow::Result TotalBytesWritten(); + + // writer must be called after Split. + std::shared_ptr writer(int32_t pid); + + private: + explicit Splitter(const std::shared_ptr& schema); + class Impl; + std::unique_ptr impl_; +}; + +} // namespace shuffle +} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/shuffle/type.h b/oap-native-sql/cpp/src/shuffle/type.h new file mode 100644 index 000000000..c03cf8758 --- /dev/null +++ b/oap-native-sql/cpp/src/shuffle/type.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace sparkcolumnarplugin { +namespace shuffle { + +static constexpr int64_t kDefaultSplitterBufferSize = 4096; + +struct BufferMessage { + std::shared_ptr validity_buffer; + std::shared_ptr value_buffer; + uint8_t* validity_addr; + uint8_t* value_addr; +}; + +struct BufferAddr { + uint8_t* validity_addr; + uint8_t* value_addr; +}; + +namespace Type { +/// \brief Data type enumeration for shuffle splitter +/// +/// This enumeration maps the types of arrow::Type::type with same length +/// to identical type + +enum typeId : int { + SHUFFLE_1BYTE, + SHUFFLE_2BYTE, + SHUFFLE_4BYTE, + SHUFFLE_8BYTE, + SHUFFLE_BIT, + SHUFFLE_BINARY, + SHUFFLE_LARGE_BINARY, + SHUFFLE_NULL, + NUM_TYPES, + SHUFFLE_NOT_IMPLEMENTED +}; + +static const typeId all[] = { + SHUFFLE_1BYTE, SHUFFLE_2BYTE, SHUFFLE_4BYTE, SHUFFLE_8BYTE, + SHUFFLE_BIT, SHUFFLE_BINARY, SHUFFLE_LARGE_BINARY, SHUFFLE_NULL, +}; + +// std::shared_ptr fixed_size_binary(int32_t byte_width) { +// return std::make_shared(byte_width); +//} +// +// class Fixed1ByteType : public arrow::ExtensionType { +// public: +// static constexpr Type::typeId shuffle_type_id = Type::SHUFFLE_1BYTE; +// +// Fixed1ByteType() : arrow::ExtensionType(fixed_size_binary(1)) {} +// +// std::string extension_name() const override { return "fixed_1_byte"; } +// +// bool ExtensionEquals(const ExtensionType& other) const override { +// return other.extension_name() == this->extension_name(); +// } +// +// std::shared_ptr MakeArray( +// std::shared_ptr data) const override { +// DCHECK_EQ(data->type->id(), arrow::Type::EXTENSION); +// DCHECK_EQ("fixed_1_byte", +// static_cast(*data->type).extension_name()); +// return std::make_shared(data); +// } +// +// arrow::Status Deserialize(std::shared_ptr storage_type, +// const std::string& serialized, +// std::shared_ptr* out) const override { +// if (serialized != "fixed-1-byte-type-unique-code") { +// return arrow::Status::Invalid("Type identifier did not match"); +// } +// DCHECK(storage_type->Equals(*fixed_size_binary(1))); +// *out = std::make_shared(); +// return arrow::Status::OK(); +// } +// +// std::string Serialize() const override { return "fixed-1-byte-type-unique-code"; } +//}; + +} // namespace Type + +using BufferMessages = std::deque>; +using TypeBufferMessages = std::vector; +using BinaryBuilders = std::deque>; +using LargeBinaryBuilders = std::deque>; +using BufferPtr = std::shared_ptr; +using SrcBuffers = std::vector; +using SrcArrays = std::vector>; +using SrcBinaryArrays = std::vector>; +using SrcLargeBinaryArrays = std::vector>; + +} // namespace shuffle +} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/tests/CMakeLists.txt b/oap-native-sql/cpp/src/tests/CMakeLists.txt index 374b461e2..0f7df7f86 100644 --- a/oap-native-sql/cpp/src/tests/CMakeLists.txt +++ b/oap-native-sql/cpp/src/tests/CMakeLists.txt @@ -1,3 +1,4 @@ package_add_test(TestArrowComputeAggregate arrow_compute_test_aggregate.cc) package_add_test(TestArrowComputeJoin arrow_compute_test_join.cc) package_add_test(TestArrowComputeSort arrow_compute_test_sort.cc) +package_add_test(TestShuffleSplit shuffle_split_test.cc) diff --git a/oap-native-sql/cpp/src/tests/arrow_compute_test_join.cc b/oap-native-sql/cpp/src/tests/arrow_compute_test_join.cc index 850c6ebb3..de0dc596d 100644 --- a/oap-native-sql/cpp/src/tests/arrow_compute_test_join.cc +++ b/oap-native-sql/cpp/src/tests/arrow_compute_test_join.cc @@ -27,7 +27,7 @@ namespace sparkcolumnarplugin { namespace codegen { -TEST(TestArrowCompute, JoinTestUsingInnerJoin) { +TEST(TestArrowComputeJoin, JoinTestUsingInnerJoin) { ////////////////////// prepare expr_vector /////////////////////// auto table0_f0 = field("table0_f0", uint32()); auto table0_f1 = field("table0_f1", uint32()); @@ -35,9 +35,6 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoin) { auto table1_f0 = field("table1_f0", uint32()); auto table1_f1 = field("table1_f1", uint32()); - auto indices_type = std::make_shared(4); - auto f_indices = field("indices", indices_type); - auto n_left = TreeExprBuilder::MakeFunction( "codegen_left_schema", {TreeExprBuilder::MakeField(table0_f0), TreeExprBuilder::MakeField(table0_f1), @@ -53,38 +50,25 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoin) { "codegen_left_key_schema", {TreeExprBuilder::MakeField(table0_f0)}, uint32()); auto n_right_key = TreeExprBuilder::MakeFunction( "codegen_right_key_schema", {TreeExprBuilder::MakeField(table1_f0)}, uint32()); - auto n_probeArrays = TreeExprBuilder::MakeFunction( - "conditionedProbeArraysInner", {n_left_key, n_right_key}, indices_type); + auto n_probeArrays = TreeExprBuilder::MakeFunction("conditionedProbeArraysInner", + {n_left_key, n_right_key}, uint32()); auto n_codegen_probe = TreeExprBuilder::MakeFunction( "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_res); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); - - auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); - auto schema_table_0 = arrow::schema({table0_f0, table0_f1, table0_f2}); auto schema_table_1 = arrow::schema({table1_f0, table1_f1}); auto schema_table = arrow::schema({table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}); ///////////////////// Calculation ////////////////// std::shared_ptr expr_probe; - ASSERT_NOT_OK(CreateCodeGenerator(schema_table_0, {probeArrays_expr}, {f_indices}, - &expr_probe, true)); - std::shared_ptr expr_conditioned_shuffle; - ASSERT_NOT_OK( - CreateCodeGenerator(schema_table, {conditionShuffleExpr}, - {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, - &expr_conditioned_shuffle, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + schema_table_0, {probeArrays_expr}, + {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, &expr_probe, true)); std::shared_ptr input_batch; std::vector> dummy_result_batches; std::shared_ptr> probe_result_iterator; - std::shared_ptr> shuffle_result_iterator; std::vector> table_0; std::vector> table_1; @@ -128,11 +112,8 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoin) { ////////////////////// evaluate ////////////////////// for (auto batch : table_0) { ASSERT_NOT_OK(expr_probe->evaluate(batch, &dummy_result_batches)); - ASSERT_NOT_OK(expr_conditioned_shuffle->evaluate(batch, &dummy_result_batches)); } ASSERT_NOT_OK(expr_probe->finish(&probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->SetDependency(probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->finish(&shuffle_result_iterator)); for (int i = 0; i < 2; i++) { auto left_batch = table_0[i]; @@ -144,13 +125,12 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoin) { input.push_back(right_batch->column(i)); } - ASSERT_NOT_OK(probe_result_iterator->ProcessAndCacheOne(input)); - ASSERT_NOT_OK(shuffle_result_iterator->Process(input, &result_batch)); + ASSERT_NOT_OK(probe_result_iterator->Process(input, &result_batch)); ASSERT_NOT_OK(Equals(*(expected_table[i]).get(), *result_batch.get())); } } -TEST(TestArrowCompute, JoinTestWithTwoKeysUsingInnerJoin) { +TEST(TestArrowComputeJoin, JoinTestWithTwoKeysUsingInnerJoin) { ////////////////////// prepare expr_vector /////////////////////// auto table0_f0 = field("table0_f0", utf8()); auto table0_f1 = field("table0_f1", utf8()); @@ -158,9 +138,6 @@ TEST(TestArrowCompute, JoinTestWithTwoKeysUsingInnerJoin) { auto table1_f0 = field("table1_f0", utf8()); auto table1_f1 = field("table1_f1", utf8()); - auto indices_type = std::make_shared(4); - auto f_indices = field("indices", indices_type); - auto n_left = TreeExprBuilder::MakeFunction( "codegen_left_schema", {TreeExprBuilder::MakeField(table0_f0), TreeExprBuilder::MakeField(table0_f1), @@ -173,23 +150,19 @@ TEST(TestArrowCompute, JoinTestWithTwoKeysUsingInnerJoin) { auto f_res = field("res", uint32()); auto n_left_key = TreeExprBuilder::MakeFunction( - "codegen_left_key_schema", {TreeExprBuilder::MakeField(table0_f0), TreeExprBuilder::MakeField(table0_f1)}, uint32()); + "codegen_left_key_schema", + {TreeExprBuilder::MakeField(table0_f0), TreeExprBuilder::MakeField(table0_f1)}, + uint32()); auto n_right_key = TreeExprBuilder::MakeFunction( - "codegen_right_key_schema", {TreeExprBuilder::MakeField(table1_f0), TreeExprBuilder::MakeField(table1_f1)}, uint32()); - auto n_probeArrays = TreeExprBuilder::MakeFunction( - "conditionedProbeArraysInner", {n_left_key, n_right_key}, indices_type); + "codegen_right_key_schema", + {TreeExprBuilder::MakeField(table1_f0), TreeExprBuilder::MakeField(table1_f1)}, + uint32()); + auto n_probeArrays = TreeExprBuilder::MakeFunction("conditionedProbeArraysInner", + {n_left_key, n_right_key}, uint32()); auto n_codegen_probe = TreeExprBuilder::MakeFunction( "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_res); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); - - auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); - auto schema_table_0 = arrow::schema({table0_f0, table0_f1, table0_f2}); auto schema_table_1 = arrow::schema({table1_f0, table1_f1}); auto schema_table = @@ -199,18 +172,14 @@ TEST(TestArrowCompute, JoinTestWithTwoKeysUsingInnerJoin) { ///////////////////// Calculation ////////////////// std::shared_ptr expr_probe; - ASSERT_NOT_OK(CreateCodeGenerator(schema_table_0, {probeArrays_expr}, {f_indices}, - &expr_probe, true)); - std::shared_ptr expr_conditioned_shuffle; - ASSERT_NOT_OK( - CreateCodeGenerator(schema_table, {conditionShuffleExpr}, - {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, - &expr_conditioned_shuffle, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + schema_table_0, {probeArrays_expr}, + {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, &expr_probe, true)); + std::shared_ptr input_batch; std::vector> dummy_result_batches; std::shared_ptr> probe_result_iterator; - std::shared_ptr> shuffle_result_iterator; std::vector> table_0; std::vector> table_1; @@ -253,11 +222,8 @@ TEST(TestArrowCompute, JoinTestWithTwoKeysUsingInnerJoin) { ////////////////////// evaluate ////////////////////// for (auto batch : table_0) { ASSERT_NOT_OK(expr_probe->evaluate(batch, &dummy_result_batches)); - ASSERT_NOT_OK(expr_conditioned_shuffle->evaluate(batch, &dummy_result_batches)); } ASSERT_NOT_OK(expr_probe->finish(&probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->SetDependency(probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->finish(&shuffle_result_iterator)); for (int i = 0; i < 2; i++) { auto left_batch = table_0[i]; @@ -269,13 +235,12 @@ TEST(TestArrowCompute, JoinTestWithTwoKeysUsingInnerJoin) { input.push_back(right_batch->column(i)); } - ASSERT_NOT_OK(probe_result_iterator->ProcessAndCacheOne(input)); - ASSERT_NOT_OK(shuffle_result_iterator->Process(input, &result_batch)); + ASSERT_NOT_OK(probe_result_iterator->Process(input, &result_batch)); ASSERT_NOT_OK(Equals(*(expected_table[i]).get(), *result_batch.get())); } } -TEST(TestArrowCompute, JoinTestUsingOuterJoin) { +TEST(TestArrowComputeJoin, JoinTestUsingOuterJoin) { ////////////////////// prepare expr_vector /////////////////////// auto table0_f0 = field("table0_f0", uint32()); auto table0_f1 = field("table0_f1", uint32()); @@ -283,9 +248,6 @@ TEST(TestArrowCompute, JoinTestUsingOuterJoin) { auto table1_f0 = field("table1_f0", uint32()); auto table1_f1 = field("table1_f1", uint32()); - auto indices_type = std::make_shared(4); - auto f_indices = field("indices", indices_type); - auto n_left = TreeExprBuilder::MakeFunction( "codegen_left_schema", {TreeExprBuilder::MakeField(table0_f0), TreeExprBuilder::MakeField(table0_f1), @@ -301,38 +263,26 @@ TEST(TestArrowCompute, JoinTestUsingOuterJoin) { "codegen_left_key_schema", {TreeExprBuilder::MakeField(table0_f0)}, uint32()); auto n_right_key = TreeExprBuilder::MakeFunction( "codegen_right_key_schema", {TreeExprBuilder::MakeField(table1_f0)}, uint32()); - auto n_probeArrays = TreeExprBuilder::MakeFunction( - "conditionedProbeArraysOuter", {n_left_key, n_right_key}, indices_type); + auto n_probeArrays = TreeExprBuilder::MakeFunction("conditionedProbeArraysOuter", + {n_left_key, n_right_key}, uint32()); auto n_codegen_probe = TreeExprBuilder::MakeFunction( "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_res); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); - - auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); - auto schema_table_0 = arrow::schema({table0_f0, table0_f1, table0_f2}); auto schema_table_1 = arrow::schema({table1_f0, table1_f1}); auto schema_table = arrow::schema({table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}); ///////////////////// Calculation ////////////////// std::shared_ptr expr_probe; - ASSERT_NOT_OK(CreateCodeGenerator(schema_table_0, {probeArrays_expr}, {f_indices}, - &expr_probe, true)); - std::shared_ptr expr_conditioned_shuffle; - ASSERT_NOT_OK( - CreateCodeGenerator(schema_table, {conditionShuffleExpr}, - {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, - &expr_conditioned_shuffle, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + schema_table_0, {probeArrays_expr}, + {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, &expr_probe, true)); + std::shared_ptr input_batch; std::vector> dummy_result_batches; std::shared_ptr> probe_result_iterator; - std::shared_ptr> shuffle_result_iterator; std::vector> table_0; std::vector> table_1; @@ -378,11 +328,8 @@ TEST(TestArrowCompute, JoinTestUsingOuterJoin) { ////////////////////// evaluate ////////////////////// for (auto batch : table_0) { ASSERT_NOT_OK(expr_probe->evaluate(batch, &dummy_result_batches)); - ASSERT_NOT_OK(expr_conditioned_shuffle->evaluate(batch, &dummy_result_batches)); } ASSERT_NOT_OK(expr_probe->finish(&probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->SetDependency(probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->finish(&shuffle_result_iterator)); for (int i = 0; i < 2; i++) { auto left_batch = table_0[i]; @@ -394,13 +341,12 @@ TEST(TestArrowCompute, JoinTestUsingOuterJoin) { input.push_back(right_batch->column(i)); } - ASSERT_NOT_OK(probe_result_iterator->ProcessAndCacheOne(input)); - ASSERT_NOT_OK(shuffle_result_iterator->Process(input, &result_batch)); + ASSERT_NOT_OK(probe_result_iterator->Process(input, &result_batch)); ASSERT_NOT_OK(Equals(*(expected_table[i]).get(), *result_batch.get())); } } -TEST(TestArrowCompute, JoinTestUsingAntiJoin) { +TEST(TestArrowComputeJoin, JoinTestUsingAntiJoin) { ////////////////////// prepare expr_vector /////////////////////// auto table0_f0 = field("table0_f0", uint32()); auto table0_f1 = field("table0_f1", uint32()); @@ -408,9 +354,6 @@ TEST(TestArrowCompute, JoinTestUsingAntiJoin) { auto table1_f0 = field("table1_f0", uint32()); auto table1_f1 = field("table1_f1", uint32()); - auto indices_type = std::make_shared(4); - auto f_indices = field("indices", indices_type); - auto n_left = TreeExprBuilder::MakeFunction( "codegen_left_schema", {TreeExprBuilder::MakeField(table0_f0), TreeExprBuilder::MakeField(table0_f1), @@ -426,37 +369,24 @@ TEST(TestArrowCompute, JoinTestUsingAntiJoin) { "codegen_left_key_schema", {TreeExprBuilder::MakeField(table0_f0)}, uint32()); auto n_right_key = TreeExprBuilder::MakeFunction( "codegen_right_key_schema", {TreeExprBuilder::MakeField(table1_f0)}, uint32()); - auto n_probeArrays = TreeExprBuilder::MakeFunction( - "conditionedProbeArraysAnti", {n_left_key, n_right_key}, indices_type); + auto n_probeArrays = TreeExprBuilder::MakeFunction("conditionedProbeArraysAnti", + {n_left_key, n_right_key}, uint32()); auto n_codegen_probe = TreeExprBuilder::MakeFunction( "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_res); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); - - auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); - auto schema_table_0 = arrow::schema({table0_f0, table0_f1, table0_f2}); auto schema_table_1 = arrow::schema({table1_f0, table1_f1}); auto schema_table = arrow::schema({table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}); ///////////////////// Calculation ////////////////// std::shared_ptr expr_probe; - ASSERT_NOT_OK(CreateCodeGenerator(schema_table_0, {probeArrays_expr}, {f_indices}, - &expr_probe, true)); - std::shared_ptr expr_conditioned_shuffle; - ASSERT_NOT_OK(CreateCodeGenerator(schema_table, {conditionShuffleExpr}, - {table1_f0, table1_f1}, &expr_conditioned_shuffle, - true)); + ASSERT_NOT_OK(CreateCodeGenerator(schema_table_0, {probeArrays_expr}, + {table1_f0, table1_f1}, &expr_probe, true)); std::shared_ptr input_batch; std::vector> dummy_result_batches; std::shared_ptr> probe_result_iterator; - std::shared_ptr> shuffle_result_iterator; std::vector> table_0; std::vector> table_1; @@ -496,11 +426,8 @@ TEST(TestArrowCompute, JoinTestUsingAntiJoin) { ////////////////////// evaluate ////////////////////// for (auto batch : table_0) { ASSERT_NOT_OK(expr_probe->evaluate(batch, &dummy_result_batches)); - ASSERT_NOT_OK(expr_conditioned_shuffle->evaluate(batch, &dummy_result_batches)); } ASSERT_NOT_OK(expr_probe->finish(&probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->SetDependency(probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->finish(&shuffle_result_iterator)); for (int i = 0; i < 2; i++) { auto left_batch = table_0[i]; @@ -512,13 +439,12 @@ TEST(TestArrowCompute, JoinTestUsingAntiJoin) { input.push_back(right_batch->column(i)); } - ASSERT_NOT_OK(probe_result_iterator->ProcessAndCacheOne(input)); - ASSERT_NOT_OK(shuffle_result_iterator->Process(input, &result_batch)); + ASSERT_NOT_OK(probe_result_iterator->Process(input, &result_batch)); ASSERT_NOT_OK(Equals(*(expected_table[i]).get(), *result_batch.get())); } } -TEST(TestArrowCompute, JoinTestUsingInnerJoinWithCondition) { +TEST(TestArrowComputeJoin, JoinTestUsingInnerJoinWithCondition) { ////////////////////// prepare expr_vector /////////////////////// auto table0_f0 = field("table0_f0", uint32()); auto table0_f1 = field("table0_f1", uint32()); @@ -526,8 +452,6 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoinWithCondition) { auto table1_f0 = field("table1_f0", uint32()); auto table1_f1 = field("table1_f1", uint32()); - auto indices_type = std::make_shared(4); - auto f_indices = field("indices", indices_type); auto greater_than_function = TreeExprBuilder::MakeFunction( "greater_than", {TreeExprBuilder::MakeField(table0_f1), TreeExprBuilder::MakeField(table1_f1)}, @@ -549,37 +473,25 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoinWithCondition) { "codegen_right_schema", {TreeExprBuilder::MakeField(table1_f0)}, uint32()); auto n_probeArrays = TreeExprBuilder::MakeFunction( "conditionedProbeArraysInner", {n_left_key, n_right_key, greater_than_function}, - indices_type); + uint32()); auto n_codegen_probe = TreeExprBuilder::MakeFunction( "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_res); - auto n_conditionedShuffleArrayList = - TreeExprBuilder::MakeFunction("conditionedShuffleArrayList", {}, uint32()); - auto n_codegen_shuffle = TreeExprBuilder::MakeFunction( - "codegen_withTwoInputs", {n_conditionedShuffleArrayList, n_left, n_right}, - uint32()); - - auto conditionShuffleExpr = TreeExprBuilder::MakeExpression(n_codegen_shuffle, f_res); - auto schema_table_0 = arrow::schema({table0_f0, table0_f1, table0_f2}); auto schema_table_1 = arrow::schema({table1_f0, table1_f1}); auto schema_table = arrow::schema({table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}); ///////////////////// Calculation ////////////////// std::shared_ptr expr_probe; - ASSERT_NOT_OK(CreateCodeGenerator(schema_table_0, {probeArrays_expr}, {f_indices}, - &expr_probe, true)); - std::shared_ptr expr_conditioned_shuffle; - ASSERT_NOT_OK( - CreateCodeGenerator(schema_table, {conditionShuffleExpr}, - {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, - &expr_conditioned_shuffle, true)); + ASSERT_NOT_OK(CreateCodeGenerator( + schema_table_0, {probeArrays_expr}, + {table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}, &expr_probe, true)); + std::shared_ptr input_batch; std::vector> dummy_result_batches; std::shared_ptr> probe_result_iterator; - std::shared_ptr> shuffle_result_iterator; std::vector> table_0; std::vector> table_1; @@ -620,11 +532,8 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoinWithCondition) { ////////////////////// evaluate ////////////////////// for (auto batch : table_0) { ASSERT_NOT_OK(expr_probe->evaluate(batch, &dummy_result_batches)); - ASSERT_NOT_OK(expr_conditioned_shuffle->evaluate(batch, &dummy_result_batches)); } ASSERT_NOT_OK(expr_probe->finish(&probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->SetDependency(probe_result_iterator)); - ASSERT_NOT_OK(expr_conditioned_shuffle->finish(&shuffle_result_iterator)); for (int i = 0; i < 2; i++) { auto left_batch = table_0[i]; @@ -636,8 +545,110 @@ TEST(TestArrowCompute, JoinTestUsingInnerJoinWithCondition) { input.push_back(right_batch->column(i)); } - ASSERT_NOT_OK(probe_result_iterator->ProcessAndCacheOne(input)); - ASSERT_NOT_OK(shuffle_result_iterator->Process(input, &result_batch)); + ASSERT_NOT_OK(probe_result_iterator->Process(input, &result_batch)); + ASSERT_NOT_OK(Equals(*(expected_table[i]).get(), *result_batch.get())); + } +} + +TEST(TestArrowComputeJoin, JoinTestUsingAntiJoinWithCondition) { + ////////////////////// prepare expr_vector /////////////////////// + auto table0_f0 = field("table0_f0", uint32()); + auto table0_f1 = field("table0_f1", uint32()); + auto table0_f2 = field("table0_f2", uint32()); + auto table1_f0 = field("table1_f0", uint32()); + auto table1_f1 = field("table1_f1", uint32()); + + auto greater_than_function = TreeExprBuilder::MakeFunction( + "greater_than", + {TreeExprBuilder::MakeField(table0_f1), TreeExprBuilder::MakeField(table1_f1)}, + arrow::boolean()); + auto n_left = TreeExprBuilder::MakeFunction( + "codegen_left_schema", + {TreeExprBuilder::MakeField(table0_f0), TreeExprBuilder::MakeField(table0_f1), + TreeExprBuilder::MakeField(table0_f2)}, + uint32()); + auto n_right = TreeExprBuilder::MakeFunction( + "codegen_right_schema", + {TreeExprBuilder::MakeField(table1_f0), TreeExprBuilder::MakeField(table1_f1)}, + uint32()); + auto f_res = field("res", uint32()); + + auto n_left_key = TreeExprBuilder::MakeFunction( + "codegen_left_key_schema", {TreeExprBuilder::MakeField(table0_f0)}, uint32()); + auto n_right_key = TreeExprBuilder::MakeFunction( + "codegen_right_key_schema", {TreeExprBuilder::MakeField(table1_f0)}, uint32()); + auto n_probeArrays = TreeExprBuilder::MakeFunction( + "conditionedProbeArraysAnti", {n_left_key, n_right_key, greater_than_function}, + uint32()); + auto n_codegen_probe = TreeExprBuilder::MakeFunction( + "codegen_withTwoInputs", {n_probeArrays, n_left, n_right}, uint32()); + auto probeArrays_expr = TreeExprBuilder::MakeExpression(n_codegen_probe, f_res); + + auto schema_table_0 = arrow::schema({table0_f0, table0_f1, table0_f2}); + auto schema_table_1 = arrow::schema({table1_f0, table1_f1}); + auto schema_table = + arrow::schema({table0_f0, table0_f1, table0_f2, table1_f0, table1_f1}); + ///////////////////// Calculation ////////////////// + std::shared_ptr expr_probe; + ASSERT_NOT_OK(CreateCodeGenerator(schema_table_0, {probeArrays_expr}, + {table1_f0, table1_f1}, &expr_probe, true)); + std::shared_ptr input_batch; + + std::vector> dummy_result_batches; + std::shared_ptr> probe_result_iterator; + + std::vector> table_0; + std::vector> table_1; + + std::vector input_data_string = { + "[10, 3, 1, 2, 3, 1]", "[10, 3, 1, 2, 13, 11]", "[10, 3, 1, 2, 13, 11]"}; + MakeInputBatch(input_data_string, schema_table_0, &input_batch); + table_0.push_back(input_batch); + + input_data_string = {"[6, 12, 5, 8, 6, 10]", "[6, 12, 5, 8, 16, 110]", + "[6, 12, 5, 8, 16, 110]"}; + MakeInputBatch(input_data_string, schema_table_0, &input_batch); + table_0.push_back(input_batch); + + std::vector input_data_2_string = {"[1, 2, 3, 4, 5, 6]", + "[1, 2, 3, 4, 5, 6]"}; + MakeInputBatch(input_data_2_string, schema_table_1, &input_batch); + table_1.push_back(input_batch); + + input_data_2_string = {"[7, 8, 9, 10, 11, 12]", "[7, 8, 9, 10, 11, 12]"}; + MakeInputBatch(input_data_2_string, schema_table_1, &input_batch); + table_1.push_back(input_batch); + + //////////////////////// data prepared ///////////////////////// + + std::vector> expected_table; + std::shared_ptr expected_result; + auto res_sch = arrow::schema({f_res, f_res}); + std::vector expected_result_string = {"[2, 4, 5]", "[2, 4, 5]"}; + MakeInputBatch(expected_result_string, res_sch, &expected_result); + expected_table.push_back(expected_result); + + expected_result_string = {"[7, 8, 9, 11, 12]", "[7, 8, 9, 11, 12]"}; + MakeInputBatch(expected_result_string, res_sch, &expected_result); + expected_table.push_back(expected_result); + + ////////////////////// evaluate ////////////////////// + for (auto batch : table_0) { + ASSERT_NOT_OK(expr_probe->evaluate(batch, &dummy_result_batches)); + } + ASSERT_NOT_OK(expr_probe->finish(&probe_result_iterator)); + + for (int i = 0; i < 2; i++) { + auto left_batch = table_0[i]; + auto right_batch = table_1[i]; + + std::shared_ptr result_batch; + std::vector> input; + for (int i = 0; i < right_batch->num_columns(); i++) { + input.push_back(right_batch->column(i)); + } + + ASSERT_NOT_OK(probe_result_iterator->Process(input, &result_batch)); ASSERT_NOT_OK(Equals(*(expected_table[i]).get(), *result_batch.get())); } } diff --git a/oap-native-sql/cpp/src/tests/shuffle_split_test.cc b/oap-native-sql/cpp/src/tests/shuffle_split_test.cc new file mode 100644 index 000000000..51e4c6ca3 --- /dev/null +++ b/oap-native-sql/cpp/src/tests/shuffle_split_test.cc @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "shuffle/splitter.h" +#include "shuffle/type.h" +#include "tests/test_utils.h" + +namespace sparkcolumnarplugin { +namespace shuffle { + +class ShuffleTest : public ::testing::Test { + protected: + void SetUp() { + + auto f_pid = field("f_pid", arrow::int32()); + auto f_na = field("f_na", arrow::null()); + auto f_int8 = field("f_int8", arrow::int8()); + auto f_int16 = field("f_int16", arrow::int16()); + auto f_uint64 = field("f_uint64", arrow::uint64()); + auto f_bool = field("f_bool", arrow::boolean()); + auto f_string = field("f_string", arrow::utf8()); + + std::shared_ptr tmp_dir1; + std::shared_ptr tmp_dir2; + ARROW_ASSIGN_OR_THROW(tmp_dir1, std::move(arrow::internal::TemporaryDir::Make(tmp_dir_prefix))) + ARROW_ASSIGN_OR_THROW(tmp_dir2, std::move(arrow::internal::TemporaryDir::Make(tmp_dir_prefix))) + auto config_dirs = tmp_dir1->path().ToString() + "," + tmp_dir2->path().ToString(); + + setenv("NATIVESQL_SPARK_LOCAL_DIRS", config_dirs.c_str(), 1); + + schema_ = arrow::schema({f_pid, f_na, f_int8, f_int16, f_uint64, f_bool, f_string}); + ARROW_ASSIGN_OR_THROW(writer_schema_, schema_->RemoveField(0)) + + ARROW_ASSIGN_OR_THROW(splitter_, Splitter::Make(schema_)); + } + + void TearDown() { ASSERT_NOT_OK(splitter_->Stop()); } + + std::string tmp_dir_prefix = "columnar-shuffle-test"; + + std::string c_pid_ = "[1, 2, 1, 10]"; + std::vector input_data_ = {c_pid_, + "[null, null, null, null]", + "[1, 2, 3, null]", + "[1, -1, null, null]", + "[null, null, null, null]", + "[null, 1, 0, null]", + R"(["alice", "bob", null, null])"}; + + std::shared_ptr schema_; + std::shared_ptr writer_schema_; + std::shared_ptr splitter_; +}; + +TEST_F(ShuffleTest, TestSplitterSchema) { ASSERT_EQ(*schema_, *splitter_->schema()); } + +TEST_F(ShuffleTest, TestSplitterTypeId) { + ASSERT_EQ(splitter_->column_type_id(0), Type::SHUFFLE_NULL); + ASSERT_EQ(splitter_->column_type_id(1), Type::SHUFFLE_1BYTE); + ASSERT_EQ(splitter_->column_type_id(2), Type::SHUFFLE_2BYTE); + ASSERT_EQ(splitter_->column_type_id(3), Type::SHUFFLE_8BYTE); + ASSERT_EQ(splitter_->column_type_id(4), Type::SHUFFLE_BIT); +} + +TEST_F(ShuffleTest, TestWriterAfterSplit) { + std::shared_ptr input_batch; + MakeInputBatch(input_data_, schema_, &input_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + + ASSERT_NE(splitter_->writer(1), nullptr); + ASSERT_NE(splitter_->writer(2), nullptr); + ASSERT_NE(splitter_->writer(10), nullptr); + ASSERT_EQ(splitter_->writer(100), nullptr); + + ASSERT_EQ(splitter_->writer(1)->pid(), 1); + ASSERT_EQ(splitter_->writer(2)->pid(), 2); + ASSERT_EQ(splitter_->writer(10)->pid(), 10); + + ASSERT_EQ(splitter_->writer(1)->capacity(), kDefaultSplitterBufferSize); + + ASSERT_EQ(splitter_->writer(1)->write_offset(), 2); + ASSERT_EQ(splitter_->writer(2)->write_offset(), 1); + ASSERT_EQ(splitter_->writer(10)->write_offset(), 1); +} + +TEST_F(ShuffleTest, TestLastType) { + std::shared_ptr input_batch; + MakeInputBatch(input_data_, schema_, &input_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->writer(1)->last_type(), Type::SHUFFLE_BINARY); + ASSERT_EQ(splitter_->writer(2)->last_type(), Type::SHUFFLE_BINARY); + ASSERT_EQ(splitter_->writer(10)->last_type(), Type::SHUFFLE_BINARY); +} + +TEST_F(ShuffleTest, TestMultipleInput) { + std::shared_ptr input_batch; + MakeInputBatch(input_data_, schema_, &input_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->writer(1)->write_offset(), 2); + ASSERT_EQ(splitter_->writer(2)->write_offset(), 1); + ASSERT_EQ(splitter_->writer(10)->write_offset(), 1); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + + ASSERT_EQ(splitter_->writer(1)->write_offset(), 6); + ASSERT_EQ(splitter_->writer(2)->write_offset(), 3); + ASSERT_EQ(splitter_->writer(10)->write_offset(), 3); +} + +TEST_F(ShuffleTest, TestCustomBufferSize) { + int64_t buffer_size = 2; + splitter_->set_buffer_size(buffer_size); + + std::shared_ptr input_batch; + MakeInputBatch(input_data_, schema_, &input_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->writer(1)->write_offset(), 2); + ASSERT_EQ(splitter_->writer(2)->write_offset(), 1); + ASSERT_EQ(splitter_->writer(10)->write_offset(), 1); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->writer(1)->write_offset(), 2); + ASSERT_EQ(splitter_->writer(2)->write_offset(), 2); + ASSERT_EQ(splitter_->writer(10)->write_offset(), 2); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->writer(1)->write_offset(), 2); + ASSERT_EQ(splitter_->writer(2)->write_offset(), 1); + ASSERT_EQ(splitter_->writer(10)->write_offset(), 1); +} + +TEST_F(ShuffleTest, TestCreateTempFile) { + std::shared_ptr input_batch; + MakeInputBatch(input_data_, schema_, &input_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->GetPartitionFileInfo().size(), 3); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->GetPartitionFileInfo().size(), 3); + + MakeInputBatch({"[100]", "[null]", "[null]", "[null]", "[null]", "[null]", "[null]"}, + schema_, &input_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_EQ(splitter_->GetPartitionFileInfo().size(), 4); + + auto pfn0 = splitter_->GetPartitionFileInfo()[0].second; + auto pfn1 = splitter_->GetPartitionFileInfo()[1].second; + auto pfn2 = splitter_->GetPartitionFileInfo()[2].second; + auto pfn3 = splitter_->GetPartitionFileInfo()[3].second; + ASSERT_EQ(*arrow::internal::FileExists(*arrow::internal::PlatformFilename::FromString(pfn0)), true); + ASSERT_EQ(*arrow::internal::FileExists(*arrow::internal::PlatformFilename::FromString(pfn1)), true); + ASSERT_EQ(*arrow::internal::FileExists(*arrow::internal::PlatformFilename::FromString(pfn2)), true); + ASSERT_EQ(*arrow::internal::FileExists(*arrow::internal::PlatformFilename::FromString(pfn3)), true); + + ASSERT_NE(pfn0.find(tmp_dir_prefix), std::string::npos); + ASSERT_NE(pfn1.find(tmp_dir_prefix), std::string::npos); + ASSERT_NE(pfn2.find(tmp_dir_prefix), std::string::npos); + ASSERT_NE(pfn3.find(tmp_dir_prefix), std::string::npos); +} + +TEST_F(ShuffleTest, TestWriterMakeArrowRecordBatch) { + int64_t buffer_size = 2; + splitter_->set_buffer_size(buffer_size); + + std::vector output_data = {"[null, null]", "[1, 3]", + "[1, null]", "[null, null]", + "[null, 0]", R"(["alice", null])"}; + + std::shared_ptr input_batch; + std::shared_ptr output_batch; + MakeInputBatch(input_data_, schema_, &input_batch); + MakeInputBatch(output_data, writer_schema_, &output_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + ASSERT_NOT_OK(splitter_->Split(*input_batch)); + + ASSERT_NOT_OK(splitter_->Stop()); + + std::shared_ptr file_in; + std::shared_ptr file_reader; + ARROW_ASSIGN_OR_THROW(file_in, + arrow::io::ReadableFile::Open(splitter_->writer(1)->file_path())) + + ARROW_ASSIGN_OR_THROW(file_reader, arrow::ipc::RecordBatchStreamReader::Open(file_in)) + ASSERT_EQ(*file_reader->schema(), *writer_schema_); + + int num_rb = 3; + for (int i = 0; i < num_rb; ++i) { + std::shared_ptr rb; + ASSERT_NOT_OK(file_reader->ReadNext(&rb)); + ASSERT_NOT_OK(Equals(*output_batch, *rb)); + } + ASSERT_NOT_OK(file_in->Close()) +} + +TEST_F(ShuffleTest, TestCustomCompressionCodec) { + auto compression_codec = arrow::Compression::LZ4_FRAME; + splitter_->set_compression_codec(compression_codec); + + std::vector output_data = {"[null, null]", "[1, 3]", + "[1, null]", "[null, null]", + "[null, 0]", R"(["alice", null])"}; + + std::shared_ptr input_batch; + std::shared_ptr output_batch; + MakeInputBatch(input_data_, schema_, &input_batch); + MakeInputBatch(output_data, writer_schema_, &output_batch); + + ASSERT_NOT_OK(splitter_->Split(*input_batch)) + ASSERT_NOT_OK(splitter_->Stop()) + + std::shared_ptr file_in; + std::shared_ptr file_reader; + ARROW_ASSIGN_OR_THROW(file_in, + arrow::io::ReadableFile::Open(splitter_->writer(1)->file_path())) + + ARROW_ASSIGN_OR_THROW(file_reader, arrow::ipc::RecordBatchStreamReader::Open(file_in)) + ASSERT_EQ(*file_reader->schema(), *writer_schema_); + + std::shared_ptr rb; + ASSERT_NOT_OK(file_reader->ReadNext(&rb)); + ASSERT_NOT_OK(Equals(*rb, *output_batch)); + + ASSERT_NOT_OK(file_in->Close()) +} + +} // namespace shuffle +} // namespace sparkcolumnarplugin diff --git a/oap-native-sql/cpp/src/third_party/ska_sort.hpp b/oap-native-sql/cpp/src/third_party/ska_sort.hpp new file mode 100644 index 000000000..81a9ef2b7 --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/ska_sort.hpp @@ -0,0 +1,1445 @@ +// Copyright Malte Skarupke 2016. +// Distributed under the Boost Software License, Version 1.0. +// (See http://www.boost.org/LICENSE_1_0.txt) + +#pragma once + +#include +#include +#include +#include +#include + +namespace detail +{ +template +void counting_sort_impl(It begin, It end, OutIt out_begin, ExtractKey && extract_key) +{ + count_type counts[256] = {}; + for (It it = begin; it != end; ++it) + { + ++counts[extract_key(*it)]; + } + count_type total = 0; + for (count_type & count : counts) + { + count_type old_count = count; + count = total; + total += old_count; + } + for (; begin != end; ++begin) + { + std::uint8_t key = extract_key(*begin); + out_begin[counts[key]++] = std::move(*begin); + } +} +template +void counting_sort_impl(It begin, It end, OutIt out_begin, ExtractKey && extract_key) +{ + counting_sort_impl(begin, end, out_begin, extract_key); +} +inline bool to_unsigned_or_bool(bool b) +{ + return b; +} +inline unsigned char to_unsigned_or_bool(unsigned char c) +{ + return c; +} +inline unsigned char to_unsigned_or_bool(signed char c) +{ + return static_cast(c) + 128; +} +inline unsigned char to_unsigned_or_bool(char c) +{ + return static_cast(c); +} +inline std::uint16_t to_unsigned_or_bool(char16_t c) +{ + return static_cast(c); +} +inline std::uint32_t to_unsigned_or_bool(char32_t c) +{ + return static_cast(c); +} +inline std::uint32_t to_unsigned_or_bool(wchar_t c) +{ + return static_cast(c); +} +inline unsigned short to_unsigned_or_bool(short i) +{ + return static_cast(i) + static_cast(1 << (sizeof(short) * 8 - 1)); +} +inline unsigned short to_unsigned_or_bool(unsigned short i) +{ + return i; +} +inline unsigned int to_unsigned_or_bool(int i) +{ + return static_cast(i) + static_cast(1 << (sizeof(int) * 8 - 1)); +} +inline unsigned int to_unsigned_or_bool(unsigned int i) +{ + return i; +} +inline unsigned long to_unsigned_or_bool(long l) +{ + return static_cast(l) + static_cast(1l << (sizeof(long) * 8 - 1)); +} +inline unsigned long to_unsigned_or_bool(unsigned long l) +{ + return l; +} +inline unsigned long long to_unsigned_or_bool(long long l) +{ + return static_cast(l) + static_cast(1ll << (sizeof(long long) * 8 - 1)); +} +inline unsigned long long to_unsigned_or_bool(unsigned long long l) +{ + return l; +} +inline std::uint32_t to_unsigned_or_bool(float f) +{ + union + { + float f; + std::uint32_t u; + } as_union = { f }; + std::uint32_t sign_bit = -std::int32_t(as_union.u >> 31); + return as_union.u ^ (sign_bit | 0x80000000); +} +inline std::uint64_t to_unsigned_or_bool(double f) +{ + union + { + double d; + std::uint64_t u; + } as_union = { f }; + std::uint64_t sign_bit = -std::int64_t(as_union.u >> 63); + return as_union.u ^ (sign_bit | 0x8000000000000000); +} +template +inline size_t to_unsigned_or_bool(T * ptr) +{ + return reinterpret_cast(ptr); +} + +template +struct SizedRadixSorter; + +template<> +struct SizedRadixSorter<1> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + counting_sort_impl(begin, end, buffer_begin, [&](auto && o) + { + return to_unsigned_or_bool(extract_key(o)); + }); + return true; + } + + static constexpr size_t pass_count = 2; +}; +template<> +struct SizedRadixSorter<2> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + std::ptrdiff_t num_elements = end - begin; + if (num_elements <= (1ll << 32)) + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + else + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + } + + template + static bool sort_inline(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + count_type counts0[256] = {}; + count_type counts1[256] = {}; + + for (It it = begin; it != end; ++it) + { + uint16_t key = to_unsigned_or_bool(extract_key(*it)); + ++counts0[key & 0xff]; + ++counts1[(key >> 8) & 0xff]; + } + count_type total0 = 0; + count_type total1 = 0; + for (int i = 0; i < 256; ++i) + { + count_type old_count0 = counts0[i]; + count_type old_count1 = counts1[i]; + counts0[i] = total0; + counts1[i] = total1; + total0 += old_count0; + total1 += old_count1; + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)); + out_begin[counts0[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 8; + begin[counts1[key]++] = std::move(*it); + } + return false; + } + + static constexpr size_t pass_count = 3; +}; +template<> +struct SizedRadixSorter<4> +{ + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + std::ptrdiff_t num_elements = end - begin; + if (num_elements <= (1ll << 32)) + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + else + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + } + template + static bool sort_inline(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + count_type counts0[256] = {}; + count_type counts1[256] = {}; + count_type counts2[256] = {}; + count_type counts3[256] = {}; + + for (It it = begin; it != end; ++it) + { + uint32_t key = to_unsigned_or_bool(extract_key(*it)); + ++counts0[key & 0xff]; + ++counts1[(key >> 8) & 0xff]; + ++counts2[(key >> 16) & 0xff]; + ++counts3[(key >> 24) & 0xff]; + } + count_type total0 = 0; + count_type total1 = 0; + count_type total2 = 0; + count_type total3 = 0; + for (int i = 0; i < 256; ++i) + { + count_type old_count0 = counts0[i]; + count_type old_count1 = counts1[i]; + count_type old_count2 = counts2[i]; + count_type old_count3 = counts3[i]; + counts0[i] = total0; + counts1[i] = total1; + counts2[i] = total2; + counts3[i] = total3; + total0 += old_count0; + total1 += old_count1; + total2 += old_count2; + total3 += old_count3; + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)); + out_begin[counts0[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 8; + begin[counts1[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 16; + out_begin[counts2[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 24; + begin[counts3[key]++] = std::move(*it); + } + return false; + } + + static constexpr size_t pass_count = 5; +}; +template<> +struct SizedRadixSorter<8> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + std::ptrdiff_t num_elements = end - begin; + if (num_elements <= (1ll << 32)) + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + else + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + } + template + static bool sort_inline(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + count_type counts0[256] = {}; + count_type counts1[256] = {}; + count_type counts2[256] = {}; + count_type counts3[256] = {}; + count_type counts4[256] = {}; + count_type counts5[256] = {}; + count_type counts6[256] = {}; + count_type counts7[256] = {}; + + for (It it = begin; it != end; ++it) + { + uint64_t key = to_unsigned_or_bool(extract_key(*it)); + ++counts0[key & 0xff]; + ++counts1[(key >> 8) & 0xff]; + ++counts2[(key >> 16) & 0xff]; + ++counts3[(key >> 24) & 0xff]; + ++counts4[(key >> 32) & 0xff]; + ++counts5[(key >> 40) & 0xff]; + ++counts6[(key >> 48) & 0xff]; + ++counts7[(key >> 56) & 0xff]; + } + count_type total0 = 0; + count_type total1 = 0; + count_type total2 = 0; + count_type total3 = 0; + count_type total4 = 0; + count_type total5 = 0; + count_type total6 = 0; + count_type total7 = 0; + for (int i = 0; i < 256; ++i) + { + count_type old_count0 = counts0[i]; + count_type old_count1 = counts1[i]; + count_type old_count2 = counts2[i]; + count_type old_count3 = counts3[i]; + count_type old_count4 = counts4[i]; + count_type old_count5 = counts5[i]; + count_type old_count6 = counts6[i]; + count_type old_count7 = counts7[i]; + counts0[i] = total0; + counts1[i] = total1; + counts2[i] = total2; + counts3[i] = total3; + counts4[i] = total4; + counts5[i] = total5; + counts6[i] = total6; + counts7[i] = total7; + total0 += old_count0; + total1 += old_count1; + total2 += old_count2; + total3 += old_count3; + total4 += old_count4; + total5 += old_count5; + total6 += old_count6; + total7 += old_count7; + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)); + out_begin[counts0[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 8; + begin[counts1[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 16; + out_begin[counts2[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 24; + begin[counts3[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 32; + out_begin[counts4[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 40; + begin[counts5[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 48; + out_begin[counts6[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 56; + begin[counts7[key]++] = std::move(*it); + } + return false; + } + + static constexpr size_t pass_count = 9; +}; + +template +struct RadixSorter; +template<> +struct RadixSorter +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + size_t false_count = 0; + for (It it = begin; it != end; ++it) + { + if (!extract_key(*it)) + ++false_count; + } + size_t true_position = false_count; + false_count = 0; + for (; begin != end; ++begin) + { + if (extract_key(*begin)) + buffer_begin[true_position++] = std::move(*begin); + else + buffer_begin[false_count++] = std::move(*begin); + } + return true; + } + + static constexpr size_t pass_count = 2; +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template +struct RadixSorter> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + bool first_result = RadixSorter::sort(begin, end, buffer_begin, [&](auto && o) + { + return extract_key(o).second; + }); + auto extract_first = [&](auto && o) + { + return extract_key(o).first; + }; + + if (first_result) + { + return !RadixSorter::sort(buffer_begin, buffer_begin + (end - begin), begin, extract_first); + } + else + { + return RadixSorter::sort(begin, end, buffer_begin, extract_first); + } + } + + static constexpr size_t pass_count = RadixSorter::pass_count + RadixSorter::pass_count; +}; +template +struct RadixSorter &> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + bool first_result = RadixSorter::sort(begin, end, buffer_begin, [&](auto && o) -> const V & + { + return extract_key(o).second; + }); + auto extract_first = [&](auto && o) -> const K & + { + return extract_key(o).first; + }; + + if (first_result) + { + return !RadixSorter::sort(buffer_begin, buffer_begin + (end - begin), begin, extract_first); + } + else + { + return RadixSorter::sort(begin, end, buffer_begin, extract_first); + } + } + + static constexpr size_t pass_count = RadixSorter::pass_count + RadixSorter::pass_count; +}; +template +struct TupleRadixSorter +{ + using NextSorter = TupleRadixSorter; + using ThisSorter = RadixSorter::type>; + + template + static bool sort(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + bool which = NextSorter::sort(begin, end, out_begin, out_end, extract_key); + auto extract_i = [&](auto && o) + { + return std::get(extract_key(o)); + }; + if (which) + return !ThisSorter::sort(out_begin, out_end, begin, extract_i); + else + return ThisSorter::sort(begin, end, out_begin, extract_i); + } + + static constexpr size_t pass_count = ThisSorter::pass_count + NextSorter::pass_count; +}; +template +struct TupleRadixSorter +{ + using NextSorter = TupleRadixSorter; + using ThisSorter = RadixSorter::type>; + + template + static bool sort(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + bool which = NextSorter::sort(begin, end, out_begin, out_end, extract_key); + auto extract_i = [&](auto && o) -> decltype(auto) + { + return std::get(extract_key(o)); + }; + if (which) + return !ThisSorter::sort(out_begin, out_end, begin, extract_i); + else + return ThisSorter::sort(begin, end, out_begin, extract_i); + } + + static constexpr size_t pass_count = ThisSorter::pass_count + NextSorter::pass_count; +}; +template +struct TupleRadixSorter +{ + template + static bool sort(It, It, OutIt, OutIt, ExtractKey &&) + { + return false; + } + + static constexpr size_t pass_count = 0; +}; +template +struct TupleRadixSorter +{ + template + static bool sort(It, It, OutIt, OutIt, ExtractKey &&) + { + return false; + } + + static constexpr size_t pass_count = 0; +}; + +template +struct RadixSorter> +{ + using SorterImpl = TupleRadixSorter<0, sizeof...(Args), std::tuple>; + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + return SorterImpl::sort(begin, end, buffer_begin, buffer_begin + (end - begin), extract_key); + } + + static constexpr size_t pass_count = SorterImpl::pass_count; +}; + +template +struct RadixSorter &> +{ + using SorterImpl = TupleRadixSorter<0, sizeof...(Args), const std::tuple &>; + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + return SorterImpl::sort(begin, end, buffer_begin, buffer_begin + (end - begin), extract_key); + } + + static constexpr size_t pass_count = SorterImpl::pass_count; +}; + +template +struct RadixSorter> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + auto buffer_end = buffer_begin + (end - begin); + bool which = false; + for (size_t i = S; i > 0; --i) + { + auto extract_i = [&, i = i - 1](auto && o) + { + return extract_key(o)[i]; + }; + if (which) + which = !RadixSorter::sort(buffer_begin, buffer_end, begin, extract_i); + else + which = RadixSorter::sort(begin, end, buffer_begin, extract_i); + } + return which; + } + + static constexpr size_t pass_count = RadixSorter::pass_count * S; +}; + +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +// these structs serve two purposes +// 1. they serve as illustration for how to implement the to_radix_sort_key function +// 2. they help produce better error messages. with these overloads you get the +// error message "no matching function for call to to_radix_sort(your_type)" +// without these examples, you'd get the error message "to_radix_sort_key was +// not declared in this scope" which is a much less useful error message +struct ExampleStructA { int i; }; +struct ExampleStructB { float f; }; +inline int to_radix_sort_key(ExampleStructA a) { return a.i; } +inline float to_radix_sort_key(ExampleStructB b) { return b.f; } +template +struct FallbackRadixSorter : RadixSorter()))> +{ + using base = RadixSorter()))>; + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + return base::sort(begin, end, buffer_begin, [&](auto && a) -> decltype(auto) + { + return to_radix_sort_key(extract_key(a)); + }); + } +}; + +template +struct nested_void +{ + using type = void; +}; + +template +using void_t = typename nested_void::type; + +template +struct has_subscript_operator_impl +{ + template()[0])> + static std::true_type test(int); + template + static std::false_type test(...); + + using type = decltype(test(0)); +}; + +template +using has_subscript_operator = typename has_subscript_operator_impl::type; + + +template +struct FallbackRadixSorter()))>> + : RadixSorter()))> +{ +}; + +template +struct RadixSorter : FallbackRadixSorter +{ +}; + +template +size_t radix_sort_pass_count = RadixSorter::pass_count; + +template +inline void unroll_loop_four_times(It begin, size_t iteration_count, Func && to_call) +{ + size_t loop_count = iteration_count / 4; + size_t remainder_count = iteration_count - loop_count * 4; + for (; loop_count > 0; --loop_count) + { + to_call(begin); + ++begin; + to_call(begin); + ++begin; + to_call(begin); + ++begin; + to_call(begin); + ++begin; + } + switch(remainder_count) + { + case 3: + to_call(begin); + ++begin; + case 2: + to_call(begin); + ++begin; + case 1: + to_call(begin); + } +} + +template +inline It custom_std_partition(It begin, It end, F && func) +{ + for (;; ++begin) + { + if (begin == end) + return end; + if (!func(*begin)) + break; + } + It it = begin; + for(++it; it != end; ++it) + { + if (!func(*it)) + continue; + + std::iter_swap(begin, it); + ++begin; + } + return begin; +} + +struct PartitionInfo +{ + PartitionInfo() + : count(0) + { + } + + union + { + size_t count; + size_t offset; + }; + size_t next_offset; +}; + +template +struct UnsignedForSize; +template<> +struct UnsignedForSize<1> +{ + typedef uint8_t type; +}; +template<> +struct UnsignedForSize<2> +{ + typedef uint16_t type; +}; +template<> +struct UnsignedForSize<4> +{ + typedef uint32_t type; +}; +template<> +struct UnsignedForSize<8> +{ + typedef uint64_t type; +}; +template +struct SubKey; +template +struct SizedSubKey +{ + template + static auto sub_key(T && value, void *) + { + return to_unsigned_or_bool(value); + } + + typedef SubKey next; + + using sub_key_type = typename UnsignedForSize::type; +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct FallbackSubKey + : SubKey()))> +{ + using base = SubKey()))>; + + template + static decltype(auto) sub_key(U && value, void * data) + { + return base::sub_key(to_radix_sort_key(value), data); + } +}; +template +struct FallbackSubKey()))>> + : SubKey()))> +{ +}; +template +struct SubKey : FallbackSubKey +{ +}; +template<> +struct SubKey +{ + template + static bool sub_key(T && value, void *) + { + return value; + } + + typedef SubKey next; + + using sub_key_type = bool; +}; +template<> +struct SubKey; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template +struct SubKey : SizedSubKey +{ +}; +template +struct PairSecondSubKey : Current +{ + static decltype(auto) sub_key(const std::pair & value, void * sort_data) + { + return Current::sub_key(value.second, sort_data); + } + + using next = typename std::conditional, typename Current::next>::value, SubKey, PairSecondSubKey>::type; +}; +template +struct PairFirstSubKey : Current +{ + static decltype(auto) sub_key(const std::pair & value, void * sort_data) + { + return Current::sub_key(value.first, sort_data); + } + + using next = typename std::conditional, typename Current::next>::value, PairSecondSubKey>, PairFirstSubKey>::type; +}; +template +struct SubKey> : PairFirstSubKey> +{ +}; +template +struct TypeAt : TypeAt +{ +}; +template +struct TypeAt<0, First, More...> +{ + typedef First type; +}; + +template +struct TupleSubKey; + +template +struct NextTupleSubKey +{ + using type = TupleSubKey; +}; +template +struct NextTupleSubKey, First, Second, More...> +{ + using type = TupleSubKey, Second, More...>; +}; +template +struct NextTupleSubKey, First> +{ + using type = SubKey; +}; + +template +struct TupleSubKey : Current +{ + template + static decltype(auto) sub_key(const Tuple & value, void * sort_data) + { + return Current::sub_key(std::get(value), sort_data); + } + + using next = typename NextTupleSubKey::type; +}; +template +struct TupleSubKey : Current +{ + template + static decltype(auto) sub_key(const Tuple & value, void * sort_data) + { + return Current::sub_key(std::get(value), sort_data); + } + + using next = typename NextTupleSubKey::type; +}; +template +struct SubKey> : TupleSubKey<0, SubKey, First, More...> +{ +}; + +struct BaseListSortData +{ + size_t current_index; + size_t recursion_limit; + void * next_sort_data; +}; +template +struct ListSortData : BaseListSortData +{ + void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *); +}; + +template +struct ListElementSubKey : SubKey()[0])>::type> +{ + using base = SubKey()[0])>::type>; + + using next = ListElementSubKey; + + template + static decltype(auto) sub_key(U && value, void * sort_data) + { + BaseListSortData * list_sort_data = static_cast(sort_data); + const T & list = CurrentSubKey::sub_key(value, list_sort_data->next_sort_data); + return base::sub_key(list[list_sort_data->current_index], list_sort_data->next_sort_data); + } +}; + +template +struct ListSubKey +{ + using next = SubKey; + + using sub_key_type = T; + + static const T & sub_key(const T & value, void *) + { + return value; + } +}; + +template +struct FallbackSubKey::value>::type> : ListSubKey +{ +}; + +template +inline void StdSortFallback(It begin, It end, ExtractKey & extract_key) +{ + std::sort(begin, end, [&](auto && l, auto && r){ return extract_key(l) < extract_key(r); }); +} + +template +inline bool StdSortIfLessThanThreshold(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key) +{ + if (num_elements <= 1) + return true; + if (num_elements >= StdSortThreshold) + return false; + StdSortFallback(begin, end, extract_key); + return true; +} + +template +struct InplaceSorter; + +template +struct UnsignedInplaceSorter +{ + static constexpr size_t ShiftAmount = (((NumBytes - 1) - Offset) * 8); + template + inline static uint8_t current_byte(T && elem, void * sort_data) + { + return CurrentSubKey::sub_key(elem, sort_data) >> ShiftAmount; + } + template + static void sort(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + if (num_elements < AmericanFlagSortThreshold) + american_flag_sort(begin, end, extract_key, next_sort, sort_data); + else + ska_byte_sort(begin, end, extract_key, next_sort, sort_data); + } + + template + static void american_flag_sort(It begin, It end, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + PartitionInfo partitions[256]; + for (It it = begin; it != end; ++it) + { + ++partitions[current_byte(extract_key(*it), sort_data)].count; + } + size_t total = 0; + uint8_t remaining_partitions[256]; + int num_partitions = 0; + for (int i = 0; i < 256; ++i) + { + size_t count = partitions[i].count; + if (!count) + continue; + partitions[i].offset = total; + total += count; + partitions[i].next_offset = total; + remaining_partitions[num_partitions] = i; + ++num_partitions; + } + if (num_partitions > 1) + { + uint8_t * current_block_ptr = remaining_partitions; + PartitionInfo * current_block = partitions + *current_block_ptr; + uint8_t * last_block = remaining_partitions + num_partitions - 1; + It it = begin; + It block_end = begin + current_block->next_offset; + It last_element = end - 1; + for (;;) + { + PartitionInfo * block = partitions + current_byte(extract_key(*it), sort_data); + if (block == current_block) + { + ++it; + if (it == last_element) + break; + else if (it == block_end) + { + for (;;) + { + ++current_block_ptr; + if (current_block_ptr == last_block) + goto recurse; + current_block = partitions + *current_block_ptr; + if (current_block->offset != current_block->next_offset) + break; + } + + it = begin + current_block->offset; + block_end = begin + current_block->next_offset; + } + } + else + { + size_t offset = block->offset++; + std::iter_swap(it, begin + offset); + } + } + } + recurse: + if (Offset + 1 != NumBytes || next_sort) + { + size_t start_offset = 0; + It partition_begin = begin; + for (uint8_t * it = remaining_partitions, * end = remaining_partitions + num_partitions; it != end; ++it) + { + size_t end_offset = partitions[*it].next_offset; + It partition_end = begin + end_offset; + std::ptrdiff_t num_elements = end_offset - start_offset; + if (!StdSortIfLessThanThreshold(partition_begin, partition_end, num_elements, extract_key)) + { + UnsignedInplaceSorter::sort(partition_begin, partition_end, num_elements, extract_key, next_sort, sort_data); + } + start_offset = end_offset; + partition_begin = partition_end; + } + } + } + + template + static void ska_byte_sort(It begin, It end, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + PartitionInfo partitions[256]; + for (It it = begin; it != end; ++it) + { + ++partitions[current_byte(extract_key(*it), sort_data)].count; + } + uint8_t remaining_partitions[256]; + size_t total = 0; + int num_partitions = 0; + for (int i = 0; i < 256; ++i) + { + size_t count = partitions[i].count; + if (count) + { + partitions[i].offset = total; + total += count; + remaining_partitions[num_partitions] = i; + ++num_partitions; + } + partitions[i].next_offset = total; + } + for (uint8_t * last_remaining = remaining_partitions + num_partitions, * end_partition = remaining_partitions + 1; last_remaining > end_partition;) + { + last_remaining = custom_std_partition(remaining_partitions, last_remaining, [&](uint8_t partition) + { + size_t & begin_offset = partitions[partition].offset; + size_t & end_offset = partitions[partition].next_offset; + if (begin_offset == end_offset) + return false; + + unroll_loop_four_times(begin + begin_offset, end_offset - begin_offset, [partitions = partitions, begin, &extract_key, sort_data](It it) + { + uint8_t this_partition = current_byte(extract_key(*it), sort_data); + size_t offset = partitions[this_partition].offset++; + std::iter_swap(it, begin + offset); + }); + return begin_offset != end_offset; + }); + } + if (Offset + 1 != NumBytes || next_sort) + { + for (uint8_t * it = remaining_partitions + num_partitions; it != remaining_partitions; --it) + { + uint8_t partition = it[-1]; + size_t start_offset = (partition == 0 ? 0 : partitions[partition - 1].next_offset); + size_t end_offset = partitions[partition].next_offset; + It partition_begin = begin + start_offset; + It partition_end = begin + end_offset; + std::ptrdiff_t num_elements = end_offset - start_offset; + if (!StdSortIfLessThanThreshold(partition_begin, partition_end, num_elements, extract_key)) + { + UnsignedInplaceSorter::sort(partition_begin, partition_end, num_elements, extract_key, next_sort, sort_data); + } + } + } + } +}; + +template +struct UnsignedInplaceSorter +{ + template + inline static void sort(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * next_sort_data) + { + next_sort(begin, end, num_elements, extract_key, next_sort_data); + } +}; + +template +size_t CommonPrefix(It begin, It end, size_t start_index, ExtractKey && extract_key, ElementKey && element_key) +{ + const auto & largest_match_list = extract_key(*begin); + size_t largest_match = largest_match_list.size(); + if (largest_match == start_index) + return start_index; + for (++begin; begin != end; ++begin) + { + const auto & current_list = extract_key(*begin); + size_t current_size = current_list.size(); + if (current_size < largest_match) + { + largest_match = current_size; + if (largest_match == start_index) + return start_index; + } + if (element_key(largest_match_list[start_index]) != element_key(current_list[start_index])) + return start_index; + for (size_t i = start_index + 1; i < largest_match; ++i) + { + if (element_key(largest_match_list[i]) != element_key(current_list[i])) + { + largest_match = i; + break; + } + } + } + return largest_match; +} + +template +struct ListInplaceSorter +{ + using ElementSubKey = ListElementSubKey; + template + static void sort(It begin, It end, ExtractKey & extract_key, ListSortData * sort_data) + { + size_t current_index = sort_data->current_index; + void * next_sort_data = sort_data->next_sort_data; + auto current_key = [&](auto && elem) -> decltype(auto) + { + return CurrentSubKey::sub_key(extract_key(elem), next_sort_data); + }; + auto element_key = [&](auto && elem) -> decltype(auto) + { + return ElementSubKey::base::sub_key(elem, sort_data); + }; + sort_data->current_index = current_index = CommonPrefix(begin, end, current_index, current_key, element_key); + It end_of_shorter_ones = std::partition(begin, end, [&](auto && elem) + { + return current_key(elem).size() <= current_index; + }); + std::ptrdiff_t num_shorter_ones = end_of_shorter_ones - begin; + if (sort_data->next_sort && !StdSortIfLessThanThreshold(begin, end_of_shorter_ones, num_shorter_ones, extract_key)) + { + sort_data->next_sort(begin, end_of_shorter_ones, num_shorter_ones, extract_key, next_sort_data); + } + std::ptrdiff_t num_elements = end - end_of_shorter_ones; + if (!StdSortIfLessThanThreshold(end_of_shorter_ones, end, num_elements, extract_key)) + { + void (*sort_next_element)(It, It, std::ptrdiff_t, ExtractKey &, void *) = static_cast(&sort_from_recursion); + InplaceSorter::sort(end_of_shorter_ones, end, num_elements, extract_key, sort_next_element, sort_data); + } + } + + template + static void sort_from_recursion(It begin, It end, std::ptrdiff_t, ExtractKey & extract_key, void * next_sort_data) + { + ListSortData offset = *static_cast *>(next_sort_data); + ++offset.current_index; + --offset.recursion_limit; + if (offset.recursion_limit == 0) + { + StdSortFallback(begin, end, extract_key); + } + else + { + sort(begin, end, extract_key, &offset); + } + } + + + template + static void sort(It begin, It end, std::ptrdiff_t, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * next_sort_data) + { + ListSortData offset; + offset.current_index = 0; + offset.recursion_limit = 16; + offset.next_sort = next_sort; + offset.next_sort_data = next_sort_data; + sort(begin, end, extract_key, &offset); + } +}; + +template +struct InplaceSorter +{ + template + static void sort(It begin, It end, std::ptrdiff_t, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + It middle = std::partition(begin, end, [&](auto && a){ return !CurrentSubKey::sub_key(extract_key(a), sort_data); }); + if (next_sort) + { + next_sort(begin, middle, middle - begin, extract_key, sort_data); + next_sort(middle, end, end - middle, extract_key, sort_data); + } + } +}; + +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct FallbackInplaceSorter; + +template +struct InplaceSorter : FallbackInplaceSorter +{ +}; + +template +struct FallbackInplaceSorter::value>::type> + : ListInplaceSorter +{ +}; + +template +struct SortStarter; +template +struct SortStarter> +{ + template + static void sort(It, It, std::ptrdiff_t, ExtractKey &, void *) + { + } +}; + +template +struct SortStarter +{ + template + static void sort(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key, void * next_sort_data = nullptr) + { + if (StdSortIfLessThanThreshold(begin, end, num_elements, extract_key)) + return; + + void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *) = static_cast(&SortStarter::sort); + if (next_sort == static_cast(&SortStarter>::sort)) + next_sort = nullptr; + InplaceSorter::sort(begin, end, num_elements, extract_key, next_sort, next_sort_data); + } +}; + +template +void inplace_radix_sort(It begin, It end, ExtractKey & extract_key) +{ + using SubKey = SubKey; + SortStarter::sort(begin, end, end - begin, extract_key); +} + +struct IdentityFunctor +{ + template + decltype(auto) operator()(T && i) const + { + return std::forward(i); + } +}; +} + +template +static void ska_sort(It begin, It end, ExtractKey && extract_key) +{ + detail::inplace_radix_sort<128, 1024>(begin, end, extract_key); +} + +template +static void ska_sort(It begin, It end) +{ + ska_sort(begin, end, detail::IdentityFunctor()); +} + +template +bool ska_sort_copy(It begin, It end, OutIt buffer_begin, ExtractKey && key) +{ + std::ptrdiff_t num_elements = end - begin; + if (num_elements < 128 || detail::radix_sort_pass_count::type> >= 8) + { + ska_sort(begin, end, key); + return false; + } + else + return detail::RadixSorter::type>::sort(begin, end, buffer_begin, key); +} +template +bool ska_sort_copy(It begin, It end, OutIt buffer_begin) +{ + return ska_sort_copy(begin, end, buffer_begin, detail::IdentityFunctor()); +} diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/dense_hash_map b/oap-native-sql/cpp/src/third_party/sparsehash/dense_hash_map new file mode 100644 index 000000000..1802ca1e2 --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/dense_hash_map @@ -0,0 +1,420 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// ---- +// +// This is just a very thin wrapper over densehashtable.h, just +// like sgi stl's stl_hash_map is a very thin wrapper over +// stl_hashtable. The major thing we define is operator[], because +// we have a concept of a data_type which stl_hashtable doesn't +// (it only has a key and a value). +// +// NOTE: this is exactly like sparse_hash_map.h, with the word +// "sparse" replaced by "dense", except for the addition of +// set_empty_key(). +// +// YOU MUST CALL SET_EMPTY_KEY() IMMEDIATELY AFTER CONSTRUCTION. +// +// Otherwise your program will die in mysterious ways. (Note if you +// use the constructor that takes an InputIterator range, you pass in +// the empty key in the constructor, rather than after. As a result, +// this constructor differs from the standard STL version.) +// +// In other respects, we adhere mostly to the STL semantics for +// hash-map. One important exception is that insert() may invalidate +// iterators entirely -- STL semantics are that insert() may reorder +// iterators, but they all still refer to something valid in the +// hashtable. Not so for us. Likewise, insert() may invalidate +// pointers into the hashtable. (Whether insert invalidates iterators +// and pointers depends on whether it results in a hashtable resize). +// On the plus side, delete() doesn't invalidate iterators or pointers +// at all, or even change the ordering of elements. +// +// Here are a few "power user" tips: +// +// 1) set_deleted_key(): +// If you want to use erase() you *must* call set_deleted_key(), +// in addition to set_empty_key(), after construction. +// The deleted and empty keys must differ. +// +// 2) resize(0): +// When an item is deleted, its memory isn't freed right +// away. This allows you to iterate over a hashtable, +// and call erase(), without invalidating the iterator. +// To force the memory to be freed, call resize(0). +// For tr1 compatibility, this can also be called as rehash(0). +// +// 3) min_load_factor(0.0) +// Setting the minimum load factor to 0.0 guarantees that +// the hash table will never shrink. +// +// Roughly speaking: +// (1) dense_hash_map: fastest, uses the most memory unless entries are small +// (2) sparse_hash_map: slowest, uses the least memory +// (3) hash_map / unordered_map (STL): in the middle +// +// Typically I use sparse_hash_map when I care about space and/or when +// I need to save the hashtable on disk. I use hash_map otherwise. I +// don't personally use dense_hash_set ever; some people use it for +// small sets with lots of lookups. +// +// - dense_hash_map has, typically, about 78% memory overhead (if your +// data takes up X bytes, the hash_map uses .78X more bytes in overhead). +// - sparse_hash_map has about 4 bits overhead per entry. +// - sparse_hash_map can be 3-7 times slower than the others for lookup and, +// especially, inserts. See time_hash_map.cc for details. +// +// See /usr/(local/)?doc/sparsehash-*/dense_hash_map.html +// for information about how to use this class. + +#pragma once + +#include // needed by stl_alloc +#include // for equal_to<>, select1st<>, etc +#include // for initializer_list +#include // for alloc +#include // for pair<> +#include // forward_as_tuple +#include // for enable_if, is_constructible, etc +#include // IWYU pragma: export +#include + +namespace google { + +template , + class EqualKey = std::equal_to, + class Alloc = libc_allocator_with_realloc>> +class dense_hash_map { + private: + // Apparently select1st is not stl-standard, so we define our own + struct SelectKey { + typedef const Key& result_type; + template + const Key& operator()(Pair&& p) const { + return p.first; + } + }; + struct SetKey { + void operator()(std::pair* value, const Key& new_key) const { + using NCKey = typename std::remove_cv::type; + *const_cast(&value->first) = new_key; + + // It would be nice to clear the rest of value here as well, in + // case it's taking up a lot of memory. We do this by clearing + // the value. This assumes T has a zero-arg constructor! + value->second = T(); + } + void operator()(std::pair* value, const Key& new_key, bool) const { + new(value) std::pair(std::piecewise_construct, std::forward_as_tuple(new_key), std::forward_as_tuple()); + } + }; + + // The actual data + typedef typename sparsehash_internal::key_equal_chosen::type EqualKeyChosen; + typedef dense_hashtable, Key, HashFcn, SelectKey, + SetKey, EqualKeyChosen, Alloc> ht; + ht rep; + + static_assert(!sparsehash_internal::has_transparent_key_equal::value + || std::is_same>::value + || std::is_same::value, + "Heterogeneous lookup requires key_equal to either be the default container value or the same as the type provided by hash"); + + public: + typedef typename ht::key_type key_type; + typedef T data_type; + typedef T mapped_type; + typedef typename ht::value_type value_type; + typedef typename ht::hasher hasher; + typedef typename ht::key_equal key_equal; + typedef Alloc allocator_type; + + typedef typename ht::size_type size_type; + typedef typename ht::difference_type difference_type; + typedef typename ht::pointer pointer; + typedef typename ht::const_pointer const_pointer; + typedef typename ht::reference reference; + typedef typename ht::const_reference const_reference; + + typedef typename ht::iterator iterator; + typedef typename ht::const_iterator const_iterator; + typedef typename ht::local_iterator local_iterator; + typedef typename ht::const_local_iterator const_local_iterator; + + // Iterator functions + iterator begin() { return rep.begin(); } + iterator end() { return rep.end(); } + const_iterator begin() const { return rep.begin(); } + const_iterator end() const { return rep.end(); } + const_iterator cbegin() const { return rep.begin(); } + const_iterator cend() const { return rep.end(); } + + // These come from tr1's unordered_map. For us, a bucket has 0 or 1 elements. + local_iterator begin(size_type i) { return rep.begin(i); } + local_iterator end(size_type i) { return rep.end(i); } + const_local_iterator begin(size_type i) const { return rep.begin(i); } + const_local_iterator end(size_type i) const { return rep.end(i); } + const_local_iterator cbegin(size_type i) const { return rep.begin(i); } + const_local_iterator cend(size_type i) const { return rep.end(i); } + + // Accessor functions + allocator_type get_allocator() const { return rep.get_allocator(); } + hasher hash_funct() const { return rep.hash_funct(); } + hasher hash_function() const { return hash_funct(); } + key_equal key_eq() const { return rep.key_eq(); } + + // Constructors + explicit dense_hash_map(size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, SelectKey(), SetKey(), + alloc) {} + + template + dense_hash_map(InputIterator f, InputIterator l, + const key_type& empty_key_val, + size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, SelectKey(), SetKey(), + alloc) { + set_empty_key(empty_key_val); + rep.insert(f, l); + } + // We use the default copy constructor + // We use the default operator=() + // We use the default destructor + + void clear() { rep.clear(); } + // This clears the hash map without resizing it down to the minimum + // bucket count, but rather keeps the number of buckets constant + void clear_no_resize() { rep.clear_no_resize(); } + void swap(dense_hash_map& hs) { rep.swap(hs.rep); } + + // Functions concerning size + size_type size() const { return rep.size(); } + size_type max_size() const { return rep.max_size(); } + bool empty() const { return rep.empty(); } + size_type bucket_count() const { return rep.bucket_count(); } + size_type max_bucket_count() const { return rep.max_bucket_count(); } + + // These are tr1 methods. bucket() is the bucket the key is or would be in. + size_type bucket_size(size_type i) const { return rep.bucket_size(i); } + size_type bucket(const key_type& key) const { return rep.bucket(key); } + float load_factor() const { return size() * 1.0f / bucket_count(); } + float max_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return grow; + } + void max_load_factor(float new_grow) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(shrink, new_grow); + } + // These aren't tr1 methods but perhaps ought to be. + float min_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return shrink; + } + void min_load_factor(float new_shrink) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(new_shrink, grow); + } + // Deprecated; use min_load_factor() or max_load_factor() instead. + void set_resizing_parameters(float shrink, float grow) { + rep.set_resizing_parameters(shrink, grow); + } + + void reserve(size_type size) { rehash(size); } // note: rehash internally treats hint/size as number of elements + void resize(size_type hint) { rep.resize(hint); } + void rehash(size_type hint) { resize(hint); } // the tr1 name + + // Lookup routines + iterator find(const key_type& key) { return rep.find(key); } + const_iterator find(const key_type& key) const { return rep.find(key); } + + template + typename std::enable_if::value, iterator>::type + find(const K& key) { return rep.find(key); } + template + typename std::enable_if::value, const_iterator>::type + find(const K& key) const { return rep.find(key); } + + data_type& operator[](const key_type& key) { // This is our value-add! + // If key is in the hashtable, returns find(key)->second, + // otherwise returns insert(value_type(key, T()).first->second. + // Note it does not create an empty T unless the find fails. + return rep.template find_or_insert(key).second; + } + + data_type& operator[](key_type&& key) { + return rep.template find_or_insert(std::move(key)).second; + } + + size_type count(const key_type& key) const { return rep.count(key); } + + template + typename std::enable_if::value, size_type>::type + count(const K& key) const { return rep.count(key); } + + std::pair equal_range(const key_type& key) { + return rep.equal_range(key); + } + std::pair equal_range( + const key_type& key) const { + return rep.equal_range(key); + } + + template + typename std::enable_if::value, std::pair>::type + equal_range(const K& key) { + return rep.equal_range(key); + } + template + typename std::enable_if::value, std::pair>::type + equal_range(const K& key) const { + return rep.equal_range(key); + } + + // Insertion routines + std::pair insert(const value_type& obj) { + return rep.insert(obj); + } + + template ::value>::type> + std::pair insert(Pair&& obj) { + return rep.insert(std::forward(obj)); + } + + // overload to allow {} syntax: .insert( { {key}, {args} } ) + std::pair insert(value_type&& obj) { + return rep.insert(std::move(obj)); + } + + template + std::pair emplace(Args&&... args) { + return rep.emplace(std::forward(args)...); + } + + template + std::pair emplace_hint(const_iterator hint, Args&&... args) { + return rep.emplace_hint(hint, std::forward(args)...); + } + + + template + void insert(InputIterator f, InputIterator l) { + rep.insert(f, l); + } + void insert(const_iterator f, const_iterator l) { rep.insert(f, l); } + void insert(std::initializer_list ilist) { rep.insert(ilist.begin(), ilist.end()); } + // Required for std::insert_iterator; the passed-in iterator is ignored. + iterator insert(const_iterator, const value_type& obj) { return insert(obj).first; } + iterator insert(const_iterator, value_type&& obj) { return insert(std::move(obj)).first; } + template ::value && + !std::is_same::value + >::type> + iterator insert(const_iterator, P&& obj) { return insert(std::forward

(obj)).first; } + + // Deletion and empty routines + // THESE ARE NON-STANDARD! I make you specify an "impossible" key + // value to identify deleted and empty buckets. You can change the + // deleted key as time goes on, or get rid of it entirely to be insert-only. + // YOU MUST CALL THIS! + void set_empty_key(const key_type& key) { rep.set_empty_key(key); } + key_type empty_key() const { return rep.empty_key(); } + + void set_deleted_key(const key_type& key) { rep.set_deleted_key(key); } + void clear_deleted_key() { rep.clear_deleted_key(); } + key_type deleted_key() const { return rep.deleted_key(); } + + // These are standard + size_type erase(const key_type& key) { return rep.erase(key); } + iterator erase(const_iterator it) { return rep.erase(it); } + iterator erase(const_iterator f, const_iterator l) { return rep.erase(f, l); } + + // Comparison + bool operator==(const dense_hash_map& hs) const { return rep == hs.rep; } + bool operator!=(const dense_hash_map& hs) const { return rep != hs.rep; } + + // I/O -- this is an add-on for writing hash map to disk + // + // For maximum flexibility, this does not assume a particular + // file type (though it will probably be a FILE *). We just pass + // the fp through to rep. + + // If your keys and values are simple enough, you can pass this + // serializer to serialize()/unserialize(). "Simple enough" means + // value_type is a POD type that contains no pointers. Note, + // however, we don't try to normalize endianness. + typedef typename ht::NopointerSerializer NopointerSerializer; + + // serializer: a class providing operator()(OUTPUT*, const value_type&) + // (writing value_type to OUTPUT). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an ostream*/subclass_of_ostream*, OR a + // pointer to a class providing size_t Write(const void*, size_t), + // which writes a buffer into a stream (which fp presumably + // owns) and returns the number of bytes successfully written. + // Note basic_ostream is not currently supported. + template + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + return rep.serialize(serializer, fp); + } + + // serializer: a functor providing operator()(INPUT*, value_type*) + // (reading from INPUT and into value_type). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an istream*/subclass_of_istream*, OR a + // pointer to a class providing size_t Read(void*, size_t), + // which reads into a buffer from a stream (which fp presumably + // owns) and returns the number of bytes successfully read. + // Note basic_istream is not currently supported. + // NOTE: Since value_type is std::pair, ValueSerializer + // may need to do a const cast in order to fill in the key. + template + bool unserialize(ValueSerializer serializer, INPUT* fp) { + return rep.unserialize(serializer, fp); + } +}; + +// We need a global swap as well +template +inline void swap(dense_hash_map& hm1, + dense_hash_map& hm2) { + hm1.swap(hm2); +} + +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/dense_hash_set b/oap-native-sql/cpp/src/third_party/sparsehash/dense_hash_set new file mode 100644 index 000000000..fa91798d3 --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/dense_hash_set @@ -0,0 +1,369 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// This is just a very thin wrapper over densehashtable.h, just +// like sgi stl's stl_hash_set is a very thin wrapper over +// stl_hashtable. The major thing we define is operator[], because +// we have a concept of a data_type which stl_hashtable doesn't +// (it only has a key and a value). +// +// This is more different from dense_hash_map than you might think, +// because all iterators for sets are const (you obviously can't +// change the key, and for sets there is no value). +// +// NOTE: this is exactly like sparse_hash_set.h, with the word +// "sparse" replaced by "dense", except for the addition of +// set_empty_key(). +// +// YOU MUST CALL SET_EMPTY_KEY() IMMEDIATELY AFTER CONSTRUCTION. +// +// Otherwise your program will die in mysterious ways. (Note if you +// use the constructor that takes an InputIterator range, you pass in +// the empty key in the constructor, rather than after. As a result, +// this constructor differs from the standard STL version.) +// +// In other respects, we adhere mostly to the STL semantics for +// hash-map. One important exception is that insert() may invalidate +// iterators entirely -- STL semantics are that insert() may reorder +// iterators, but they all still refer to something valid in the +// hashtable. Not so for us. Likewise, insert() may invalidate +// pointers into the hashtable. (Whether insert invalidates iterators +// and pointers depends on whether it results in a hashtable resize). +// On the plus side, delete() doesn't invalidate iterators or pointers +// at all, or even change the ordering of elements. +// +// Here are a few "power user" tips: +// +// 1) set_deleted_key(): +// If you want to use erase() you must call set_deleted_key(), +// in addition to set_empty_key(), after construction. +// The deleted and empty keys must differ. +// +// 2) resize(0): +// When an item is deleted, its memory isn't freed right +// away. This allows you to iterate over a hashtable, +// and call erase(), without invalidating the iterator. +// To force the memory to be freed, call resize(0). +// For tr1 compatibility, this can also be called as rehash(0). +// +// 3) min_load_factor(0.0) +// Setting the minimum load factor to 0.0 guarantees that +// the hash table will never shrink. +// +// Roughly speaking: +// (1) dense_hash_set: fastest, uses the most memory unless entries are small +// (2) sparse_hash_set: slowest, uses the least memory +// (3) hash_set / unordered_set (STL): in the middle +// +// Typically I use sparse_hash_set when I care about space and/or when +// I need to save the hashtable on disk. I use hash_set otherwise. I +// don't personally use dense_hash_set ever; some people use it for +// small sets with lots of lookups. +// +// - dense_hash_set has, typically, about 78% memory overhead (if your +// data takes up X bytes, the hash_set uses .78X more bytes in overhead). +// - sparse_hash_set has about 4 bits overhead per entry. +// - sparse_hash_set can be 3-7 times slower than the others for lookup and, +// especially, inserts. See time_hash_map.cc for details. +// +// See /usr/(local/)?doc/sparsehash-*/dense_hash_set.html +// for information about how to use this class. + +#pragma once + +#include // needed by stl_alloc +#include // for equal_to<>, select1st<>, etc +#include // for initializer_list +#include // for alloc +#include // for pair<> +#include // IWYU pragma: export +#include + +namespace google { + +template , + class EqualKey = std::equal_to, + class Alloc = libc_allocator_with_realloc> +class dense_hash_set { + private: + // Apparently identity is not stl-standard, so we define our own + struct Identity { + typedef const Value& result_type; + template + const Value& operator()(V&& v) const { return v; } + }; + struct SetKey { + void operator()(Value* value, const Value& new_key) const { + *value = new_key; + } + void operator()(Value* value, const Value& new_key, bool) const { + new(value) Value(new_key); + } + }; + + // The actual data + typedef typename sparsehash_internal::key_equal_chosen::type EqualKeyChosen; + typedef dense_hashtable ht; + ht rep; + + static_assert(!sparsehash_internal::has_transparent_key_equal::value + || std::is_same>::value + || std::is_same::value, + "Heterogeneous lookup requires key_equal to either be the default container value or the same as the type provided by hash"); + + public: + typedef typename ht::key_type key_type; + typedef typename ht::value_type value_type; + typedef typename ht::hasher hasher; + typedef typename ht::key_equal key_equal; + typedef Alloc allocator_type; + + typedef typename ht::size_type size_type; + typedef typename ht::difference_type difference_type; + typedef typename ht::const_pointer pointer; + typedef typename ht::const_pointer const_pointer; + typedef typename ht::const_reference reference; + typedef typename ht::const_reference const_reference; + + typedef typename ht::const_iterator iterator; + typedef typename ht::const_iterator const_iterator; + typedef typename ht::const_local_iterator local_iterator; + typedef typename ht::const_local_iterator const_local_iterator; + + // Iterator functions -- recall all iterators are const + iterator begin() const { return rep.begin(); } + iterator end() const { return rep.end(); } + const_iterator cbegin() const { return rep.begin(); } + const_iterator cend() const { return rep.end(); } + + // These come from tr1's unordered_set. For us, a bucket has 0 or 1 elements. + local_iterator begin(size_type i) const { return rep.begin(i); } + local_iterator end(size_type i) const { return rep.end(i); } + local_iterator cbegin(size_type i) const { return rep.begin(i); } + local_iterator cend(size_type i) const { return rep.end(i); } + + // Accessor functions + allocator_type get_allocator() const { return rep.get_allocator(); } + hasher hash_funct() const { return rep.hash_funct(); } + hasher hash_function() const { return hash_funct(); } // tr1 name + key_equal key_eq() const { return rep.key_eq(); } + + // Constructors + explicit dense_hash_set(size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, Identity(), SetKey(), alloc) { + } + + template + dense_hash_set(InputIterator f, InputIterator l, + const key_type& empty_key_val, + size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, Identity(), SetKey(), alloc) { + set_empty_key(empty_key_val); + rep.insert(f, l); + } + // We use the default copy constructor + // We use the default operator=() + // We use the default destructor + + void clear() { rep.clear(); } + // This clears the hash set without resizing it down to the minimum + // bucket count, but rather keeps the number of buckets constant + void clear_no_resize() { rep.clear_no_resize(); } + void swap(dense_hash_set& hs) { rep.swap(hs.rep); } + + // Functions concerning size + size_type size() const { return rep.size(); } + size_type max_size() const { return rep.max_size(); } + bool empty() const { return rep.empty(); } + size_type bucket_count() const { return rep.bucket_count(); } + size_type max_bucket_count() const { return rep.max_bucket_count(); } + + // These are tr1 methods. bucket() is the bucket the key is or would be in. + size_type bucket_size(size_type i) const { return rep.bucket_size(i); } + size_type bucket(const key_type& key) const { return rep.bucket(key); } + float load_factor() const { return size() * 1.0f / bucket_count(); } + float max_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return grow; + } + void max_load_factor(float new_grow) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(shrink, new_grow); + } + // These aren't tr1 methods but perhaps ought to be. + float min_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return shrink; + } + void min_load_factor(float new_shrink) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(new_shrink, grow); + } + // Deprecated; use min_load_factor() or max_load_factor() instead. + void set_resizing_parameters(float shrink, float grow) { + rep.set_resizing_parameters(shrink, grow); + } + + void reserve(size_type size) { rehash(size); } // note: rehash internally treats hint/size as number of elements + void resize(size_type hint) { rep.resize(hint); } + void rehash(size_type hint) { resize(hint); } // the tr1 name + + // Lookup routines + iterator find(const key_type& key) const { return rep.find(key); } + + template + typename std::enable_if::value, iterator>::type + find(const K& key) const { return rep.find(key); } + + size_type count(const key_type& key) const { return rep.count(key); } + + template + typename std::enable_if::value, size_type>::type + count(const K& key) const { return rep.count(key); } + + std::pair equal_range(const key_type& key) const { + return rep.equal_range(key); + } + + template + typename std::enable_if::value, std::pair>::type + equal_range(const K& key) const { + return rep.equal_range(key); + } + + // Insertion routines + std::pair insert(const value_type& obj) { + std::pair p = rep.insert(obj); + return std::pair(p.first, p.second); // const to non-const + } + + std::pair insert(value_type&& obj) { + std::pair p = rep.insert(std::move(obj)); + return std::pair(p.first, p.second); // const to non-const + } + + template + std::pair emplace(Args&&... args) { + return rep.emplace(std::forward(args)...); + } + + template + std::pair emplace_hint(const_iterator hint, Args&&... args) { + return rep.emplace_hint(hint, std::forward(args)...); + } + + template + void insert(InputIterator f, InputIterator l) { + rep.insert(f, l); + } + void insert(const_iterator f, const_iterator l) { rep.insert(f, l); } + void insert(std::initializer_list ilist) { rep.insert(ilist.begin(), ilist.end()); } + // Required for std::insert_iterator; the passed-in iterator is ignored. + iterator insert(const_iterator, const value_type& obj) { return insert(obj).first; } + iterator insert(const_iterator, value_type&& obj) { return insert(std::move(obj)).first; } + + // Deletion and empty routines + // THESE ARE NON-STANDARD! I make you specify an "impossible" key + // value to identify deleted and empty buckets. You can change the + // deleted key as time goes on, or get rid of it entirely to be insert-only. + void set_empty_key(const key_type& key) { rep.set_empty_key(key); } + key_type empty_key() const { return rep.empty_key(); } + + void set_deleted_key(const key_type& key) { rep.set_deleted_key(key); } + void clear_deleted_key() { rep.clear_deleted_key(); } + key_type deleted_key() const { return rep.deleted_key(); } + + // These are standard + size_type erase(const key_type& key) { return rep.erase(key); } + iterator erase(const_iterator it) { return rep.erase(it); } + iterator erase(const_iterator f, const_iterator l) { return rep.erase(f, l); } + + // Comparison + bool operator==(const dense_hash_set& hs) const { return rep == hs.rep; } + bool operator!=(const dense_hash_set& hs) const { return rep != hs.rep; } + + // I/O -- this is an add-on for writing metainformation to disk + // + // For maximum flexibility, this does not assume a particular + // file type (though it will probably be a FILE *). We just pass + // the fp through to rep. + + // If your keys and values are simple enough, you can pass this + // serializer to serialize()/unserialize(). "Simple enough" means + // value_type is a POD type that contains no pointers. Note, + // however, we don't try to normalize endianness. + typedef typename ht::NopointerSerializer NopointerSerializer; + + // serializer: a class providing operator()(OUTPUT*, const value_type&) + // (writing value_type to OUTPUT). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an ostream*/subclass_of_ostream*, OR a + // pointer to a class providing size_t Write(const void*, size_t), + // which writes a buffer into a stream (which fp presumably + // owns) and returns the number of bytes successfully written. + // Note basic_ostream is not currently supported. + template + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + return rep.serialize(serializer, fp); + } + + // serializer: a functor providing operator()(INPUT*, value_type*) + // (reading from INPUT and into value_type). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an istream*/subclass_of_istream*, OR a + // pointer to a class providing size_t Read(void*, size_t), + // which reads into a buffer from a stream (which fp presumably + // owns) and returns the number of bytes successfully read. + // Note basic_istream is not currently supported. + template + bool unserialize(ValueSerializer serializer, INPUT* fp) { + return rep.unserialize(serializer, fp); + } +}; + +template +inline void swap(dense_hash_set& hs1, + dense_hash_set& hs2) { + hs1.swap(hs2); +} + +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/internal/densehashtable.h b/oap-native-sql/cpp/src/third_party/sparsehash/internal/densehashtable.h new file mode 100644 index 000000000..e254126c2 --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/internal/densehashtable.h @@ -0,0 +1,1380 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// A dense hashtable is a particular implementation of +// a hashtable: one that is meant to minimize memory allocation. +// It does this by using an array to store all the data. We +// steal a value from the key space to indicate "empty" array +// elements (ie indices where no item lives) and another to indicate +// "deleted" elements. +// +// (Note it is possible to change the value of the delete key +// on the fly; you can even remove it, though after that point +// the hashtable is insert_only until you set it again. The empty +// value however can't be changed.) +// +// To minimize allocation and pointer overhead, we use internal +// probing, in which the hashtable is a single table, and collisions +// are resolved by trying to insert again in another bucket. The +// most cache-efficient internal probing schemes are linear probing +// (which suffers, alas, from clumping) and quadratic probing, which +// is what we implement by default. +// +// Type requirements: value_type is required to be Copy Constructible +// and Default Constructible. It is not required to be (and commonly +// isn't) Assignable. +// +// You probably shouldn't use this code directly. Use dense_hash_map<> +// or dense_hash_set<> instead. + +// You can change the following below: +// HT_OCCUPANCY_PCT -- how full before we double size +// HT_EMPTY_PCT -- how empty before we halve size +// HT_MIN_BUCKETS -- default smallest bucket size +// +// You can also change enlarge_factor (which defaults to +// HT_OCCUPANCY_PCT), and shrink_factor (which defaults to +// HT_EMPTY_PCT) with set_resizing_parameters(). +// +// How to decide what values to use? +// shrink_factor's default of .4 * OCCUPANCY_PCT, is probably good. +// HT_MIN_BUCKETS is probably unnecessary since you can specify +// (indirectly) the starting number of buckets at construct-time. +// For enlarge_factor, you can use this chart to try to trade-off +// expected lookup time to the space taken up. By default, this +// code uses quadratic probing, though you can change it to linear +// via JUMP_ below if you really want to. +// +// From +// http://www.augustana.ca/~mohrj/courses/1999.fall/csc210/lecture_notes/hashing.html +// NUMBER OF PROBES / LOOKUP Successful Unsuccessful +// Quadratic collision resolution 1 - ln(1-L) - L/2 1/(1-L) - L - ln(1-L) +// Linear collision resolution [1+1/(1-L)]/2 [1+1/(1-L)2]/2 +// +// -- enlarge_factor -- 0.10 0.50 0.60 0.75 0.80 0.90 0.99 +// QUADRATIC COLLISION RES. +// probes/successful lookup 1.05 1.44 1.62 2.01 2.21 2.85 5.11 +// probes/unsuccessful lookup 1.11 2.19 2.82 4.64 5.81 11.4 103.6 +// LINEAR COLLISION RES. +// probes/successful lookup 1.06 1.5 1.75 2.5 3.0 5.5 50.5 +// probes/unsuccessful lookup 1.12 2.5 3.6 8.5 13.0 50.0 5000.0 + +#pragma once + +#include +#include // for FILE, fwrite, fread +#include // For swap(), eg +#include // For iterator tags +#include // for numeric_limits +#include // For uninitialized_fill +#include // for pair +#include // For length_error +#include +#include +#include + +namespace google { + +// The probing method +// Linear probing +// #define JUMP_(key, num_probes) ( 1 ) +// Quadratic probing +#define JUMP_(key, num_probes) (num_probes) + +// Hashtable class, used to implement the hashed associative containers +// hash_set and hash_map. + +// Value: what is stored in the table (each bucket is a Value). +// Key: something in a 1-to-1 correspondence to a Value, that can be used +// to search for a Value in the table (find() takes a Key). +// HashFcn: Takes a Key and returns an integer, the more unique the better. +// ExtractKey: given a Value, returns the unique Key associated with it. +// Must inherit from unary_function, or at least have a +// result_type enum indicating the return type of operator(). +// SetKey: given a Value* and a Key, modifies the value such that +// ExtractKey(value) == key. We guarantee this is only called +// with key == deleted_key or key == empty_key. +// EqualKey: Given two Keys, says whether they are the same (that is, +// if they are both associated with the same Value). +// Alloc: STL allocator to use to allocate memory. + +template +class dense_hashtable; + +template +struct dense_hashtable_iterator; + +template +struct dense_hashtable_const_iterator; + +// We're just an array, but we need to skip over empty and deleted elements +template +struct dense_hashtable_iterator { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + + public: + typedef dense_hashtable_iterator iterator; + typedef dense_hashtable_const_iterator + const_iterator; + + typedef std::forward_iterator_tag iterator_category; // very little defined! + typedef V value_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::reference reference; + typedef typename value_alloc_type::pointer pointer; + + // "Real" constructor and default constructor + dense_hashtable_iterator( + const dense_hashtable* h, pointer it, + pointer it_end, bool advance) + : ht(h), pos(it), end(it_end) { + if (advance) advance_past_empty_and_deleted(); + } + dense_hashtable_iterator() {} + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // Happy dereferencer + reference operator*() const { return *pos; } + pointer operator->() const { return &(operator*()); } + + // Arithmetic. The only hard part is making sure that + // we're not on an empty or marked-deleted array element + void advance_past_empty_and_deleted() { + while (pos != end && (ht->test_empty(*this) || ht->test_deleted(*this))) + ++pos; + } + iterator& operator++() { + assert(pos != end); + ++pos; + advance_past_empty_and_deleted(); + return *this; + } + iterator operator++(int) { + iterator tmp(*this); + ++*this; + return tmp; + } + + // Comparison. + bool operator==(const iterator& it) const { return pos == it.pos; } + bool operator!=(const iterator& it) const { return pos != it.pos; } + + // The actual data + const dense_hashtable* ht; + pointer pos, end; +}; + +// Now do it all again, but with const-ness! +template +struct dense_hashtable_const_iterator { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + + public: + typedef dense_hashtable_iterator iterator; + typedef dense_hashtable_const_iterator + const_iterator; + + typedef std::forward_iterator_tag iterator_category; // very little defined! + typedef V value_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::const_reference reference; + typedef typename value_alloc_type::const_pointer pointer; + + // "Real" constructor and default constructor + dense_hashtable_const_iterator( + const dense_hashtable* h, pointer it, + pointer it_end, bool advance) + : ht(h), pos(it), end(it_end) { + if (advance) advance_past_empty_and_deleted(); + } + dense_hashtable_const_iterator() : ht(NULL), pos(pointer()), end(pointer()) {} + // This lets us convert regular iterators to const iterators + dense_hashtable_const_iterator(const iterator& it) + : ht(it.ht), pos(it.pos), end(it.end) {} + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // Happy dereferencer + reference operator*() const { return *pos; } + pointer operator->() const { return &(operator*()); } + + // Arithmetic. The only hard part is making sure that + // we're not on an empty or marked-deleted array element + void advance_past_empty_and_deleted() { + while (pos != end && (ht->test_empty(*this) || ht->test_deleted(*this))) + ++pos; + } + const_iterator& operator++() { + assert(pos != end); + ++pos; + advance_past_empty_and_deleted(); + return *this; + } + const_iterator operator++(int) { + const_iterator tmp(*this); + ++*this; + return tmp; + } + + // Comparison. + bool operator==(const const_iterator& it) const { return pos == it.pos; } + bool operator!=(const const_iterator& it) const { return pos != it.pos; } + + // The actual data + const dense_hashtable* ht; + pointer pos, end; +}; + +template +class dense_hashtable { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + + public: + typedef Key key_type; + typedef Value value_type; + typedef HashFcn hasher; + typedef EqualKey key_equal; + typedef Alloc allocator_type; + + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::reference reference; + typedef typename value_alloc_type::const_reference const_reference; + typedef typename value_alloc_type::pointer pointer; + typedef typename value_alloc_type::const_pointer const_pointer; + typedef dense_hashtable_iterator iterator; + + typedef dense_hashtable_const_iterator< + Value, Key, HashFcn, ExtractKey, SetKey, EqualKey, Alloc> const_iterator; + + // These come from tr1. For us they're the same as regular iterators. + typedef iterator local_iterator; + typedef const_iterator const_local_iterator; + + // How full we let the table get before we resize, by default. + // Knuth says .8 is good -- higher causes us to probe too much, + // though it saves memory. + static const int HT_OCCUPANCY_PCT; // defined at the bottom of this file + + // How empty we let the table get before we resize lower, by default. + // (0.0 means never resize lower.) + // It should be less than OCCUPANCY_PCT / 2 or we thrash resizing + static const int HT_EMPTY_PCT; // defined at the bottom of this file + + // Minimum size we're willing to let hashtables be. + // Must be a power of two, and at least 4. + // Note, however, that for a given hashtable, the initial size is a + // function of the first constructor arg, and may be >HT_MIN_BUCKETS. + static const size_type HT_MIN_BUCKETS = 4; + + // By default, if you don't specify a hashtable size at + // construction-time, we use this size. Must be a power of two, and + // at least HT_MIN_BUCKETS. + static const size_type HT_DEFAULT_STARTING_BUCKETS = 32; + + // ITERATOR FUNCTIONS + iterator begin() { return iterator(this, table, table + num_buckets, true); } + iterator end() { + return iterator(this, table + num_buckets, table + num_buckets, true); + } + const_iterator begin() const { + return const_iterator(this, table, table + num_buckets, true); + } + const_iterator end() const { + return const_iterator(this, table + num_buckets, table + num_buckets, true); + } + + // These come from tr1 unordered_map. They iterate over 'bucket' n. + // We'll just consider bucket n to be the n-th element of the table. + local_iterator begin(size_type i) { + return local_iterator(this, table + i, table + i + 1, false); + } + local_iterator end(size_type i) { + local_iterator it = begin(i); + if (!test_empty(i) && !test_deleted(i)) ++it; + return it; + } + const_local_iterator begin(size_type i) const { + return const_local_iterator(this, table + i, table + i + 1, false); + } + const_local_iterator end(size_type i) const { + const_local_iterator it = begin(i); + if (!test_empty(i) && !test_deleted(i)) ++it; + return it; + } + + // ACCESSOR FUNCTIONS for the things we templatize on, basically + hasher hash_funct() const { return settings; } + key_equal key_eq() const { return key_info; } + allocator_type get_allocator() const { return allocator_type(val_info); } + + // Accessor function for statistics gathering. + int num_table_copies() const { return settings.num_ht_copies(); } + + private: + // Annoyingly, we can't copy values around, because they might have + // const components (they're probably pair). We use + // explicit destructor invocation and placement new to get around + // this. Arg. + template + void set_value(pointer dst, Args&&... args) { + dst->~value_type(); // delete the old value, if any + new (dst) value_type(std::forward(args)...); + } + + void destroy_buckets(size_type first, size_type last) { + for (; first != last; ++first) table[first].~value_type(); + } + + // DELETE HELPER FUNCTIONS + // This lets the user describe a key that will indicate deleted + // table entries. This key should be an "impossible" entry -- + // if you try to insert it for real, you won't be able to retrieve it! + // (NB: while you pass in an entire value, only the key part is looked + // at. This is just because I don't know how to assign just a key.) + private: + void squash_deleted() { // gets rid of any deleted entries we have + if (num_deleted) { // get rid of deleted before writing + size_type resize_to = settings.min_buckets( + num_elements, bucket_count()); + dense_hashtable tmp(std::move(*this), resize_to); // copying will get rid of deleted + swap(tmp); // now we are tmp + } + assert(num_deleted == 0); + } + + // Test if the given key is the deleted indicator. Requires + // num_deleted > 0, for correctness of read(), and because that + // guarantees that key_info.delkey is valid. + bool test_deleted_key(const key_type& key) const { + assert(num_deleted > 0); + return equals(key_info.delkey, key); + } + + public: + void set_deleted_key(const key_type& key) { + // the empty indicator (if specified) and the deleted indicator + // must be different + assert( + (!settings.use_empty() || !equals(key, key_info.empty_key)) && + "Passed the empty-key to set_deleted_key"); + // It's only safe to change what "deleted" means if we purge deleted guys + squash_deleted(); + settings.set_use_deleted(true); + key_info.delkey = key; + } + void clear_deleted_key() { + squash_deleted(); + settings.set_use_deleted(false); + } + key_type deleted_key() const { + assert(settings.use_deleted() && + "Must set deleted key before calling deleted_key"); + return key_info.delkey; + } + + // These are public so the iterators can use them + // True if the item at position bucknum is "deleted" marker + bool test_deleted(size_type bucknum) const { + // Invariant: !use_deleted() implies num_deleted is 0. + assert(settings.use_deleted() || num_deleted == 0); + return num_deleted > 0 && test_deleted_key(get_key(table[bucknum])); + } + bool test_deleted(const iterator& it) const { + // Invariant: !use_deleted() implies num_deleted is 0. + assert(settings.use_deleted() || num_deleted == 0); + return num_deleted > 0 && test_deleted_key(get_key(*it)); + } + bool test_deleted(const const_iterator& it) const { + // Invariant: !use_deleted() implies num_deleted is 0. + assert(settings.use_deleted() || num_deleted == 0); + return num_deleted > 0 && test_deleted_key(get_key(*it)); + } + + private: + void check_use_deleted(const char* caller) { + (void)caller; // could log it if the assert failed + assert(settings.use_deleted()); + } + + // Set it so test_deleted is true. true if object didn't used to be deleted. + bool set_deleted(iterator& it) { + check_use_deleted("set_deleted()"); + bool retval = !test_deleted(it); + // &* converts from iterator to value-type. + set_key(&(*it), key_info.delkey); + return retval; + } + // Set it so test_deleted is false. true if object used to be deleted. + bool clear_deleted(iterator& it) { + check_use_deleted("clear_deleted()"); + // Happens automatically when we assign something else in its place. + return test_deleted(it); + } + + // We also allow to set/clear the deleted bit on a const iterator. + // We allow a const_iterator for the same reason you can delete a + // const pointer: it's convenient, and semantically you can't use + // 'it' after it's been deleted anyway, so its const-ness doesn't + // really matter. + bool set_deleted(const_iterator& it) { + check_use_deleted("set_deleted()"); + bool retval = !test_deleted(it); + set_key(const_cast(&(*it)), key_info.delkey); + return retval; + } + // Set it so test_deleted is false. true if object used to be deleted. + bool clear_deleted(const_iterator& it) { + check_use_deleted("clear_deleted()"); + return test_deleted(it); + } + + // EMPTY HELPER FUNCTIONS + // This lets the user describe a key that will indicate empty (unused) + // table entries. This key should be an "impossible" entry -- + // if you try to insert it for real, you won't be able to retrieve it! + // (NB: while you pass in an entire value, only the key part is looked + // at. This is just because I don't know how to assign just a key.) + public: + // These are public so the iterators can use them + // True if the item at position bucknum is "empty" marker + bool test_empty(size_type bucknum) const { + assert(settings.use_empty()); // we always need to know what's empty! + return equals(key_info.empty_key, get_key(table[bucknum])); + } + bool test_empty(const iterator& it) const { + assert(settings.use_empty()); // we always need to know what's empty! + return equals(key_info.empty_key, get_key(*it)); + } + bool test_empty(const const_iterator& it) const { + assert(settings.use_empty()); // we always need to know what's empty! + return equals(key_info.empty_key, get_key(*it)); + } + + private: + void fill_range_with_empty(pointer table_start, size_type count) { + for (size_type i = 0; i < count; ++i) + { + construct_key(&table_start[i], key_info.empty_key); + } + } + + public: + void set_empty_key(const key_type& key) { + // Once you set the empty key, you can't change it + assert(!settings.use_empty() && "Calling set_empty_key multiple times"); + // The deleted indicator (if specified) and the empty indicator + // must be different. + assert( + (!settings.use_deleted() || !equals(key, key_info.delkey)) && + "Setting the empty key the same as the deleted key"); + settings.set_use_empty(true); + key_info.empty_key = key; + + assert(!table); // must set before first use + // num_buckets was set in constructor even though table was NULL + table = val_info.allocate(num_buckets); + assert(table); + fill_range_with_empty(table, num_buckets); + } + key_type empty_key() const { + assert(settings.use_empty()); + return key_info.empty_key; + } + + // FUNCTIONS CONCERNING SIZE + public: + size_type size() const { return num_elements - num_deleted; } + size_type max_size() const { return val_info.max_size(); } + bool empty() const { return size() == 0; } + size_type bucket_count() const { return num_buckets; } + size_type max_bucket_count() const { return max_size(); } + size_type nonempty_bucket_count() const { return num_elements; } + // These are tr1 methods. Their idea of 'bucket' doesn't map well to + // what we do. We just say every bucket has 0 or 1 items in it. + size_type bucket_size(size_type i) const { + return begin(i) == end(i) ? 0 : 1; + } + + private: + // Because of the above, size_type(-1) is never legal; use it for errors + static const size_type ILLEGAL_BUCKET = size_type(-1); + + // Used after a string of deletes. Returns true if we actually shrunk. + // TODO(csilvers): take a delta so we can take into account inserts + // done after shrinking. Maybe make part of the Settings class? + bool maybe_shrink() { + assert(num_elements >= num_deleted); + assert((bucket_count() & (bucket_count() - 1)) == 0); // is a power of two + assert(bucket_count() >= HT_MIN_BUCKETS); + bool retval = false; + + // If you construct a hashtable with < HT_DEFAULT_STARTING_BUCKETS, + // we'll never shrink until you get relatively big, and we'll never + // shrink below HT_DEFAULT_STARTING_BUCKETS. Otherwise, something + // like "dense_hash_set x; x.insert(4); x.erase(4);" will + // shrink us down to HT_MIN_BUCKETS buckets, which is too small. + const size_type num_remain = num_elements - num_deleted; + const size_type shrink_threshold = settings.shrink_threshold(); + if (shrink_threshold > 0 && num_remain < shrink_threshold && + bucket_count() > HT_DEFAULT_STARTING_BUCKETS) { + const float shrink_factor = settings.shrink_factor(); + size_type sz = bucket_count() / 2; // find how much we should shrink + while (sz > HT_DEFAULT_STARTING_BUCKETS && + num_remain < sz * shrink_factor) { + sz /= 2; // stay a power of 2 + } + dense_hashtable tmp(std::move(*this), sz); // Do the actual resizing + swap(tmp); // now we are tmp + retval = true; + } + settings.set_consider_shrink(false); // because we just considered it + return retval; + } + + // We'll let you resize a hashtable -- though this makes us copy all! + // When you resize, you say, "make it big enough for this many more elements" + // Returns true if we actually resized, false if size was already ok. + bool resize_delta(size_type delta) { + bool did_resize = false; + if (settings.consider_shrink()) { // see if lots of deletes happened + if (maybe_shrink()) did_resize = true; + } + if (num_elements >= (std::numeric_limits::max)() - delta) { + throw std::length_error("resize overflow"); + } + if (bucket_count() >= HT_MIN_BUCKETS && + (num_elements + delta) <= settings.enlarge_threshold()) + return did_resize; // we're ok as we are + + // Sometimes, we need to resize just to get rid of all the + // "deleted" buckets that are clogging up the hashtable. So when + // deciding whether to resize, count the deleted buckets (which + // are currently taking up room). But later, when we decide what + // size to resize to, *don't* count deleted buckets, since they + // get discarded during the resize. + size_type needed_size = settings.min_buckets(num_elements + delta, 0); + if (needed_size <= bucket_count()) // we have enough buckets + return did_resize; + + size_type resize_to = settings.min_buckets( + num_elements - num_deleted + delta, bucket_count()); + + // When num_deleted is large, we may still grow but we do not want to + // over expand. So we reduce needed_size by a portion of num_deleted + // (the exact portion does not matter). This is especially helpful + // when min_load_factor is zero (no shrink at all) to avoid doubling + // the bucket count to infinity. See also test ResizeWithoutShrink. + needed_size = settings.min_buckets(num_elements - num_deleted / 4 + delta, 0); + + if (resize_to < needed_size && // may double resize_to + resize_to < (std::numeric_limits::max)() / 2) { + // This situation means that we have enough deleted elements, + // that once we purge them, we won't actually have needed to + // grow. But we may want to grow anyway: if we just purge one + // element, say, we'll have to grow anyway next time we + // insert. Might as well grow now, since we're already going + // through the trouble of copying (in order to purge the + // deleted elements). + const size_type target = + static_cast(settings.shrink_size(resize_to * 2)); + if (num_elements - num_deleted + delta >= target) { + // Good, we won't be below the shrink threshhold even if we double. + resize_to *= 2; + } + } + dense_hashtable tmp(std::move(*this), resize_to); + swap(tmp); // now we are tmp + return true; + } + + // We require table be not-NULL and empty before calling this. + void resize_table(size_type /*old_size*/, size_type new_size, + std::true_type) { + table = val_info.realloc_or_die(table, new_size); + } + + void resize_table(size_type old_size, size_type new_size, std::false_type) { + val_info.deallocate(table, old_size); + table = val_info.allocate(new_size); + } + + // Used to actually do the rehashing when we grow/shrink a hashtable + template + void copy_or_move_from(Hashtable&& ht, size_type min_buckets_wanted) { + clear_to_size(settings.min_buckets(ht.size(), min_buckets_wanted)); + + // We use a normal iterator to get non-deleted bcks from ht + // We could use insert() here, but since we know there are + // no duplicates and no deleted items, we can be more efficient + assert((bucket_count() & (bucket_count() - 1)) == 0); // a power of two + for (auto&& value : ht) { + size_type num_probes = 0; // how many times we've probed + size_type bucknum; + const size_type bucket_count_minus_one = bucket_count() - 1; + for (bucknum = hash(get_key(value)) & bucket_count_minus_one; + !test_empty(bucknum); // not empty + bucknum = + (bucknum + JUMP_(key, num_probes)) & bucket_count_minus_one) { + ++num_probes; + assert(num_probes < bucket_count() && + "Hashtable is full: an error in key_equal<> or hash<>"); + } + + using will_move = std::is_rvalue_reference; + using value_t = typename std::conditional::type; + + set_value(&table[bucknum], std::forward(value)); + num_elements++; + } + settings.inc_num_ht_copies(); + } + + // Required by the spec for hashed associative container + public: + // Though the docs say this should be num_buckets, I think it's much + // more useful as num_elements. As a special feature, calling with + // req_elements==0 will cause us to shrink if we can, saving space. + void resize(size_type req_elements) { // resize to this or larger + if (settings.consider_shrink() || req_elements == 0) maybe_shrink(); + if (req_elements > num_elements) resize_delta(req_elements - num_elements); + } + + // Get and change the value of shrink_factor and enlarge_factor. The + // description at the beginning of this file explains how to choose + // the values. Setting the shrink parameter to 0.0 ensures that the + // table never shrinks. + void get_resizing_parameters(float* shrink, float* grow) const { + *shrink = settings.shrink_factor(); + *grow = settings.enlarge_factor(); + } + void set_resizing_parameters(float shrink, float grow) { + settings.set_resizing_parameters(shrink, grow); + settings.reset_thresholds(bucket_count()); + } + + // CONSTRUCTORS -- as required by the specs, we take a size, + // but also let you specify a hashfunction, key comparator, + // and key extractor. We also define a copy constructor and =. + // DESTRUCTOR -- needs to free the table + explicit dense_hashtable(size_type expected_max_items_in_table = 0, + const HashFcn& hf = HashFcn(), + const EqualKey& eql = EqualKey(), + const ExtractKey& ext = ExtractKey(), + const SetKey& set = SetKey(), + const Alloc& alloc = Alloc()) + : settings(hf), + key_info(ext, set, eql), + num_deleted(0), + num_elements(0), + num_buckets(expected_max_items_in_table == 0 + ? HT_DEFAULT_STARTING_BUCKETS + : settings.min_buckets(expected_max_items_in_table, 0)), + val_info(alloc_impl(alloc)), + table(NULL) { + // table is NULL until emptyval is set. However, we set num_buckets + // here so we know how much space to allocate once emptyval is set + settings.reset_thresholds(bucket_count()); + } + + // As a convenience for resize(), we allow an optional second argument + // which lets you make this new hashtable a different size than ht + dense_hashtable(const dense_hashtable& ht, + size_type min_buckets_wanted = HT_DEFAULT_STARTING_BUCKETS) + : settings(ht.settings), + key_info(ht.key_info), + num_deleted(0), + num_elements(0), + num_buckets(0), + val_info(ht.val_info), + table(NULL) { + if (!ht.settings.use_empty()) { + // If use_empty isn't set, copy_from will crash, so we do our own copying. + assert(ht.empty()); + num_buckets = settings.min_buckets(ht.size(), min_buckets_wanted); + settings.reset_thresholds(bucket_count()); + return; + } + settings.reset_thresholds(bucket_count()); + copy_or_move_from(ht, min_buckets_wanted); // copy_or_move_from() ignores deleted entries + } + + dense_hashtable(dense_hashtable&& ht) + : dense_hashtable() { + swap(ht); + } + + dense_hashtable(dense_hashtable&& ht, + size_type min_buckets_wanted) + : settings(ht.settings), + key_info(ht.key_info), + num_deleted(0), + num_elements(0), + num_buckets(0), + val_info(std::move(ht.val_info)), + table(NULL) { + if (!ht.settings.use_empty()) { + // If use_empty isn't set, copy_or_move_from will crash, so we do our own copying. + assert(ht.empty()); + num_buckets = settings.min_buckets(ht.size(), min_buckets_wanted); + settings.reset_thresholds(bucket_count()); + return; + } + settings.reset_thresholds(bucket_count()); + copy_or_move_from(std::move(ht), min_buckets_wanted); // copy_or_move_from() ignores deleted entries + } + + dense_hashtable& operator=(const dense_hashtable& ht) { + if (&ht == this) return *this; // don't copy onto ourselves + if (!ht.settings.use_empty()) { + assert(ht.empty()); + dense_hashtable empty_table(ht); // empty table with ht's thresholds + this->swap(empty_table); + return *this; + } + settings = ht.settings; + key_info = ht.key_info; + // copy_or_move_from() calls clear and sets num_deleted to 0 too + copy_or_move_from(ht, HT_MIN_BUCKETS); + // we purposefully don't copy the allocator, which may not be copyable + return *this; + } + + dense_hashtable& operator=(dense_hashtable&& ht) { + assert(&ht != this); // this should not happen + swap(ht); + return *this; + } + + ~dense_hashtable() { + if (table) { + destroy_buckets(0, num_buckets); + val_info.deallocate(table, num_buckets); + } + } + + // Many STL algorithms use swap instead of copy constructors + void swap(dense_hashtable& ht) { + std::swap(settings, ht.settings); + std::swap(key_info, ht.key_info); + std::swap(num_deleted, ht.num_deleted); + std::swap(num_elements, ht.num_elements); + std::swap(num_buckets, ht.num_buckets); + std::swap(table, ht.table); + settings.reset_thresholds(bucket_count()); // also resets consider_shrink + ht.settings.reset_thresholds(ht.bucket_count()); + // we purposefully don't swap the allocator, which may not be swap-able + } + + private: + void clear_to_size(size_type new_num_buckets) { + if (!table) { + table = val_info.allocate(new_num_buckets); + } else { + destroy_buckets(0, num_buckets); + if (new_num_buckets != num_buckets) { // resize, if necessary + typedef std::integral_constant< + bool, std::is_same>::value> + realloc_ok; + resize_table(num_buckets, new_num_buckets, realloc_ok()); + } + } + assert(table); + fill_range_with_empty(table, new_num_buckets); + num_elements = 0; + num_deleted = 0; + num_buckets = new_num_buckets; // our new size + settings.reset_thresholds(bucket_count()); + } + + public: + // It's always nice to be able to clear a table without deallocating it + void clear() { + // If the table is already empty, and the number of buckets is + // already as we desire, there's nothing to do. + const size_type new_num_buckets = settings.min_buckets(0, 0); + if (num_elements == 0 && new_num_buckets == num_buckets) { + return; + } + clear_to_size(new_num_buckets); + } + + // Clear the table without resizing it. + // Mimicks the stl_hashtable's behaviour when clear()-ing in that it + // does not modify the bucket count + void clear_no_resize() { + if (num_elements > 0) { + assert(table); + destroy_buckets(0, num_buckets); + fill_range_with_empty(table, num_buckets); + } + // don't consider to shrink before another erase() + settings.reset_thresholds(bucket_count()); + num_elements = 0; + num_deleted = 0; + } + + // LOOKUP ROUTINES + private: + // Returns a pair of positions: 1st where the object is, 2nd where + // it would go if you wanted to insert it. 1st is ILLEGAL_BUCKET + // if object is not found; 2nd is ILLEGAL_BUCKET if it is. + // Note: because of deletions where-to-insert is not trivial: it's the + // first deleted bucket we see, as long as we don't find the key later + template + std::pair find_position(const K& key) const { + size_type num_probes = 0; // how many times we've probed + const size_type bucket_count_minus_one = bucket_count() - 1; + size_type bucknum = hash(key) & bucket_count_minus_one; + size_type insert_pos = ILLEGAL_BUCKET; // where we would insert + while (1) { // probe until something happens + if (test_empty(bucknum)) { // bucket is empty + if (insert_pos == ILLEGAL_BUCKET) // found no prior place to insert + return std::pair(ILLEGAL_BUCKET, bucknum); + else + return std::pair(ILLEGAL_BUCKET, insert_pos); + + } else if (test_deleted(bucknum)) { // keep searching, but mark to insert + if (insert_pos == ILLEGAL_BUCKET) insert_pos = bucknum; + + } else if (equals(key, get_key(table[bucknum]))) { + return std::pair(bucknum, ILLEGAL_BUCKET); + } + ++num_probes; // we're doing another probe + bucknum = (bucknum + JUMP_(key, num_probes)) & bucket_count_minus_one; + assert(num_probes < bucket_count() && + "Hashtable is full: an error in key_equal<> or hash<>"); + } + } + + public: + template + iterator find(const K& key) { + if (size() == 0) return end(); + std::pair pos = find_position(key); + if (pos.first == ILLEGAL_BUCKET) // alas, not there + return end(); + else + return iterator(this, table + pos.first, table + num_buckets, false); + } + + template + const_iterator find(const K& key) const { + if (size() == 0) return end(); + std::pair pos = find_position(key); + if (pos.first == ILLEGAL_BUCKET) // alas, not there + return end(); + else + return const_iterator(this, table + pos.first, table + num_buckets, + false); + } + + // This is a tr1 method: the bucket a given key is in, or what bucket + // it would be put in, if it were to be inserted. Shrug. + size_type bucket(const key_type& key) const { + std::pair pos = find_position(key); + return pos.first == ILLEGAL_BUCKET ? pos.second : pos.first; + } + + // Counts how many elements have key key. For maps, it's either 0 or 1. + template + size_type count(const K& key) const { + std::pair pos = find_position(key); + return pos.first == ILLEGAL_BUCKET ? 0 : 1; + } + + // Likewise, equal_range doesn't really make sense for us. Oh well. + template + std::pair equal_range(const K& key) { + iterator pos = find(key); // either an iterator or end + if (pos == end()) { + return std::pair(pos, pos); + } else { + const iterator startpos = pos++; + return std::pair(startpos, pos); + } + } + template + std::pair equal_range( + const K& key) const { + const_iterator pos = find(key); // either an iterator or end + if (pos == end()) { + return std::pair(pos, pos); + } else { + const const_iterator startpos = pos++; + return std::pair(startpos, pos); + } + } + + // INSERTION ROUTINES + private: + // Private method used by insert_noresize and find_or_insert. + template + iterator insert_at(size_type pos, Args&&... args) { + if (size() >= max_size()) { + throw std::length_error("insert overflow"); + } + if (test_deleted(pos)) { // just replace if it's been del. + // shrug: shouldn't need to be const. + const_iterator delpos(this, table + pos, table + num_buckets, false); + clear_deleted(delpos); + assert(num_deleted > 0); + --num_deleted; // used to be, now it isn't + } else { + ++num_elements; // replacing an empty bucket + } + set_value(&table[pos], std::forward(args)...); + return iterator(this, table + pos, table + num_buckets, false); + } + + // If you know *this is big enough to hold obj, use this routine + template + std::pair insert_noresize(K&& key, Args&&... args) { + // First, double-check we're not inserting delkey or emptyval + assert(settings.use_empty() && "Inserting without empty key"); + assert(!equals(std::forward(key), key_info.empty_key) && "Inserting the empty key"); + assert((!settings.use_deleted() || !equals(key, key_info.delkey)) && "Inserting the deleted key"); + + const std::pair pos = find_position(key); + if (pos.first != ILLEGAL_BUCKET) { // object was already there + return std::pair( + iterator(this, table + pos.first, table + num_buckets, false), + false); // false: we didn't insert + } else { // pos.second says where to put it + return std::pair(insert_at(pos.second, std::forward(args)...), true); + } + } + + // Specializations of insert(it, it) depending on the power of the iterator: + // (1) Iterator supports operator-, resize before inserting + template + void insert(ForwardIterator f, ForwardIterator l, std::forward_iterator_tag) { + size_t dist = std::distance(f, l); + if (dist >= (std::numeric_limits::max)()) { + throw std::length_error("insert-range overflow"); + } + resize_delta(static_cast(dist)); + for (; dist > 0; --dist, ++f) { + insert_noresize(get_key(*f), *f); + } + } + + // (2) Arbitrary iterator, can't tell how much to resize + template + void insert(InputIterator f, InputIterator l, std::input_iterator_tag) { + for (; f != l; ++f) insert(*f); + } + + public: + // This is the normal insert routine, used by the outside world + template + std::pair insert(Arg&& obj) { + resize_delta(1); // adding an object, grow if need be + return insert_noresize(get_key(std::forward(obj)), std::forward(obj)); + } + + template + std::pair emplace(K&& key, Args&&... args) { + resize_delta(1); + // here we push key twice as we need it once for the indexing, and the rest of the params are for the emplace itself + return insert_noresize(std::forward(key), std::forward(key), std::forward(args)...); + } + + /* Overload for maps: Here, K != V, and we need to pass hint->first to the equal() function. */ + template + typename std::enable_if::value, + std::pair>::type + emplace_hint(const_iterator hint, K&& key, Args&&... args) { + resize_delta(1); + + if ((hint != this->end()) && (equals(key, hint->first))) { + return {iterator(this, const_cast(hint.pos), const_cast(hint.end), false), false}; + } + + // here we push key twice as we need it once for the indexing, and the rest of the params are for the emplace itself + return insert_noresize(std::forward(key), std::forward(key), std::forward(args)...); + } + + /* Overload for sets: Here, K == V, and we need to pass *hint to the equal() function. */ + template + typename std::enable_if::value, + std::pair>::type + emplace_hint(const_iterator hint, K&& key, Args&&... args) { + resize_delta(1); + + if ((hint != this->end()) && (equals(key, *hint))) { + return {iterator(this, const_cast(hint.pos), const_cast(hint.end), false), false}; + } + + // here we push key twice as we need it once for the indexing, and the rest of the params are for the emplace itself + return insert_noresize(std::forward(key), std::forward(key), std::forward(args)...); + } + + // When inserting a lot at a time, we specialize on the type of iterator + template + void insert(InputIterator f, InputIterator l) { + // specializes on iterator type + insert(f, l, + typename std::iterator_traits::iterator_category()); + } + + // DefaultValue is a functor that takes a key and returns a value_type + // representing the default value to be inserted if none is found. + template + value_type& find_or_insert(K&& key) { + // First, double-check we're not inserting emptykey or delkey + assert( + (!settings.use_empty() || !equals(key, key_info.empty_key)) && + "Inserting the empty key"); + assert((!settings.use_deleted() || !equals(key, key_info.delkey)) && + "Inserting the deleted key"); + const std::pair pos = find_position(key); + if (pos.first != ILLEGAL_BUCKET) { // object was already there + return table[pos.first]; + } else if (resize_delta(1)) { // needed to rehash to make room + // Since we resized, we can't use pos, so recalculate where to insert. + return *insert_noresize(std::forward(key), std::forward(key), T()).first; + } else { // no need to rehash, insert right here + return *insert_at(pos.second, std::forward(key), T()); + } + } + + // DELETION ROUTINES + size_type erase(const key_type& key) { + // First, double-check we're not trying to erase delkey or emptyval. + assert( + (!settings.use_empty() || !equals(key, key_info.empty_key)) && + "Erasing the empty key"); + assert((!settings.use_deleted() || !equals(key, key_info.delkey)) && + "Erasing the deleted key"); + const_iterator pos = find(key); // shrug: shouldn't need to be const + if (pos != end()) { + assert(!test_deleted(pos)); // or find() shouldn't have returned it + set_deleted(pos); + ++num_deleted; + settings.set_consider_shrink( + true); // will think about shrink after next insert + return 1; // because we deleted one thing + } else { + return 0; // because we deleted nothing + } + } + + // We return the iterator past the deleted item. + iterator erase(const_iterator pos) { + if (pos == end()) return end(); // sanity check + if (set_deleted(pos)) { // true if object has been newly deleted + ++num_deleted; + settings.set_consider_shrink( + true); // will think about shrink after next insert + } + return iterator(this, const_cast(pos.pos), const_cast(pos.end), true); + } + + iterator erase(const_iterator f, const_iterator l) { + for (; f != l; ++f) { + if (set_deleted(f)) // should always be true + ++num_deleted; + } + settings.set_consider_shrink( + true); // will think about shrink after next insert + return iterator(this, const_cast(f.pos), const_cast(f.end), false); + } + + // COMPARISON + bool operator==(const dense_hashtable& ht) const { + if (size() != ht.size()) { + return false; + } else if (this == &ht) { + return true; + } else { + // Iterate through the elements in "this" and see if the + // corresponding element is in ht + for (const_iterator it = begin(); it != end(); ++it) { + const_iterator it2 = ht.find(get_key(*it)); + if ((it2 == ht.end()) || (*it != *it2)) { + return false; + } + } + return true; + } + } + bool operator!=(const dense_hashtable& ht) const { return !(*this == ht); } + + // I/O + // We support reading and writing hashtables to disk. Alas, since + // I don't know how to write a hasher or key_equal, you have to make + // sure everything but the table is the same. We compact before writing. + private: + // Every time the disk format changes, this should probably change too + typedef unsigned long MagicNumberType; + static const MagicNumberType MAGIC_NUMBER = 0x13578642; + + public: + // I/O -- this is an add-on for writing hash table to disk + // + // INPUT and OUTPUT must be either a FILE, *or* a C++ stream + // (istream, ostream, etc) *or* a class providing + // Read(void*, size_t) and Write(const void*, size_t) + // (respectively), which writes a buffer into a stream + // (which the INPUT/OUTPUT instance presumably owns). + + typedef sparsehash_internal::pod_serializer NopointerSerializer; + + // ValueSerializer: a functor. operator()(OUTPUT*, const value_type&) + template + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + squash_deleted(); // so we don't have to worry about delkey + if (!sparsehash_internal::write_bigendian_number(fp, MAGIC_NUMBER, 4)) + return false; + if (!sparsehash_internal::write_bigendian_number(fp, num_buckets, 8)) + return false; + if (!sparsehash_internal::write_bigendian_number(fp, num_elements, 8)) + return false; + // Now write a bitmap of non-empty buckets. + for (size_type i = 0; i < num_buckets; i += 8) { + unsigned char bits = 0; + for (int bit = 0; bit < 8; ++bit) { + if (i + bit < num_buckets && !test_empty(i + bit)) bits |= (1 << bit); + } + if (!sparsehash_internal::write_data(fp, &bits, sizeof(bits))) + return false; + for (int bit = 0; bit < 8; ++bit) { + if (bits & (1 << bit)) { + if (!serializer(fp, table[i + bit])) return false; + } + } + } + return true; + } + + // INPUT: anything we've written an overload of read_data() for. + // ValueSerializer: a functor. operator()(INPUT*, value_type*) + template + bool unserialize(ValueSerializer serializer, INPUT* fp) { + assert(settings.use_empty() && "empty_key not set for read"); + + clear(); // just to be consistent + MagicNumberType magic_read; + if (!sparsehash_internal::read_bigendian_number(fp, &magic_read, 4)) + return false; + if (magic_read != MAGIC_NUMBER) { + return false; + } + size_type new_num_buckets; + if (!sparsehash_internal::read_bigendian_number(fp, &new_num_buckets, 8)) + return false; + clear_to_size(new_num_buckets); + if (!sparsehash_internal::read_bigendian_number(fp, &num_elements, 8)) + return false; + + // Read the bitmap of non-empty buckets. + for (size_type i = 0; i < num_buckets; i += 8) { + unsigned char bits; + if (!sparsehash_internal::read_data(fp, &bits, sizeof(bits))) + return false; + for (int bit = 0; bit < 8; ++bit) { + if (i + bit < num_buckets && (bits & (1 << bit))) { // not empty + if (!serializer(fp, &table[i + bit])) return false; + } + } + } + return true; + } + + private: + template + class alloc_impl : public A { + public: + typedef typename A::pointer pointer; + typedef typename A::size_type size_type; + + // Convert a normal allocator to one that has realloc_or_die() + alloc_impl(const A& a) : A(a) {} + + // realloc_or_die should only be used when using the default + // allocator (libc_allocator_with_realloc). + pointer realloc_or_die(pointer /*ptr*/, size_type /*n*/) { + fprintf(stderr, + "realloc_or_die is only supported for " + "libc_allocator_with_realloc\n"); + exit(1); + return NULL; + } + }; + + // A template specialization of alloc_impl for + // libc_allocator_with_realloc that can handle realloc_or_die. + template + class alloc_impl> + : public libc_allocator_with_realloc { + public: + typedef typename libc_allocator_with_realloc::pointer pointer; + typedef typename libc_allocator_with_realloc::size_type size_type; + + alloc_impl(const libc_allocator_with_realloc& a) + : libc_allocator_with_realloc(a) {} + + pointer realloc_or_die(pointer ptr, size_type n) { + pointer retval = this->reallocate(ptr, n); + if (retval == NULL) { + fprintf(stderr, + "sparsehash: FATAL ERROR: failed to reallocate " + "%lu elements for ptr %p", + static_cast(n), static_cast(ptr)); + exit(1); + } + return retval; + } + }; + + // Package allocator with emptyval to eliminate memory needed for + // the zero-size allocator. + // If new fields are added to this class, we should add them to + // operator= and swap. + class ValInfo : public alloc_impl { + public: + typedef typename alloc_impl::value_type value_type; + + ValInfo(const alloc_impl& a) + : alloc_impl(a) {} + }; + + // Package functors with another class to eliminate memory needed for + // zero-size functors. Since ExtractKey and hasher's operator() might + // have the same function signature, they must be packaged in + // different classes. + struct Settings + : sparsehash_internal::sh_hashtable_settings { + explicit Settings(const hasher& hf) + : sparsehash_internal::sh_hashtable_settings( + hf, HT_OCCUPANCY_PCT / 100.0f, HT_EMPTY_PCT / 100.0f) {} + }; + + // Packages ExtractKey and SetKey functors. + class KeyInfo : public ExtractKey, public SetKey, public EqualKey { + public: + KeyInfo(const ExtractKey& ek, const SetKey& sk, const EqualKey& eq) + : ExtractKey(ek), SetKey(sk), EqualKey(eq) {} + + // We want to return the exact same type as ExtractKey: Key or const Key& + template + typename ExtractKey::result_type get_key(V&& v) const { + return ExtractKey::operator()(std::forward(v)); + } + void set_key(pointer v, const key_type& k) const { + SetKey::operator()(v, k); + } + void construct_key(pointer v, const key_type& k) const { + SetKey::operator()(v, k, true); + } + template + bool equals(const K1& a, const K2& b) const { + return EqualKey::operator()(a, b); + } + + // Which key marks deleted entries. + // TODO(csilvers): make a pointer, and get rid of use_deleted (benchmark!) + typename std::remove_const::type delkey; + typename std::remove_const::type empty_key; + }; + + // Utility functions to access the templated operators + template + size_type hash(const K& v) const { return settings.hash(v); } + template + bool equals(const K1& a, const K2& b) const { + return key_info.equals(a, b); + } + template + typename ExtractKey::result_type get_key(V&& v) const { + return key_info.get_key(std::forward(v)); + } + void set_key(pointer v, const key_type& k) const { key_info.set_key(v, k); } + void construct_key(pointer v, const key_type& k) const { key_info.construct_key(v, k); } + + private: + // Actual data + Settings settings; + KeyInfo key_info; + + size_type num_deleted; // how many occupied buckets are marked deleted + size_type num_elements; + size_type num_buckets; + ValInfo val_info; // holds emptyval, and also the allocator + pointer table; +}; + +// We need a global swap as well +template +inline void swap(dense_hashtable& x, + dense_hashtable& y) { + x.swap(y); +} + +#undef JUMP_ + +template +const typename dense_hashtable::size_type + dense_hashtable::ILLEGAL_BUCKET; + +// How full we let the table get before we resize. Knuth says .8 is +// good -- higher causes us to probe too much, though saves memory. +// However, we go with .5, getting better performance at the cost of +// more space (a trade-off densehashtable explicitly chooses to make). +// Feel free to play around with different values, though, via +// max_load_factor() and/or set_resizing_parameters(). +template +const int dense_hashtable::HT_OCCUPANCY_PCT = 50; + +// How empty we let the table get before we resize lower. +// It should be less than OCCUPANCY_PCT / 2 or we thrash resizing. +template +const int dense_hashtable::HT_EMPTY_PCT = + static_cast( + 0.4 * dense_hashtable::HT_OCCUPANCY_PCT); + +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/internal/hashtable-common.h b/oap-native-sql/cpp/src/third_party/sparsehash/internal/hashtable-common.h new file mode 100644 index 000000000..8d4d3f72a --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/internal/hashtable-common.h @@ -0,0 +1,368 @@ +// Copyright (c) 2010, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// Provides classes shared by both sparse and dense hashtable. +// +// sh_hashtable_settings has parameters for growing and shrinking +// a hashtable. It also packages zero-size functor (ie. hasher). +// +// Other functions and classes provide common code for serializing +// and deserializing hashtables to a stream (such as a FILE*). + +#pragma once + +#include +#include +#include // for size_t +#include +#include // For length_error + +namespace google { +namespace sparsehash_internal { + +template struct make_void { typedef void type;}; +template using void_t = typename make_void::type; + +template +struct has_is_transparent : std::false_type {}; + +template +struct has_is_transparent> : std::true_type {}; + +template +struct has_transparent_key_equal : std::false_type {}; + +template +struct has_transparent_key_equal> : std::true_type {}; + +template ::value> +struct key_equal_chosen { + using type = EqualKey; +}; + +template +struct key_equal_chosen { + using type = typename HashFcn::transparent_key_equal; +}; + +// Adaptor methods for reading/writing data from an INPUT or OUPTUT +// variable passed to serialize() or unserialize(). For now we +// have implemented INPUT/OUTPUT for FILE*, istream*/ostream* (note +// they are pointers, unlike typical use), or else a pointer to +// something that supports a Read()/Write() method. +// +// For technical reasons, we implement read_data/write_data in two +// stages. The actual work is done in *_data_internal, which takes +// the stream argument twice: once as a template type, and once with +// normal type information. (We only use the second version.) We do +// this because of how C++ picks what function overload to use. If we +// implemented this the naive way: +// bool read_data(istream* is, const void* data, size_t length); +// template read_data(T* fp, const void* data, size_t length); +// C++ would prefer the second version for every stream type except +// istream. However, we want C++ to prefer the first version for +// streams that are *subclasses* of istream, such as istringstream. +// This is not possible given the way template types are resolved. So +// we split the stream argument in two, one of which is templated and +// one of which is not. The specialized functions (like the istream +// version above) ignore the template arg and use the second, 'type' +// arg, getting subclass matching as normal. The 'catch-all' +// functions (the second version above) use the template arg to deduce +// the type, and use a second, void* arg to achieve the desired +// 'catch-all' semantics. + +// ----- low-level I/O for FILE* ---- + +template +inline bool read_data_internal(Ignored*, FILE* fp, void* data, size_t length) { + return fread(data, length, 1, fp) == 1; +} + +template +inline bool write_data_internal(Ignored*, FILE* fp, const void* data, + size_t length) { + return fwrite(data, length, 1, fp) == 1; +} + +// ----- low-level I/O for iostream ---- + +// We want the caller to be responsible for #including , not +// us, because iostream is a big header! According to the standard, +// it's only legal to delay the instantiation the way we want to if +// the istream/ostream is a template type. So we jump through hoops. +template +inline bool read_data_internal_for_istream(ISTREAM* fp, void* data, + size_t length) { + return fp->read(reinterpret_cast(data), length).good(); +} +template +inline bool read_data_internal(Ignored*, std::istream* fp, void* data, + size_t length) { + return read_data_internal_for_istream(fp, data, length); +} + +template +inline bool write_data_internal_for_ostream(OSTREAM* fp, const void* data, + size_t length) { + return fp->write(reinterpret_cast(data), length).good(); +} +template +inline bool write_data_internal(Ignored*, std::ostream* fp, const void* data, + size_t length) { + return write_data_internal_for_ostream(fp, data, length); +} + +// ----- low-level I/O for custom streams ---- + +// The INPUT type needs to support a Read() method that takes a +// buffer and a length and returns the number of bytes read. +template +inline bool read_data_internal(INPUT* fp, void*, void* data, size_t length) { + return static_cast(fp->Read(data, length)) == length; +} + +// The OUTPUT type needs to support a Write() operation that takes +// a buffer and a length and returns the number of bytes written. +template +inline bool write_data_internal(OUTPUT* fp, void*, const void* data, + size_t length) { + return static_cast(fp->Write(data, length)) == length; +} + +// ----- low-level I/O: the public API ---- + +template +inline bool read_data(INPUT* fp, void* data, size_t length) { + return read_data_internal(fp, fp, data, length); +} + +template +inline bool write_data(OUTPUT* fp, const void* data, size_t length) { + return write_data_internal(fp, fp, data, length); +} + +// Uses read_data() and write_data() to read/write an integer. +// length is the number of bytes to read/write (which may differ +// from sizeof(IntType), allowing us to save on a 32-bit system +// and load on a 64-bit system). Excess bytes are taken to be 0. +// INPUT and OUTPUT must match legal inputs to read/write_data (above). +template +bool read_bigendian_number(INPUT* fp, IntType* value, size_t length) { + *value = 0; + unsigned char byte; + // We require IntType to be unsigned or else the shifting gets all screwy. + static_assert(static_cast(-1) > static_cast(0), + "serializing int requires an unsigned type"); + for (size_t i = 0; i < length; ++i) { + if (!read_data(fp, &byte, sizeof(byte))) return false; + *value |= static_cast(byte) << ((length - 1 - i) * 8); + } + return true; +} + +template +bool write_bigendian_number(OUTPUT* fp, IntType value, size_t length) { + unsigned char byte; + // We require IntType to be unsigned or else the shifting gets all screwy. + static_assert(static_cast(-1) > static_cast(0), + "serializing int requires an unsigned type"); + for (size_t i = 0; i < length; ++i) { + byte = (sizeof(value) <= length - 1 - i) + ? 0 + : static_cast((value >> ((length - 1 - i) * 8)) & + 255); + if (!write_data(fp, &byte, sizeof(byte))) return false; + } + return true; +} + +// If your keys and values are simple enough, you can pass this +// serializer to serialize()/unserialize(). "Simple enough" means +// value_type is a POD type that contains no pointers. Note, +// however, we don't try to normalize endianness. +// This is the type used for NopointerSerializer. +template +struct pod_serializer { + template + bool operator()(INPUT* fp, value_type* value) const { + return read_data(fp, value, sizeof(*value)); + } + + template + bool operator()(OUTPUT* fp, const value_type& value) const { + return write_data(fp, &value, sizeof(value)); + } +}; + +// Settings contains parameters for growing and shrinking the table. +// It also packages zero-size functor (ie. hasher). +// +// It does some munging of the hash value in cases where we think +// (fear) the original hash function might not be very good. In +// particular, the default hash of pointers is the identity hash, +// so probably all the low bits are 0. We identify when we think +// we're hashing a pointer, and chop off the low bits. Note this +// isn't perfect: even when the key is a pointer, we can't tell +// for sure that the hash is the identity hash. If it's not, this +// is needless work (and possibly, though not likely, harmful). + +template +class sh_hashtable_settings : public HashFunc { + public: + typedef Key key_type; + typedef HashFunc hasher; + typedef SizeType size_type; + static_assert(!has_transparent_key_equal::value || has_is_transparent::value, + "hash provided non-transparent key_equal"); + + public: + sh_hashtable_settings(const hasher& hf, const float ht_occupancy_flt, + const float ht_empty_flt) + : hasher(hf), + enlarge_threshold_(0), + shrink_threshold_(0), + consider_shrink_(false), + use_empty_(false), + use_deleted_(false), + num_ht_copies_(0) { + set_enlarge_factor(ht_occupancy_flt); + set_shrink_factor(ht_empty_flt); + } + + template + size_type hash(const K& v) const { + // We munge the hash value when we don't trust hasher::operator(). + return hash_munger::MungedHash(hasher::operator()(v)); + } + + float enlarge_factor() const { return enlarge_factor_; } + void set_enlarge_factor(float f) { enlarge_factor_ = f; } + float shrink_factor() const { return shrink_factor_; } + void set_shrink_factor(float f) { shrink_factor_ = f; } + + size_type enlarge_threshold() const { return enlarge_threshold_; } + void set_enlarge_threshold(size_type t) { enlarge_threshold_ = t; } + size_type shrink_threshold() const { return shrink_threshold_; } + void set_shrink_threshold(size_type t) { shrink_threshold_ = t; } + + size_type enlarge_size(size_type x) const { + return static_cast(x * enlarge_factor_); + } + size_type shrink_size(size_type x) const { + return static_cast(x * shrink_factor_); + } + + bool consider_shrink() const { return consider_shrink_; } + void set_consider_shrink(bool t) { consider_shrink_ = t; } + + bool use_empty() const { return use_empty_; } + void set_use_empty(bool t) { use_empty_ = t; } + + bool use_deleted() const { return use_deleted_; } + void set_use_deleted(bool t) { use_deleted_ = t; } + + size_type num_ht_copies() const { + return static_cast(num_ht_copies_); + } + void inc_num_ht_copies() { ++num_ht_copies_; } + + // Reset the enlarge and shrink thresholds + void reset_thresholds(size_type num_buckets) { + set_enlarge_threshold(enlarge_size(num_buckets)); + set_shrink_threshold(shrink_size(num_buckets)); + // whatever caused us to reset already considered + set_consider_shrink(false); + } + + // Caller is resposible for calling reset_threshold right after + // set_resizing_parameters. + void set_resizing_parameters(float shrink, float grow) { + assert(shrink >= 0.0); + assert(grow <= 1.0); + if (shrink > grow / 2.0f) + shrink = grow / 2.0f; // otherwise we thrash hashtable size + set_shrink_factor(shrink); + set_enlarge_factor(grow); + } + + // This is the smallest size a hashtable can be without being too crowded + // If you like, you can give a min #buckets as well as a min #elts + size_type min_buckets(size_type num_elts, size_type min_buckets_wanted) { + float enlarge = enlarge_factor(); + size_type sz = HT_MIN_BUCKETS; // min buckets allowed + while (sz < min_buckets_wanted || + num_elts >= static_cast(sz * enlarge)) { + // This just prevents overflowing size_type, since sz can exceed + // max_size() here. + if (static_cast(sz * 2) < sz) { + throw std::length_error("resize overflow"); // protect against overflow + } + sz *= 2; + } + return sz; + } + + private: + template + class hash_munger { + public: + static size_t MungedHash(size_t hash) { return hash; } + }; + // This matches when the hashtable key is a pointer. + template + class hash_munger { + public: + static size_t MungedHash(size_t hash) { + // TODO(csilvers): consider rotating instead: + // static const int shift = (sizeof(void *) == 4) ? 2 : 3; + // return (hash << (sizeof(hash) * 8) - shift)) | (hash >> + // shift); + // This matters if we ever change sparse/dense_hash_* to compare + // hashes before comparing actual values. It's speedy on x86. + return hash / sizeof(void*); // get rid of known-0 bits + } + }; + + size_type enlarge_threshold_; // table.size() * enlarge_factor + size_type shrink_threshold_; // table.size() * shrink_factor + float enlarge_factor_; // how full before resize + float shrink_factor_; // how empty before resize + // consider_shrink=true if we should try to shrink before next insert + bool consider_shrink_; + bool use_empty_; // used only by densehashtable, not sparsehashtable + bool use_deleted_; // false until delkey has been set + // num_ht_copies is a counter incremented every Copy/Move + unsigned int num_ht_copies_; +}; + +} // namespace sparsehash_internal +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/internal/libc_allocator_with_realloc.h b/oap-native-sql/cpp/src/third_party/sparsehash/internal/libc_allocator_with_realloc.h new file mode 100644 index 000000000..5b447d35c --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/internal/libc_allocator_with_realloc.h @@ -0,0 +1,113 @@ +// Copyright (c) 2010, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- + +#pragma once + +#include // for malloc/realloc/free +#include // for ptrdiff_t +#include // for placement new + +namespace google { +template +class libc_allocator_with_realloc { + public: + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + + libc_allocator_with_realloc() {} + libc_allocator_with_realloc(const libc_allocator_with_realloc&) {} + ~libc_allocator_with_realloc() {} + + pointer address(reference r) const { return &r; } + const_pointer address(const_reference r) const { return &r; } + + pointer allocate(size_type n, const_pointer = 0) { + return static_cast(malloc(n * sizeof(value_type))); + } + void deallocate(pointer p, size_type) { free(p); } + pointer reallocate(pointer p, size_type n) { + // p points to a storage array whose objects have already been destroyed + // cast to void* to prevent compiler warnings about calling realloc() on + // an object which cannot be relocated in memory + return static_cast(realloc(static_cast(p), n * sizeof(value_type))); + } + + size_type max_size() const { + return static_cast(-1) / sizeof(value_type); + } + + void construct(pointer p, const value_type& val) { new (p) value_type(val); } + void destroy(pointer p) { p->~value_type(); } + + template + libc_allocator_with_realloc(const libc_allocator_with_realloc&) {} + + template + struct rebind { + typedef libc_allocator_with_realloc other; + }; +}; + +// libc_allocator_with_realloc specialization. +template <> +class libc_allocator_with_realloc { + public: + typedef void value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef void* pointer; + typedef const void* const_pointer; + + template + struct rebind { + typedef libc_allocator_with_realloc other; + }; +}; + +template +inline bool operator==(const libc_allocator_with_realloc&, + const libc_allocator_with_realloc&) { + return true; +} + +template +inline bool operator!=(const libc_allocator_with_realloc&, + const libc_allocator_with_realloc&) { + return false; +} + +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/internal/sparsehashtable.h b/oap-native-sql/cpp/src/third_party/sparsehash/internal/sparsehashtable.h new file mode 100644 index 000000000..46878749a --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/internal/sparsehashtable.h @@ -0,0 +1,1265 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// A sparse hashtable is a particular implementation of +// a hashtable: one that is meant to minimize memory use. +// It does this by using a *sparse table* (cf sparsetable.h), +// which uses between 1 and 2 bits to store empty buckets +// (we may need another bit for hashtables that support deletion). +// +// When empty buckets are so cheap, an appealing hashtable +// implementation is internal probing, in which the hashtable +// is a single table, and collisions are resolved by trying +// to insert again in another bucket. The most cache-efficient +// internal probing schemes are linear probing (which suffers, +// alas, from clumping) and quadratic probing, which is what +// we implement by default. +// +// Deleted buckets are a bit of a pain. We have to somehow mark +// deleted buckets (the probing must distinguish them from empty +// buckets). The most principled way is to have another bitmap, +// but that's annoying and takes up space. Instead we let the +// user specify an "impossible" key. We set deleted buckets +// to have the impossible key. +// +// Note it is possible to change the value of the delete key +// on the fly; you can even remove it, though after that point +// the hashtable is insert_only until you set it again. +// +// You probably shouldn't use this code directly. Use +// sparse_hash_map<> or sparse_hash_set<> instead. +// +// You can modify the following, below: +// HT_OCCUPANCY_PCT -- how full before we double size +// HT_EMPTY_PCT -- how empty before we halve size +// HT_MIN_BUCKETS -- smallest bucket size +// HT_DEFAULT_STARTING_BUCKETS -- default bucket size at construct-time +// +// You can also change enlarge_factor (which defaults to +// HT_OCCUPANCY_PCT), and shrink_factor (which defaults to +// HT_EMPTY_PCT) with set_resizing_parameters(). +// +// How to decide what values to use? +// shrink_factor's default of .4 * OCCUPANCY_PCT, is probably good. +// HT_MIN_BUCKETS is probably unnecessary since you can specify +// (indirectly) the starting number of buckets at construct-time. +// For enlarge_factor, you can use this chart to try to trade-off +// expected lookup time to the space taken up. By default, this +// code uses quadratic probing, though you can change it to linear +// via _JUMP below if you really want to. +// +// From +// http://www.augustana.ca/~mohrj/courses/1999.fall/csc210/lecture_notes/hashing.html +// NUMBER OF PROBES / LOOKUP Successful Unsuccessful +// Quadratic collision resolution 1 - ln(1-L) - L/2 1/(1-L) - L - ln(1-L) +// Linear collision resolution [1+1/(1-L)]/2 [1+1/(1-L)2]/2 +// +// -- enlarge_factor -- 0.10 0.50 0.60 0.75 0.80 0.90 0.99 +// QUADRATIC COLLISION RES. +// probes/successful lookup 1.05 1.44 1.62 2.01 2.21 2.85 5.11 +// probes/unsuccessful lookup 1.11 2.19 2.82 4.64 5.81 11.4 103.6 +// LINEAR COLLISION RES. +// probes/successful lookup 1.06 1.5 1.75 2.5 3.0 5.5 50.5 +// probes/unsuccessful lookup 1.12 2.5 3.6 8.5 13.0 50.0 5000.0 +// +// The value type is required to be copy constructible and default +// constructible, but it need not be (and commonly isn't) assignable. + +#pragma once + +#include +#include // For swap(), eg +#include // for iterator tags +#include // for numeric_limits +#include // for pair +#include // for remove_const +#include +#include // IWYU pragma: export +#include // For length_error + +namespace google { + +#ifndef SPARSEHASH_STAT_UPDATE +#define SPARSEHASH_STAT_UPDATE(x) ((void)0) +#endif + +// The probing method +// Linear probing +// #define JUMP_(key, num_probes) ( 1 ) +// Quadratic probing +#define JUMP_(key, num_probes) (num_probes) + +// The smaller this is, the faster lookup is (because the group bitmap is +// smaller) and the faster insert is, because there's less to move. +// On the other hand, there are more groups. Since group::size_type is +// a short, this number should be of the form 32*x + 16 to avoid waste. +static const uint16_t DEFAULT_GROUP_SIZE = 48; // fits in 1.5 words + +// Hashtable class, used to implement the hashed associative containers +// hash_set and hash_map. +// +// Value: what is stored in the table (each bucket is a Value). +// Key: something in a 1-to-1 correspondence to a Value, that can be used +// to search for a Value in the table (find() takes a Key). +// HashFcn: Takes a Key and returns an integer, the more unique the better. +// ExtractKey: given a Value, returns the unique Key associated with it. +// Must inherit from unary_function, or at least have a +// result_type enum indicating the return type of operator(). +// SetKey: given a Value* and a Key, modifies the value such that +// ExtractKey(value) == key. We guarantee this is only called +// with key == deleted_key. +// EqualKey: Given two Keys, says whether they are the same (that is, +// if they are both associated with the same Value). +// Alloc: STL allocator to use to allocate memory. + +template +class sparse_hashtable; + +template +struct sparse_hashtable_iterator; + +template +struct sparse_hashtable_const_iterator; + +// As far as iterating, we're basically just a sparsetable +// that skips over deleted elements. +template +struct sparse_hashtable_iterator { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + + public: + typedef sparse_hashtable_iterator iterator; + typedef sparse_hashtable_const_iterator + const_iterator; + typedef typename sparsetable::nonempty_iterator st_iterator; + + typedef std::forward_iterator_tag iterator_category; // very little defined! + typedef V value_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::reference reference; + typedef typename value_alloc_type::pointer pointer; + + // "Real" constructor and default constructor + sparse_hashtable_iterator( + const sparse_hashtable* h, st_iterator it, + st_iterator it_end) + : ht(h), pos(it), end(it_end) { + advance_past_deleted(); + } + sparse_hashtable_iterator() {} // not ever used internally + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // Happy dereferencer + reference operator*() const { return *pos; } + pointer operator->() const { return &(operator*()); } + + // Arithmetic. The only hard part is making sure that + // we're not on a marked-deleted array element + void advance_past_deleted() { + while (pos != end && ht->test_deleted(*this)) ++pos; + } + iterator& operator++() { + assert(pos != end); + ++pos; + advance_past_deleted(); + return *this; + } + iterator operator++(int) { + iterator tmp(*this); + ++*this; + return tmp; + } + + // Comparison. + bool operator==(const iterator& it) const { return pos == it.pos; } + bool operator!=(const iterator& it) const { return pos != it.pos; } + + // The actual data + const sparse_hashtable* ht; + st_iterator pos, end; +}; + +// Now do it all again, but with const-ness! +template +struct sparse_hashtable_const_iterator { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + + public: + typedef sparse_hashtable_iterator iterator; + typedef sparse_hashtable_const_iterator + const_iterator; + typedef typename sparsetable::const_nonempty_iterator + st_iterator; + + typedef std::forward_iterator_tag iterator_category; // very little defined! + typedef V value_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::const_reference reference; + typedef typename value_alloc_type::const_pointer pointer; + + // "Real" constructor and default constructor + sparse_hashtable_const_iterator( + const sparse_hashtable* h, st_iterator it, + st_iterator it_end) + : ht(h), pos(it), end(it_end) { + advance_past_deleted(); + } + // This lets us convert regular iterators to const iterators + sparse_hashtable_const_iterator() {} // never used internally + sparse_hashtable_const_iterator(const iterator& it) + : ht(it.ht), pos(it.pos), end(it.end) {} + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // Happy dereferencer + reference operator*() const { return *pos; } + pointer operator->() const { return &(operator*()); } + + // Arithmetic. The only hard part is making sure that + // we're not on a marked-deleted array element + void advance_past_deleted() { + while (pos != end && ht->test_deleted(*this)) ++pos; + } + const_iterator& operator++() { + assert(pos != end); + ++pos; + advance_past_deleted(); + return *this; + } + const_iterator operator++(int) { + const_iterator tmp(*this); + ++*this; + return tmp; + } + + // Comparison. + bool operator==(const const_iterator& it) const { return pos == it.pos; } + bool operator!=(const const_iterator& it) const { return pos != it.pos; } + + // The actual data + const sparse_hashtable* ht; + st_iterator pos, end; +}; + +// And once again, but this time freeing up memory as we iterate +template +struct sparse_hashtable_destructive_iterator { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + + public: + typedef sparse_hashtable_destructive_iterator + iterator; + typedef + typename sparsetable::destructive_iterator st_iterator; + + typedef std::forward_iterator_tag iterator_category; // very little defined! + typedef V value_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::reference reference; + typedef typename value_alloc_type::pointer pointer; + + // "Real" constructor and default constructor + sparse_hashtable_destructive_iterator( + const sparse_hashtable* h, st_iterator it, + st_iterator it_end) + : ht(h), pos(it), end(it_end) { + advance_past_deleted(); + } + sparse_hashtable_destructive_iterator() {} // never used internally + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // Happy dereferencer + reference operator*() const { return *pos; } + pointer operator->() const { return &(operator*()); } + + // Arithmetic. The only hard part is making sure that + // we're not on a marked-deleted array element + void advance_past_deleted() { + while (pos != end && ht->test_deleted(*this)) ++pos; + } + iterator& operator++() { + assert(pos != end); + ++pos; + advance_past_deleted(); + return *this; + } + iterator operator++(int) { + iterator tmp(*this); + ++*this; + return tmp; + } + + // Comparison. + bool operator==(const iterator& it) const { return pos == it.pos; } + bool operator!=(const iterator& it) const { return pos != it.pos; } + + // The actual data + const sparse_hashtable* ht; + st_iterator pos, end; +}; + +template +class sparse_hashtable { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + + public: + typedef Key key_type; + typedef Value value_type; + typedef HashFcn hasher; + typedef EqualKey key_equal; + typedef Alloc allocator_type; + + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::reference reference; + typedef typename value_alloc_type::const_reference const_reference; + typedef typename value_alloc_type::pointer pointer; + typedef typename value_alloc_type::const_pointer const_pointer; + typedef sparse_hashtable_iterator iterator; + + typedef sparse_hashtable_const_iterator< + Value, Key, HashFcn, ExtractKey, SetKey, EqualKey, Alloc> const_iterator; + + typedef sparse_hashtable_destructive_iterator destructive_iterator; + + // These come from tr1. For us they're the same as regular iterators. + typedef iterator local_iterator; + typedef const_iterator const_local_iterator; + + // How full we let the table get before we resize, by default. + // Knuth says .8 is good -- higher causes us to probe too much, + // though it saves memory. + static const int HT_OCCUPANCY_PCT; // = 80 (out of 100); + + // How empty we let the table get before we resize lower, by default. + // (0.0 means never resize lower.) + // It should be less than OCCUPANCY_PCT / 2 or we thrash resizing + static const int HT_EMPTY_PCT; // = 0.4 * HT_OCCUPANCY_PCT; + + // Minimum size we're willing to let hashtables be. + // Must be a power of two, and at least 4. + // Note, however, that for a given hashtable, the initial size is a + // function of the first constructor arg, and may be >HT_MIN_BUCKETS. + static const size_type HT_MIN_BUCKETS = 4; + + // By default, if you don't specify a hashtable size at + // construction-time, we use this size. Must be a power of two, and + // at least HT_MIN_BUCKETS. + static const size_type HT_DEFAULT_STARTING_BUCKETS = 32; + + // ITERATOR FUNCTIONS + iterator begin() { + return iterator(this, table.nonempty_begin(), table.nonempty_end()); + } + iterator end() { + return iterator(this, table.nonempty_end(), table.nonempty_end()); + } + const_iterator begin() const { + return const_iterator(this, table.nonempty_begin(), table.nonempty_end()); + } + const_iterator end() const { + return const_iterator(this, table.nonempty_end(), table.nonempty_end()); + } + + // These come from tr1 unordered_map. They iterate over 'bucket' n. + // For sparsehashtable, we could consider each 'group' to be a bucket, + // I guess, but I don't really see the point. We'll just consider + // bucket n to be the n-th element of the sparsetable, if it's occupied, + // or some empty element, otherwise. + local_iterator begin(size_type i) { + if (table.test(i)) + return local_iterator(this, table.get_iter(i), table.nonempty_end()); + else + return local_iterator(this, table.nonempty_end(), table.nonempty_end()); + } + local_iterator end(size_type i) { + local_iterator it = begin(i); + if (table.test(i) && !test_deleted(i)) ++it; + return it; + } + const_local_iterator begin(size_type i) const { + if (table.test(i)) + return const_local_iterator(this, table.get_iter(i), + table.nonempty_end()); + else + return const_local_iterator(this, table.nonempty_end(), + table.nonempty_end()); + } + const_local_iterator end(size_type i) const { + const_local_iterator it = begin(i); + if (table.test(i) && !test_deleted(i)) ++it; + return it; + } + + // This is used when resizing + destructive_iterator destructive_begin() { + return destructive_iterator(this, table.destructive_begin(), + table.destructive_end()); + } + destructive_iterator destructive_end() { + return destructive_iterator(this, table.destructive_end(), + table.destructive_end()); + } + + // ACCESSOR FUNCTIONS for the things we templatize on, basically + hasher hash_funct() const { return settings; } + key_equal key_eq() const { return key_info; } + allocator_type get_allocator() const { return table.get_allocator(); } + + // Accessor function for statistics gathering. + int num_table_copies() const { return settings.num_ht_copies(); } + + private: + // We need to copy values when we set the special marker for deleted + // elements, but, annoyingly, we can't just use the copy assignment + // operator because value_type might not be assignable (it's often + // pair). We use explicit destructor invocation and + // placement new to get around this. Arg. + void set_value(pointer dst, const_reference src) { + dst->~value_type(); // delete the old value, if any + new (dst) value_type(src); + } + + // This is used as a tag for the copy constructor, saying to destroy its + // arg We have two ways of destructively copying: with potentially growing + // the hashtable as we copy, and without. To make sure the outside world + // can't do a destructive copy, we make the typename private. + enum MoveDontCopyT { MoveDontCopy, MoveDontGrow }; + + // DELETE HELPER FUNCTIONS + // This lets the user describe a key that will indicate deleted + // table entries. This key should be an "impossible" entry -- + // if you try to insert it for real, you won't be able to retrieve it! + // (NB: while you pass in an entire value, only the key part is looked + // at. This is just because I don't know how to assign just a key.) + private: + void squash_deleted() { // gets rid of any deleted entries we have + if (num_deleted) { // get rid of deleted before writing + sparse_hashtable tmp(MoveDontGrow, *this); + swap(tmp); // now we are tmp + } + assert(num_deleted == 0); + } + + // Test if the given key is the deleted indicator. Requires + // num_deleted > 0, for correctness of read(), and because that + // guarantees that key_info.delkey is valid. + bool test_deleted_key(const key_type& key) const { + assert(num_deleted > 0); + return equals(key_info.delkey, key); + } + + public: + void set_deleted_key(const key_type& key) { + // It's only safe to change what "deleted" means if we purge deleted + // guys + squash_deleted(); + settings.set_use_deleted(true); + key_info.delkey = key; + } + void clear_deleted_key() { + squash_deleted(); + settings.set_use_deleted(false); + } + key_type deleted_key() const { + assert(settings.use_deleted() && + "Must set deleted key before calling deleted_key"); + return key_info.delkey; + } + + // These are public so the iterators can use them + // True if the item at position bucknum is "deleted" marker + bool test_deleted(size_type bucknum) const { + // Invariant: !use_deleted() implies num_deleted is 0. + assert(settings.use_deleted() || num_deleted == 0); + return num_deleted > 0 && table.test(bucknum) && + test_deleted_key(get_key(table.unsafe_get(bucknum))); + } + bool test_deleted(const iterator& it) const { + // Invariant: !use_deleted() implies num_deleted is 0. + assert(settings.use_deleted() || num_deleted == 0); + return num_deleted > 0 && test_deleted_key(get_key(*it)); + } + bool test_deleted(const const_iterator& it) const { + // Invariant: !use_deleted() implies num_deleted is 0. + assert(settings.use_deleted() || num_deleted == 0); + return num_deleted > 0 && test_deleted_key(get_key(*it)); + } + bool test_deleted(const destructive_iterator& it) const { + // Invariant: !use_deleted() implies num_deleted is 0. + assert(settings.use_deleted() || num_deleted == 0); + return num_deleted > 0 && test_deleted_key(get_key(*it)); + } + + private: + void check_use_deleted(const char* caller) { + (void)caller; // could log it if the assert failed + assert(settings.use_deleted()); + } + + // Set it so test_deleted is true. true if object didn't used to be + // deleted. + // TODO(csilvers): make these private (also in densehashtable.h) + bool set_deleted(iterator& it) { + check_use_deleted("set_deleted()"); + bool retval = !test_deleted(it); + // &* converts from iterator to value-type. + set_key(&(*it), key_info.delkey); + return retval; + } + // Set it so test_deleted is false. true if object used to be deleted. + bool clear_deleted(iterator& it) { + check_use_deleted("clear_deleted()"); + // Happens automatically when we assign something else in its place. + return test_deleted(it); + } + + // We also allow to set/clear the deleted bit on a const iterator. + // We allow a const_iterator for the same reason you can delete a + // const pointer: it's convenient, and semantically you can't use + // 'it' after it's been deleted anyway, so its const-ness doesn't + // really matter. + bool set_deleted(const_iterator& it) { + check_use_deleted("set_deleted()"); + bool retval = !test_deleted(it); + set_key(const_cast(&(*it)), key_info.delkey); + return retval; + } + // Set it so test_deleted is false. true if object used to be deleted. + bool clear_deleted(const_iterator& it) { + check_use_deleted("clear_deleted()"); + return test_deleted(it); + } + + // FUNCTIONS CONCERNING SIZE + public: + size_type size() const { return table.num_nonempty() - num_deleted; } + size_type max_size() const { return table.max_size(); } + bool empty() const { return size() == 0; } + size_type bucket_count() const { return table.size(); } + size_type max_bucket_count() const { return max_size(); } + // These are tr1 methods. Their idea of 'bucket' doesn't map well to + // what we do. We just say every bucket has 0 or 1 items in it. + size_type bucket_size(size_type i) const { + return begin(i) == end(i) ? 0 : 1; + } + + private: + // Because of the above, size_type(-1) is never legal; use it for errors + static const size_type ILLEGAL_BUCKET = size_type(-1); + + // Used after a string of deletes. Returns true if we actually shrunk. + // TODO(csilvers): take a delta so we can take into account inserts + // done after shrinking. Maybe make part of the Settings class? + bool maybe_shrink() { + assert(table.num_nonempty() >= num_deleted); + assert((bucket_count() & (bucket_count() - 1)) == 0); // is a power of two + assert(bucket_count() >= HT_MIN_BUCKETS); + bool retval = false; + + // If you construct a hashtable with < HT_DEFAULT_STARTING_BUCKETS, + // we'll never shrink until you get relatively big, and we'll never + // shrink below HT_DEFAULT_STARTING_BUCKETS. Otherwise, something + // like "dense_hash_set x; x.insert(4); x.erase(4);" will + // shrink us down to HT_MIN_BUCKETS buckets, which is too small. + const size_type num_remain = table.num_nonempty() - num_deleted; + const size_type shrink_threshold = settings.shrink_threshold(); + if (shrink_threshold > 0 && num_remain < shrink_threshold && + bucket_count() > HT_DEFAULT_STARTING_BUCKETS) { + const float shrink_factor = settings.shrink_factor(); + size_type sz = bucket_count() / 2; // find how much we should shrink + while (sz > HT_DEFAULT_STARTING_BUCKETS && + num_remain < static_cast(sz * shrink_factor)) { + sz /= 2; // stay a power of 2 + } + sparse_hashtable tmp(MoveDontCopy, *this, sz); + swap(tmp); // now we are tmp + retval = true; + } + settings.set_consider_shrink(false); // because we just considered it + return retval; + } + + // We'll let you resize a hashtable -- though this makes us copy all! + // When you resize, you say, "make it big enough for this many more + // elements" + // Returns true if we actually resized, false if size was already ok. + bool resize_delta(size_type delta) { + bool did_resize = false; + if (settings.consider_shrink()) { // see if lots of deletes happened + if (maybe_shrink()) did_resize = true; + } + if (table.num_nonempty() >= + (std::numeric_limits::max)() - delta) { + throw std::length_error("resize overflow"); + } + if (bucket_count() >= HT_MIN_BUCKETS && + (table.num_nonempty() + delta) <= settings.enlarge_threshold()) + return did_resize; // we're ok as we are + + // Sometimes, we need to resize just to get rid of all the + // "deleted" buckets that are clogging up the hashtable. So when + // deciding whether to resize, count the deleted buckets (which + // are currently taking up room). But later, when we decide what + // size to resize to, *don't* count deleted buckets, since they + // get discarded during the resize. + const size_type needed_size = + settings.min_buckets(table.num_nonempty() + delta, 0); + if (needed_size <= bucket_count()) // we have enough buckets + return did_resize; + + size_type resize_to = settings.min_buckets( + table.num_nonempty() - num_deleted + delta, bucket_count()); + if (resize_to < needed_size && // may double resize_to + resize_to < (std::numeric_limits::max)() / 2) { + // This situation means that we have enough deleted elements, + // that once we purge them, we won't actually have needed to + // grow. But we may want to grow anyway: if we just purge one + // element, say, we'll have to grow anyway next time we + // insert. Might as well grow now, since we're already going + // through the trouble of copying (in order to purge the + // deleted elements). + const size_type target = + static_cast(settings.shrink_size(resize_to * 2)); + if (table.num_nonempty() - num_deleted + delta >= target) { + // Good, we won't be below the shrink threshhold even if we + // double. + resize_to *= 2; + } + } + + sparse_hashtable tmp(MoveDontCopy, *this, resize_to); + swap(tmp); // now we are tmp + return true; + } + + // Used to actually do the rehashing when we grow/shrink a hashtable + void copy_from(const sparse_hashtable& ht, size_type min_buckets_wanted) { + clear(); // clear table, set num_deleted to 0 + + // If we need to change the size of our table, do it now + const size_type resize_to = + settings.min_buckets(ht.size(), min_buckets_wanted); + if (resize_to > bucket_count()) { // we don't have enough buckets + table.resize(resize_to); // sets the number of buckets + settings.reset_thresholds(bucket_count()); + } + + // We use a normal iterator to get non-deleted bcks from ht + // We could use insert() here, but since we know there are + // no duplicates and no deleted items, we can be more efficient + assert((bucket_count() & (bucket_count() - 1)) == 0); // a power of two + for (const_iterator it = ht.begin(); it != ht.end(); ++it) { + size_type num_probes = 0; // how many times we've probed + size_type bucknum; + const size_type bucket_count_minus_one = bucket_count() - 1; + for (bucknum = hash(get_key(*it)) & bucket_count_minus_one; + table.test(bucknum); // not empty + bucknum = + (bucknum + JUMP_(key, num_probes)) & bucket_count_minus_one) { + ++num_probes; + assert(num_probes < bucket_count() && + "Hashtable is full: an error in key_equal<> or hash<>"); + } + table.set(bucknum, *it); // copies the value to here + } + settings.inc_num_ht_copies(); + } + + // Implementation is like copy_from, but it destroys the table of the + // "from" guy by freeing sparsetable memory as we iterate. This is + // useful in resizing, since we're throwing away the "from" guy anyway. + void move_from(MoveDontCopyT mover, sparse_hashtable& ht, + size_type min_buckets_wanted) { + clear(); // clear table, set num_deleted to 0 + + // If we need to change the size of our table, do it now + size_type resize_to; + if (mover == MoveDontGrow) + resize_to = ht.bucket_count(); // keep same size as old ht + else // MoveDontCopy + resize_to = settings.min_buckets(ht.size(), min_buckets_wanted); + if (resize_to > bucket_count()) { // we don't have enough buckets + table.resize(resize_to); // sets the number of buckets + settings.reset_thresholds(bucket_count()); + } + + // We use a normal iterator to get non-deleted bcks from ht + // We could use insert() here, but since we know there are + // no duplicates and no deleted items, we can be more efficient + assert((bucket_count() & (bucket_count() - 1)) == 0); // a power of two + // THIS IS THE MAJOR LINE THAT DIFFERS FROM COPY_FROM(): + for (destructive_iterator it = ht.destructive_begin(); + it != ht.destructive_end(); ++it) { + size_type num_probes = 0; // how many times we've probed + size_type bucknum; + for (bucknum = hash(get_key(*it)) & (bucket_count() - 1); // h % buck_cnt + table.test(bucknum); // not empty + bucknum = + (bucknum + JUMP_(key, num_probes)) & (bucket_count() - 1)) { + ++num_probes; + assert(num_probes < bucket_count() && + "Hashtable is full: an error in key_equal<> or hash<>"); + } + table.set(bucknum, *it); // copies the value to here + } + settings.inc_num_ht_copies(); + } + + // Required by the spec for hashed associative container + public: + // Though the docs say this should be num_buckets, I think it's much + // more useful as num_elements. As a special feature, calling with + // req_elements==0 will cause us to shrink if we can, saving space. + void resize(size_type req_elements) { // resize to this or larger + if (settings.consider_shrink() || req_elements == 0) maybe_shrink(); + if (req_elements > table.num_nonempty()) // we only grow + resize_delta(req_elements - table.num_nonempty()); + } + + // Get and change the value of shrink_factor and enlarge_factor. The + // description at the beginning of this file explains how to choose + // the values. Setting the shrink parameter to 0.0 ensures that the + // table never shrinks. + void get_resizing_parameters(float* shrink, float* grow) const { + *shrink = settings.shrink_factor(); + *grow = settings.enlarge_factor(); + } + void set_resizing_parameters(float shrink, float grow) { + settings.set_resizing_parameters(shrink, grow); + settings.reset_thresholds(bucket_count()); + } + + // CONSTRUCTORS -- as required by the specs, we take a size, + // but also let you specify a hashfunction, key comparator, + // and key extractor. We also define a copy constructor and =. + // DESTRUCTOR -- the default is fine, surprisingly. + explicit sparse_hashtable(size_type expected_max_items_in_table = 0, + const HashFcn& hf = HashFcn(), + const EqualKey& eql = EqualKey(), + const ExtractKey& ext = ExtractKey(), + const SetKey& set = SetKey(), + const Alloc& alloc = Alloc()) + : settings(hf), + key_info(ext, set, eql), + num_deleted(0), + table((expected_max_items_in_table == 0 + ? HT_DEFAULT_STARTING_BUCKETS + : settings.min_buckets(expected_max_items_in_table, 0)), + alloc) { + settings.reset_thresholds(bucket_count()); + } + + // As a convenience for resize(), we allow an optional second argument + // which lets you make this new hashtable a different size than ht. + // We also provide a mechanism of saying you want to "move" the ht argument + // into us instead of copying. + sparse_hashtable(const sparse_hashtable& ht, + size_type min_buckets_wanted = HT_DEFAULT_STARTING_BUCKETS) + : settings(ht.settings), + key_info(ht.key_info), + num_deleted(0), + table(0, ht.get_allocator()) { + settings.reset_thresholds(bucket_count()); + copy_from(ht, min_buckets_wanted); // copy_from() ignores deleted entries + } + sparse_hashtable(MoveDontCopyT mover, sparse_hashtable& ht, + size_type min_buckets_wanted = HT_DEFAULT_STARTING_BUCKETS) + : settings(ht.settings), + key_info(ht.key_info), + num_deleted(0), + table(0, ht.get_allocator()) { + settings.reset_thresholds(bucket_count()); + move_from(mover, ht, min_buckets_wanted); // ignores deleted entries + } + + sparse_hashtable& operator=(const sparse_hashtable& ht) { + if (&ht == this) return *this; // don't copy onto ourselves + settings = ht.settings; + key_info = ht.key_info; + num_deleted = ht.num_deleted; + // copy_from() calls clear and sets num_deleted to 0 too + copy_from(ht, HT_MIN_BUCKETS); + // we purposefully don't copy the allocator, which may not be copyable + return *this; + } + + // Many STL algorithms use swap instead of copy constructors + void swap(sparse_hashtable& ht) { + std::swap(settings, ht.settings); + std::swap(key_info, ht.key_info); + std::swap(num_deleted, ht.num_deleted); + table.swap(ht.table); + settings.reset_thresholds(bucket_count()); // also resets consider_shrink + ht.settings.reset_thresholds(ht.bucket_count()); + // we purposefully don't swap the allocator, which may not be swap-able + } + + // It's always nice to be able to clear a table without deallocating it + void clear() { + if (!empty() || (num_deleted != 0)) { + table.clear(); + } + settings.reset_thresholds(bucket_count()); + num_deleted = 0; + } + + // LOOKUP ROUTINES + private: + // Returns a pair of positions: 1st where the object is, 2nd where + // it would go if you wanted to insert it. 1st is ILLEGAL_BUCKET + // if object is not found; 2nd is ILLEGAL_BUCKET if it is. + // Note: because of deletions where-to-insert is not trivial: it's the + // first deleted bucket we see, as long as we don't find the key later + template + std::pair find_position(const K& key) const { + size_type num_probes = 0; // how many times we've probed + const size_type bucket_count_minus_one = bucket_count() - 1; + size_type bucknum = hash(key) & bucket_count_minus_one; + size_type insert_pos = ILLEGAL_BUCKET; // where we would insert + SPARSEHASH_STAT_UPDATE(total_lookups += 1); + while (1) { // probe until something happens + if (!table.test(bucknum)) { // bucket is empty + SPARSEHASH_STAT_UPDATE(total_probes += num_probes); + if (insert_pos == ILLEGAL_BUCKET) // found no prior place to insert + return std::pair(ILLEGAL_BUCKET, bucknum); + else + return std::pair(ILLEGAL_BUCKET, insert_pos); + } else if (test_deleted(bucknum)) { // keep searching, but mark to insert + if (insert_pos == ILLEGAL_BUCKET) insert_pos = bucknum; + } else if (equals(key, get_key(table.unsafe_get(bucknum)))) { + SPARSEHASH_STAT_UPDATE(total_probes += num_probes); + return std::pair(bucknum, ILLEGAL_BUCKET); + } + ++num_probes; // we're doing another probe + bucknum = (bucknum + JUMP_(key, num_probes)) & bucket_count_minus_one; + assert(num_probes < bucket_count() && + "Hashtable is full: an error in key_equal<> or hash<>"); + } + } + + public: + template + iterator find(const K& key) { + if (size() == 0) return end(); + std::pair pos = find_position(key); + if (pos.first == ILLEGAL_BUCKET) // alas, not there + return end(); + else + return iterator(this, table.get_iter(pos.first), table.nonempty_end()); + } + + template + const_iterator find(const K& key) const { + if (size() == 0) return end(); + std::pair pos = find_position(key); + if (pos.first == ILLEGAL_BUCKET) // alas, not there + return end(); + else + return const_iterator(this, table.get_iter(pos.first), + table.nonempty_end()); + } + + // This is a tr1 method: the bucket a given key is in, or what bucket + // it would be put in, if it were to be inserted. Shrug. + size_type bucket(const key_type& key) const { + std::pair pos = find_position(key); + return pos.first == ILLEGAL_BUCKET ? pos.second : pos.first; + } + + // Counts how many elements have key key. For maps, it's either 0 or 1. + template + size_type count(const K& key) const { + std::pair pos = find_position(key); + return pos.first == ILLEGAL_BUCKET ? 0 : 1; + } + + // Likewise, equal_range doesn't really make sense for us. Oh well. + template + std::pair equal_range(const K& key) { + iterator pos = find(key); // either an iterator or end + if (pos == end()) { + return std::pair(pos, pos); + } else { + const iterator startpos = pos++; + return std::pair(startpos, pos); + } + } + template + std::pair equal_range( + const K& key) const { + const_iterator pos = find(key); // either an iterator or end + if (pos == end()) { + return std::pair(pos, pos); + } else { + const const_iterator startpos = pos++; + return std::pair(startpos, pos); + } + } + + // INSERTION ROUTINES + private: + // Private method used by insert_noresize and find_or_insert. + iterator insert_at(const_reference obj, size_type pos) { + if (size() >= max_size()) { + throw std::length_error("insert overflow"); + } + if (test_deleted(pos)) { // just replace if it's been deleted + // The set() below will undelete this object. We just worry about + // stats + assert(num_deleted > 0); + --num_deleted; // used to be, now it isn't + } + table.set(pos, obj); + return iterator(this, table.get_iter(pos), table.nonempty_end()); + } + + // If you know *this is big enough to hold obj, use this routine + std::pair insert_noresize(const_reference obj) { + // First, double-check we're not inserting delkey + assert( + (!settings.use_deleted() || !equals(get_key(obj), key_info.delkey)) && + "Inserting the deleted key"); + const std::pair pos = find_position(get_key(obj)); + if (pos.first != ILLEGAL_BUCKET) { // object was already there + return std::pair( + iterator(this, table.get_iter(pos.first), table.nonempty_end()), + false); // false: we didn't insert + } else { // pos.second says where to put it + return std::pair(insert_at(obj, pos.second), true); + } + } + + // Specializations of insert(it, it) depending on the power of the iterator: + // (1) Iterator supports operator-, resize before inserting + template + void insert(ForwardIterator f, ForwardIterator l, std::forward_iterator_tag) { + size_t dist = std::distance(f, l); + if (dist >= (std::numeric_limits::max)()) { + throw std::length_error("insert-range overflow"); + } + resize_delta(static_cast(dist)); + for (; dist > 0; --dist, ++f) { + insert_noresize(*f); + } + } + + // (2) Arbitrary iterator, can't tell how much to resize + template + void insert(InputIterator f, InputIterator l, std::input_iterator_tag) { + for (; f != l; ++f) insert(*f); + } + + public: + // This is the normal insert routine, used by the outside world + std::pair insert(const_reference obj) { + resize_delta(1); // adding an object, grow if need be + return insert_noresize(obj); + } + + // When inserting a lot at a time, we specialize on the type of iterator + template + void insert(InputIterator f, InputIterator l) { + // specializes on iterator type + insert(f, l, + typename std::iterator_traits::iterator_category()); + } + + // DefaultValue is a functor that takes a key and returns a value_type + // representing the default value to be inserted if none is found. + template + value_type& find_or_insert(const key_type& key) { + // First, double-check we're not inserting delkey + assert((!settings.use_deleted() || !equals(key, key_info.delkey)) && + "Inserting the deleted key"); + const std::pair pos = find_position(key); + DefaultValue default_value; + if (pos.first != ILLEGAL_BUCKET) { // object was already there + return *table.get_iter(pos.first); + } else if (resize_delta(1)) { // needed to rehash to make room + // Since we resized, we can't use pos, so recalculate where to + // insert. + return *insert_noresize(default_value(key)).first; + } else { // no need to rehash, insert right here + return *insert_at(default_value(key), pos.second); + } + } + + // DELETION ROUTINES + size_type erase(const key_type& key) { + // First, double-check we're not erasing delkey. + assert((!settings.use_deleted() || !equals(key, key_info.delkey)) && + "Erasing the deleted key"); + assert(!settings.use_deleted() || !equals(key, key_info.delkey)); + const_iterator pos = find(key); // shrug: shouldn't need to be const + if (pos != end()) { + assert(!test_deleted(pos)); // or find() shouldn't have returned it + set_deleted(pos); + ++num_deleted; + // will think about shrink after next insert + settings.set_consider_shrink(true); + return 1; // because we deleted one thing + } else { + return 0; // because we deleted nothing + } + } + + // We return the iterator past the deleted item. + void erase(iterator pos) { + if (pos == end()) return; // sanity check + if (set_deleted(pos)) { // true if object has been newly deleted + ++num_deleted; + // will think about shrink after next insert + settings.set_consider_shrink(true); + } + } + + void erase(iterator f, iterator l) { + for (; f != l; ++f) { + if (set_deleted(f)) // should always be true + ++num_deleted; + } + // will think about shrink after next insert + settings.set_consider_shrink(true); + } + + // We allow you to erase a const_iterator just like we allow you to + // erase an iterator. This is in parallel to 'delete': you can delete + // a const pointer just like a non-const pointer. The logic is that + // you can't use the object after it's erased anyway, so it doesn't matter + // if it's const or not. + void erase(const_iterator pos) { + if (pos == end()) return; // sanity check + if (set_deleted(pos)) { // true if object has been newly deleted + ++num_deleted; + // will think about shrink after next insert + settings.set_consider_shrink(true); + } + } + void erase(const_iterator f, const_iterator l) { + for (; f != l; ++f) { + if (set_deleted(f)) // should always be true + ++num_deleted; + } + // will think about shrink after next insert + settings.set_consider_shrink(true); + } + + // COMPARISON + bool operator==(const sparse_hashtable& ht) const { + if (size() != ht.size()) { + return false; + } else if (this == &ht) { + return true; + } else { + // Iterate through the elements in "this" and see if the + // corresponding element is in ht + for (const_iterator it = begin(); it != end(); ++it) { + const_iterator it2 = ht.find(get_key(*it)); + if ((it2 == ht.end()) || (*it != *it2)) { + return false; + } + } + return true; + } + } + bool operator!=(const sparse_hashtable& ht) const { return !(*this == ht); } + + // I/O + // We support reading and writing hashtables to disk. NOTE that + // this only stores the hashtable metadata, not the stuff you've + // actually put in the hashtable! Alas, since I don't know how to + // write a hasher or key_equal, you have to make sure everything + // but the table is the same. We compact before writing. + // + // The OUTPUT type needs to support a Write() operation. File and + // OutputBuffer are appropriate types to pass in. + // + // The INPUT type needs to support a Read() operation. File and + // InputBuffer are appropriate types to pass in. + template + bool write_metadata(OUTPUT* fp) { + squash_deleted(); // so we don't have to worry about delkey + return table.write_metadata(fp); + } + + template + bool read_metadata(INPUT* fp) { + num_deleted = 0; // since we got rid before writing + const bool result = table.read_metadata(fp); + settings.reset_thresholds(bucket_count()); + return result; + } + + // Only meaningful if value_type is a POD. + template + bool write_nopointer_data(OUTPUT* fp) { + return table.write_nopointer_data(fp); + } + + // Only meaningful if value_type is a POD. + template + bool read_nopointer_data(INPUT* fp) { + return table.read_nopointer_data(fp); + } + + // INPUT and OUTPUT must be either a FILE, *or* a C++ stream + // (istream, ostream, etc) *or* a class providing + // Read(void*, size_t) and Write(const void*, size_t) + // (respectively), which writes a buffer into a stream + // (which the INPUT/OUTPUT instance presumably owns). + + typedef sparsehash_internal::pod_serializer NopointerSerializer; + + // ValueSerializer: a functor. operator()(OUTPUT*, const value_type&) + template + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + squash_deleted(); // so we don't have to worry about delkey + return table.serialize(serializer, fp); + } + + // ValueSerializer: a functor. operator()(INPUT*, value_type*) + template + bool unserialize(ValueSerializer serializer, INPUT* fp) { + num_deleted = 0; // since we got rid before writing + const bool result = table.unserialize(serializer, fp); + settings.reset_thresholds(bucket_count()); + return result; + } + + private: + // Table is the main storage class. + typedef sparsetable Table; + + // Package templated functors with the other types to eliminate memory + // needed for storing these zero-size operators. Since ExtractKey and + // hasher's operator() might have the same function signature, they + // must be packaged in different classes. + struct Settings + : sparsehash_internal::sh_hashtable_settings { + explicit Settings(const hasher& hf) + : sparsehash_internal::sh_hashtable_settings( + hf, HT_OCCUPANCY_PCT / 100.0f, HT_EMPTY_PCT / 100.0f) {} + }; + + // KeyInfo stores delete key and packages zero-size functors: + // ExtractKey and SetKey. + class KeyInfo : public ExtractKey, public SetKey, public EqualKey { + public: + KeyInfo(const ExtractKey& ek, const SetKey& sk, const EqualKey& eq) + : ExtractKey(ek), SetKey(sk), EqualKey(eq) {} + // We want to return the exact same type as ExtractKey: Key or const + // Key& + typename ExtractKey::result_type get_key(const_reference v) const { + return ExtractKey::operator()(v); + } + void set_key(pointer v, const key_type& k) const { + SetKey::operator()(v, k); + } + template + bool equals(const K1& a, const K2& b) const { + return EqualKey::operator()(a, b); + } + + // Which key marks deleted entries. + // TODO(csilvers): make a pointer, and get rid of use_deleted + // (benchmark!) + typename std::remove_const::type delkey; + }; + + // Utility functions to access the templated operators + template + size_type hash(const K& v) const { return settings.hash(v); } + template + bool equals(const K1& a, const K2& b) const { + return key_info.equals(a, b); + } + typename ExtractKey::result_type get_key(const_reference v) const { + return key_info.get_key(v); + } + void set_key(pointer v, const key_type& k) const { key_info.set_key(v, k); } + + private: + // Actual data + Settings settings; + KeyInfo key_info; + size_type num_deleted; // how many occupied buckets are marked deleted + Table table; // holds num_buckets and num_elements too +}; + +// We need a global swap as well +template +inline void swap(sparse_hashtable& x, + sparse_hashtable& y) { + x.swap(y); +} + +#undef JUMP_ + +template +const typename sparse_hashtable::size_type + sparse_hashtable::ILLEGAL_BUCKET; + +// How full we let the table get before we resize. Knuth says .8 is +// good -- higher causes us to probe too much, though saves memory +template +const int sparse_hashtable::HT_OCCUPANCY_PCT = 80; + +// How empty we let the table get before we resize lower. +// It should be less than OCCUPANCY_PCT / 2 or we thrash resizing +template +const int sparse_hashtable::HT_EMPTY_PCT = + static_cast( + 0.4 * sparse_hashtable::HT_OCCUPANCY_PCT); +} diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_map b/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_map new file mode 100644 index 000000000..f2b405132 --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_map @@ -0,0 +1,385 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// This is just a very thin wrapper over sparsehashtable.h, just +// like sgi stl's stl_hash_map is a very thin wrapper over +// stl_hashtable. The major thing we define is operator[], because +// we have a concept of a data_type which stl_hashtable doesn't +// (it only has a key and a value). +// +// We adhere mostly to the STL semantics for hash-map. One important +// exception is that insert() may invalidate iterators entirely -- STL +// semantics are that insert() may reorder iterators, but they all +// still refer to something valid in the hashtable. Not so for us. +// Likewise, insert() may invalidate pointers into the hashtable. +// (Whether insert invalidates iterators and pointers depends on +// whether it results in a hashtable resize). On the plus side, +// delete() doesn't invalidate iterators or pointers at all, or even +// change the ordering of elements. +// +// Here are a few "power user" tips: +// +// 1) set_deleted_key(): +// Unlike STL's hash_map, if you want to use erase() you +// *must* call set_deleted_key() after construction. +// +// 2) resize(0): +// When an item is deleted, its memory isn't freed right +// away. This is what allows you to iterate over a hashtable +// and call erase() without invalidating the iterator. +// To force the memory to be freed, call resize(0). +// For tr1 compatibility, this can also be called as rehash(0). +// +// 3) min_load_factor(0.0) +// Setting the minimum load factor to 0.0 guarantees that +// the hash table will never shrink. +// +// Roughly speaking: +// (1) dense_hash_map: fastest, uses the most memory unless entries are small +// (2) sparse_hash_map: slowest, uses the least memory +// (3) hash_map / unordered_map (STL): in the middle +// +// Typically I use sparse_hash_map when I care about space and/or when +// I need to save the hashtable on disk. I use hash_map otherwise. I +// don't personally use dense_hash_map ever; some people use it for +// small maps with lots of lookups. +// +// - dense_hash_map has, typically, about 78% memory overhead (if your +// data takes up X bytes, the hash_map uses .78X more bytes in overhead). +// - sparse_hash_map has about 4 bits overhead per entry. +// - sparse_hash_map can be 3-7 times slower than the others for lookup and, +// especially, inserts. See time_hash_map.cc for details. +// +// See /usr/(local/)?doc/sparsehash-*/sparse_hash_map.html +// for information about how to use this class. + +#pragma once + +#include // needed by stl_alloc +#include // for equal_to<>, select1st<>, etc +#include // for alloc +#include // for pair<> +#include +#include // IWYU pragma: export + +namespace google { + +template , + class EqualKey = std::equal_to, + class Alloc = libc_allocator_with_realloc>> +class sparse_hash_map { + private: + // Apparently select1st is not stl-standard, so we define our own + struct SelectKey { + typedef const Key& result_type; + const Key& operator()(const std::pair& p) const { + return p.first; + } + }; + struct SetKey { + void operator()(std::pair* value, const Key& new_key) const { + *const_cast(&value->first) = new_key; + // It would be nice to clear the rest of value here as well, in + // case it's taking up a lot of memory. We do this by clearing + // the value. This assumes T has a zero-arg constructor! + value->second = T(); + } + }; + // For operator[]. + struct DefaultValue { + std::pair operator()(const Key& key) { + return std::make_pair(key, T()); + } + }; + + // The actual data + typedef typename sparsehash_internal::key_equal_chosen::type EqualKeyChosen; + typedef sparse_hashtable, Key, HashFcn, SelectKey, + SetKey, EqualKeyChosen, Alloc> ht; + ht rep; + + static_assert(!sparsehash_internal::has_transparent_key_equal::value + || std::is_same>::value + || std::is_same::value, + "Heterogeneous lookup requires key_equal to either be the default container value or the same as the type provided by hash"); + + public: + typedef typename ht::key_type key_type; + typedef T data_type; + typedef T mapped_type; + typedef typename ht::value_type value_type; + typedef typename ht::hasher hasher; + typedef typename ht::key_equal key_equal; + typedef Alloc allocator_type; + + typedef typename ht::size_type size_type; + typedef typename ht::difference_type difference_type; + typedef typename ht::pointer pointer; + typedef typename ht::const_pointer const_pointer; + typedef typename ht::reference reference; + typedef typename ht::const_reference const_reference; + + typedef typename ht::iterator iterator; + typedef typename ht::const_iterator const_iterator; + typedef typename ht::local_iterator local_iterator; + typedef typename ht::const_local_iterator const_local_iterator; + + // Iterator functions + iterator begin() { return rep.begin(); } + iterator end() { return rep.end(); } + const_iterator begin() const { return rep.begin(); } + const_iterator end() const { return rep.end(); } + + // These come from tr1's unordered_map. For us, a bucket has 0 or 1 elements. + local_iterator begin(size_type i) { return rep.begin(i); } + local_iterator end(size_type i) { return rep.end(i); } + const_local_iterator begin(size_type i) const { return rep.begin(i); } + const_local_iterator end(size_type i) const { return rep.end(i); } + + // Accessor functions + allocator_type get_allocator() const { return rep.get_allocator(); } + hasher hash_funct() const { return rep.hash_funct(); } + hasher hash_function() const { return hash_funct(); } + key_equal key_eq() const { return rep.key_eq(); } + + // Constructors + explicit sparse_hash_map(size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, SelectKey(), SetKey(), + alloc) {} + + template + sparse_hash_map(InputIterator f, InputIterator l, + size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, SelectKey(), SetKey(), + alloc) { + rep.insert(f, l); + } + // We use the default copy constructor + // We use the default operator=() + // We use the default destructor + + void clear() { rep.clear(); } + void swap(sparse_hash_map& hs) { rep.swap(hs.rep); } + + // Functions concerning size + size_type size() const { return rep.size(); } + size_type max_size() const { return rep.max_size(); } + bool empty() const { return rep.empty(); } + size_type bucket_count() const { return rep.bucket_count(); } + size_type max_bucket_count() const { return rep.max_bucket_count(); } + + // These are tr1 methods. bucket() is the bucket the key is or would be in. + size_type bucket_size(size_type i) const { return rep.bucket_size(i); } + size_type bucket(const key_type& key) const { return rep.bucket(key); } + float load_factor() const { return size() * 1.0f / bucket_count(); } + float max_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return grow; + } + void max_load_factor(float new_grow) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(shrink, new_grow); + } + // These aren't tr1 methods but perhaps ought to be. + float min_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return shrink; + } + void min_load_factor(float new_shrink) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(new_shrink, grow); + } + // Deprecated; use min_load_factor() or max_load_factor() instead. + void set_resizing_parameters(float shrink, float grow) { + rep.set_resizing_parameters(shrink, grow); + } + + void reserve(size_type size) { rehash(size); } // note: rehash internally treats hint/size as number of elements + void resize(size_type hint) { rep.resize(hint); } + void rehash(size_type hint) { resize(hint); } // the tr1 name + + // Lookup routines + iterator find(const key_type& key) { return rep.find(key); } + const_iterator find(const key_type& key) const { return rep.find(key); } + + template + typename std::enable_if::value, iterator>::type + find(const K& key) { return rep.find(key); } + template + typename std::enable_if::value, const_iterator>::type + find(const K& key) const { return rep.find(key); } + + data_type& operator[](const key_type& key) { // This is our value-add! + // If key is in the hashtable, returns find(key)->second, + // otherwise returns insert(value_type(key, T()).first->second. + // Note it does not create an empty T unless the find fails. + return rep.template find_or_insert(key).second; + } + + size_type count(const key_type& key) const { return rep.count(key); } + + template + typename std::enable_if::value, size_type>::type + count(const K& key) const { return rep.count(key); } + + std::pair equal_range(const key_type& key) { + return rep.equal_range(key); + } + std::pair equal_range( + const key_type& key) const { + return rep.equal_range(key); + } + + template + typename std::enable_if::value, std::pair>::type + equal_range(const K& key) { + return rep.equal_range(key); + } + template + typename std::enable_if::value, std::pair>::type + equal_range(const K& key) const { + return rep.equal_range(key); + } + + // Insertion routines + std::pair insert(const value_type& obj) { + return rep.insert(obj); + } + template + void insert(InputIterator f, InputIterator l) { + rep.insert(f, l); + } + void insert(const_iterator f, const_iterator l) { rep.insert(f, l); } + // Required for std::insert_iterator; the passed-in iterator is ignored. + iterator insert(iterator, const value_type& obj) { return insert(obj).first; } + + // Deletion routines + // THESE ARE NON-STANDARD! I make you specify an "impossible" key + // value to identify deleted buckets. You can change the key as + // time goes on, or get rid of it entirely to be insert-only. + void set_deleted_key(const key_type& key) { rep.set_deleted_key(key); } + void clear_deleted_key() { rep.clear_deleted_key(); } + key_type deleted_key() const { return rep.deleted_key(); } + + // These are standard + size_type erase(const key_type& key) { return rep.erase(key); } + void erase(iterator it) { rep.erase(it); } + void erase(iterator f, iterator l) { rep.erase(f, l); } + + // Comparison + bool operator==(const sparse_hash_map& hs) const { return rep == hs.rep; } + bool operator!=(const sparse_hash_map& hs) const { return rep != hs.rep; } + + // I/O -- this is an add-on for writing metainformation to disk + // + // For maximum flexibility, this does not assume a particular + // file type (though it will probably be a FILE *). We just pass + // the fp through to rep. + + // If your keys and values are simple enough, you can pass this + // serializer to serialize()/unserialize(). "Simple enough" means + // value_type is a POD type that contains no pointers. Note, + // however, we don't try to normalize endianness. + typedef typename ht::NopointerSerializer NopointerSerializer; + + // serializer: a class providing operator()(OUTPUT*, const value_type&) + // (writing value_type to OUTPUT). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an ostream*/subclass_of_ostream*, OR a + // pointer to a class providing size_t Write(const void*, size_t), + // which writes a buffer into a stream (which fp presumably + // owns) and returns the number of bytes successfully written. + // Note basic_ostream is not currently supported. + template + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + return rep.serialize(serializer, fp); + } + + // serializer: a functor providing operator()(INPUT*, value_type*) + // (reading from INPUT and into value_type). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an istream*/subclass_of_istream*, OR a + // pointer to a class providing size_t Read(void*, size_t), + // which reads into a buffer from a stream (which fp presumably + // owns) and returns the number of bytes successfully read. + // Note basic_istream is not currently supported. + // NOTE: Since value_type is std::pair, ValueSerializer + // may need to do a const cast in order to fill in the key. + // NOTE: if Key or T are not POD types, the serializer MUST use + // placement-new to initialize their values, rather than a normal + // equals-assignment or similar. (The value_type* passed into the + // serializer points to garbage memory.) + template + bool unserialize(ValueSerializer serializer, INPUT* fp) { + return rep.unserialize(serializer, fp); + } + + // The four methods below are DEPRECATED. + // Use serialize() and unserialize() for new code. + template + bool write_metadata(OUTPUT* fp) { + return rep.write_metadata(fp); + } + + template + bool read_metadata(INPUT* fp) { + return rep.read_metadata(fp); + } + + template + bool write_nopointer_data(OUTPUT* fp) { + return rep.write_nopointer_data(fp); + } + + template + bool read_nopointer_data(INPUT* fp) { + return rep.read_nopointer_data(fp); + } +}; + +// We need a global swap as well +template +inline void swap(sparse_hash_map& hm1, + sparse_hash_map& hm2) { + hm1.swap(hm2); +} + +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_map.h b/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_map.h new file mode 100644 index 000000000..c967cfe20 --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_map.h @@ -0,0 +1,62 @@ +#include +#include +#include "sparsehash/dense_hash_map" + +using google::dense_hash_map; + +#define NOTFOUND -1 + +template +class SparseHashMap { + public: + SparseHashMap() { dense_map_.set_empty_key(0); } + SparseHashMap(arrow::MemoryPool* pool) { dense_map_.set_empty_key(0); } + template + arrow::Status GetOrInsert(const Scalar& value, Func1&& on_found, Func2&& on_not_found, + int32_t* out_memo_index) { + if (dense_map_.find(value) == dense_map_.end()) { + auto index = size_++; + dense_map_[value] = index; + *out_memo_index = index; + on_not_found(index); + } else { + auto index = dense_map_[value]; + *out_memo_index = index; + on_found(index); + } + return arrow::Status::OK(); + } + template + int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) { + if (!null_index_set_) { + null_index_set_ = true; + null_index_ = size_++; + on_not_found(null_index_); + } else { + on_found(null_index_); + } + return null_index_; + } + int32_t Get(const Scalar& value) { + if (dense_map_.find(value) == dense_map_.end()) { + return NOTFOUND; + } else { + auto ret = dense_map_[value]; + return ret; + } + } + int32_t GetNull() { + if (!null_index_set_) { + return NOTFOUND; + } else { + auto ret = null_index_; + return ret; + } + } + + private: + dense_hash_map dense_map_; + int32_t size_ = 0; + bool null_index_set_ = false; + int32_t null_index_; +}; diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_set b/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_set new file mode 100644 index 000000000..dde3142a6 --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/sparse_hash_set @@ -0,0 +1,349 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// This is just a very thin wrapper over sparsehashtable.h, just +// like sgi stl's stl_hash_set is a very thin wrapper over +// stl_hashtable. The major thing we define is operator[], because +// we have a concept of a data_type which stl_hashtable doesn't +// (it only has a key and a value). +// +// This is more different from sparse_hash_map than you might think, +// because all iterators for sets are const (you obviously can't +// change the key, and for sets there is no value). +// +// We adhere mostly to the STL semantics for hash-map. One important +// exception is that insert() may invalidate iterators entirely -- STL +// semantics are that insert() may reorder iterators, but they all +// still refer to something valid in the hashtable. Not so for us. +// Likewise, insert() may invalidate pointers into the hashtable. +// (Whether insert invalidates iterators and pointers depends on +// whether it results in a hashtable resize). On the plus side, +// delete() doesn't invalidate iterators or pointers at all, or even +// change the ordering of elements. +// +// Here are a few "power user" tips: +// +// 1) set_deleted_key(): +// Unlike STL's hash_map, if you want to use erase() you +// *must* call set_deleted_key() after construction. +// +// 2) resize(0): +// When an item is deleted, its memory isn't freed right +// away. This allows you to iterate over a hashtable, +// and call erase(), without invalidating the iterator. +// To force the memory to be freed, call resize(0). +// For tr1 compatibility, this can also be called as rehash(0). +// +// 3) min_load_factor(0.0) +// Setting the minimum load factor to 0.0 guarantees that +// the hash table will never shrink. +// +// Roughly speaking: +// (1) dense_hash_set: fastest, uses the most memory unless entries are small +// (2) sparse_hash_set: slowest, uses the least memory +// (3) hash_set / unordered_set (STL): in the middle +// +// Typically I use sparse_hash_set when I care about space and/or when +// I need to save the hashtable on disk. I use hash_set otherwise. I +// don't personally use dense_hash_set ever; some people use it for +// small sets with lots of lookups. +// +// - dense_hash_set has, typically, about 78% memory overhead (if your +// data takes up X bytes, the hash_set uses .78X more bytes in overhead). +// - sparse_hash_set has about 4 bits overhead per entry. +// - sparse_hash_set can be 3-7 times slower than the others for lookup and, +// especially, inserts. See time_hash_map.cc for details. +// +// See /usr/(local/)?doc/sparsehash-*/sparse_hash_set.html +// for information about how to use this class. + +#pragma once + +#include // needed by stl_alloc +#include // for equal_to<> +#include // for alloc (which we don't use) +#include // for pair<> +#include +#include // IWYU pragma: export + +namespace google { + +template , + class EqualKey = std::equal_to, + class Alloc = libc_allocator_with_realloc> +class sparse_hash_set { + private: + // Apparently identity is not stl-standard, so we define our own + struct Identity { + typedef const Value& result_type; + const Value& operator()(const Value& v) const { return v; } + }; + struct SetKey { + void operator()(Value* value, const Value& new_key) const { + *value = new_key; + } + }; + + typedef typename sparsehash_internal::key_equal_chosen::type EqualKeyChosen; + typedef sparse_hashtable ht; + ht rep; + + static_assert(!sparsehash_internal::has_transparent_key_equal::value + || std::is_same>::value + || std::is_same::value, + "Heterogeneous lookup requires key_equal to either be the default container value or the same as the type provided by hash"); + + public: + typedef typename ht::key_type key_type; + typedef typename ht::value_type value_type; + typedef typename ht::hasher hasher; + typedef typename ht::key_equal key_equal; + typedef Alloc allocator_type; + + typedef typename ht::size_type size_type; + typedef typename ht::difference_type difference_type; + typedef typename ht::const_pointer pointer; + typedef typename ht::const_pointer const_pointer; + typedef typename ht::const_reference reference; + typedef typename ht::const_reference const_reference; + + typedef typename ht::const_iterator iterator; + typedef typename ht::const_iterator const_iterator; + typedef typename ht::const_local_iterator local_iterator; + typedef typename ht::const_local_iterator const_local_iterator; + + // Iterator functions -- recall all iterators are const + iterator begin() const { return rep.begin(); } + iterator end() const { return rep.end(); } + + // These come from tr1's unordered_set. For us, a bucket has 0 or 1 elements. + local_iterator begin(size_type i) const { return rep.begin(i); } + local_iterator end(size_type i) const { return rep.end(i); } + + // Accessor functions + allocator_type get_allocator() const { return rep.get_allocator(); } + hasher hash_funct() const { return rep.hash_funct(); } + hasher hash_function() const { return hash_funct(); } // tr1 name + key_equal key_eq() const { return rep.key_eq(); } + + // Constructors + explicit sparse_hash_set(size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, Identity(), SetKey(), alloc) { + } + + template + sparse_hash_set(InputIterator f, InputIterator l, + size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, Identity(), SetKey(), alloc) { + rep.insert(f, l); + } + // We use the default copy constructor + // We use the default operator=() + // We use the default destructor + + void clear() { rep.clear(); } + void swap(sparse_hash_set& hs) { rep.swap(hs.rep); } + + // Functions concerning size + size_type size() const { return rep.size(); } + size_type max_size() const { return rep.max_size(); } + bool empty() const { return rep.empty(); } + size_type bucket_count() const { return rep.bucket_count(); } + size_type max_bucket_count() const { return rep.max_bucket_count(); } + + // These are tr1 methods. bucket() is the bucket the key is or would be in. + size_type bucket_size(size_type i) const { return rep.bucket_size(i); } + size_type bucket(const key_type& key) const { return rep.bucket(key); } + float load_factor() const { return size() * 1.0f / bucket_count(); } + float max_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return grow; + } + void max_load_factor(float new_grow) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(shrink, new_grow); + } + // These aren't tr1 methods but perhaps ought to be. + float min_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return shrink; + } + void min_load_factor(float new_shrink) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(new_shrink, grow); + } + // Deprecated; use min_load_factor() or max_load_factor() instead. + void set_resizing_parameters(float shrink, float grow) { + rep.set_resizing_parameters(shrink, grow); + } + + void reserve(size_type size) { rehash(size); } // note: rehash internally treats hint/size as number of elements + void resize(size_type hint) { rep.resize(hint); } + void rehash(size_type hint) { resize(hint); } // the tr1 name + + // Lookup routines + iterator find(const key_type& key) const { return rep.find(key); } + + template + typename std::enable_if::value, iterator>::type + find(const K& key) const { return rep.find(key); } + + size_type count(const key_type& key) const { return rep.count(key); } + + template + typename std::enable_if::value, size_type>::type + count(const K& key) const { return rep.count(key); } + + std::pair equal_range(const key_type& key) const { + return rep.equal_range(key); + } + + template + typename std::enable_if::value, std::pair>::type + equal_range(const K& key) const { + return rep.equal_range(key); + } + + // Insertion routines + std::pair insert(const value_type& obj) { + std::pair p = rep.insert(obj); + return std::pair(p.first, p.second); // const to non-const + } + template + void insert(InputIterator f, InputIterator l) { + rep.insert(f, l); + } + void insert(const_iterator f, const_iterator l) { rep.insert(f, l); } + // Required for std::insert_iterator; the passed-in iterator is ignored. + iterator insert(iterator, const value_type& obj) { return insert(obj).first; } + + // Deletion routines + // THESE ARE NON-STANDARD! I make you specify an "impossible" key + // value to identify deleted buckets. You can change the key as + // time goes on, or get rid of it entirely to be insert-only. + void set_deleted_key(const key_type& key) { rep.set_deleted_key(key); } + void clear_deleted_key() { rep.clear_deleted_key(); } + key_type deleted_key() const { return rep.deleted_key(); } + + // These are standard + size_type erase(const key_type& key) { return rep.erase(key); } + void erase(iterator it) { rep.erase(it); } + void erase(iterator f, iterator l) { rep.erase(f, l); } + + // Comparison + bool operator==(const sparse_hash_set& hs) const { return rep == hs.rep; } + bool operator!=(const sparse_hash_set& hs) const { return rep != hs.rep; } + + // I/O -- this is an add-on for writing metainformation to disk + // + // For maximum flexibility, this does not assume a particular + // file type (though it will probably be a FILE *). We just pass + // the fp through to rep. + + // If your keys and values are simple enough, you can pass this + // serializer to serialize()/unserialize(). "Simple enough" means + // value_type is a POD type that contains no pointers. Note, + // however, we don't try to normalize endianness. + typedef typename ht::NopointerSerializer NopointerSerializer; + + // serializer: a class providing operator()(OUTPUT*, const value_type&) + // (writing value_type to OUTPUT). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an ostream*/subclass_of_ostream*, OR a + // pointer to a class providing size_t Write(const void*, size_t), + // which writes a buffer into a stream (which fp presumably + // owns) and returns the number of bytes successfully written. + // Note basic_ostream is not currently supported. + template + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + return rep.serialize(serializer, fp); + } + + // serializer: a functor providing operator()(INPUT*, value_type*) + // (reading from INPUT and into value_type). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an istream*/subclass_of_istream*, OR a + // pointer to a class providing size_t Read(void*, size_t), + // which reads into a buffer from a stream (which fp presumably + // owns) and returns the number of bytes successfully read. + // Note basic_istream is not currently supported. + // NOTE: Since value_type is const Key, ValueSerializer + // may need to do a const cast in order to fill in the key. + // NOTE: if Key is not a POD type, the serializer MUST use + // placement-new to initialize its value, rather than a normal + // equals-assignment or similar. (The value_type* passed into + // the serializer points to garbage memory.) + template + bool unserialize(ValueSerializer serializer, INPUT* fp) { + return rep.unserialize(serializer, fp); + } + + // The four methods below are DEPRECATED. + // Use serialize() and unserialize() for new code. + template + bool write_metadata(OUTPUT* fp) { + return rep.write_metadata(fp); + } + + template + bool read_metadata(INPUT* fp) { + return rep.read_metadata(fp); + } + + template + bool write_nopointer_data(OUTPUT* fp) { + return rep.write_nopointer_data(fp); + } + + template + bool read_nopointer_data(INPUT* fp) { + return rep.read_nopointer_data(fp); + } +}; + +template +inline void swap(sparse_hash_set& hs1, + sparse_hash_set& hs2) { + hs1.swap(hs2); +} + +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/sparsetable b/oap-native-sql/cpp/src/third_party/sparsehash/sparsetable new file mode 100644 index 000000000..bcc4b98ee --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/sparsetable @@ -0,0 +1,1830 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// +// A sparsetable is a random container that implements a sparse array, +// that is, an array that uses very little memory to store unassigned +// indices (in this case, between 1-2 bits per unassigned index). For +// instance, if you allocate an array of size 5 and assign a[2] = , then a[2] will take up a lot of memory but a[0], a[1], +// a[3], and a[4] will not. Array elements that have a value are +// called "assigned". Array elements that have no value yet, or have +// had their value cleared using erase() or clear(), are called +// "unassigned". +// +// Unassigned values seem to have the default value of T (see below). +// Nevertheless, there is a difference between an unassigned index and +// one explicitly assigned the value of T(). The latter is considered +// assigned. +// +// Access to an array element is constant time, as is insertion and +// deletion. Insertion and deletion may be fairly slow, however: +// because of this container's memory economy, each insert and delete +// causes a memory reallocation. +// +// NOTE: You should not test(), get(), or set() any index that is +// greater than sparsetable.size(). If you need to do that, call +// resize() first. +// +// --- Template parameters +// PARAMETER DESCRIPTION DEFAULT +// T The value of the array: the type of -- +// object that is stored in the array. +// +// GROUP_SIZE How large each "group" in the table 48 +// is (see below). Larger values use +// a little less memory but cause most +// operations to be a little slower +// +// Alloc: Allocator to use to allocate memory. libc_allocator_with_realloc +// +// --- Model of +// Random Access Container +// +// --- Type requirements +// T must be Copy Constructible. It need not be Assignable. +// +// --- Public base classes +// None. +// +// --- Members +// Type members +// +// MEMBER WHERE DEFINED DESCRIPTION +// value_type container The type of object, T, stored in the array +// allocator_type container Allocator to use +// pointer container Pointer to p +// const_pointer container Const pointer to p +// reference container Reference to t +// const_reference container Const reference to t +// size_type container An unsigned integral type +// difference_type container A signed integral type +// iterator [*] container Iterator used to iterate over a sparsetable +// const_iterator container Const iterator used to iterate over a table +// reverse_iterator reversible Iterator used to iterate backwards over +// container a sparsetable +// const_reverse_iterator reversible container Guess +// nonempty_iterator [+] sparsetable Iterates over assigned +// array elements only +// const_nonempty_iterator sparsetable Iterates over assigned +// array elements only +// reverse_nonempty_iterator sparsetable Iterates backwards over +// assigned array elements only +// const_reverse_nonempty_iterator sparsetable Iterates backwards over +// assigned array elements only +// +// [*] All iterators are const in a sparsetable (though nonempty_iterators +// may not be). Use get() and set() to assign values, not iterators. +// +// [+] iterators are random-access iterators. nonempty_iterators are +// bidirectional iterators. + +// Iterator members +// MEMBER WHERE DEFINED DESCRIPTION +// +// iterator begin() container An iterator to the beginning of the table +// iterator end() container An iterator to the end of the table +// const_iterator container A const_iterator pointing to the +// begin() const beginning of a sparsetable +// const_iterator container A const_iterator pointing to the +// end() const end of a sparsetable +// +// reverse_iterator reversable Points to beginning of a reversed +// rbegin() container sparsetable +// reverse_iterator reversable Points to end of a reversed table +// rend() container +// const_reverse_iterator reversable Points to beginning of a +// rbegin() const container reversed sparsetable +// const_reverse_iterator reversable Points to end of a reversed table +// rend() const container +// +// nonempty_iterator sparsetable Points to first assigned element +// begin() of a sparsetable +// nonempty_iterator sparsetable Points past last assigned element +// end() of a sparsetable +// const_nonempty_iterator sparsetable Points to first assigned element +// begin() const of a sparsetable +// const_nonempty_iterator sparsetable Points past last assigned element +// end() const of a sparsetable +// +// reverse_nonempty_iterator sparsetable Points to first assigned element +// begin() of a reversed sparsetable +// reverse_nonempty_iterator sparsetable Points past last assigned element +// end() of a reversed sparsetable +// const_reverse_nonempty_iterator sparsetable Points to first assigned +// begin() const elt of a reversed sparsetable +// const_reverse_nonempty_iterator sparsetable Points past last assigned +// end() const elt of a reversed sparsetable +// +// +// Other members +// MEMBER WHERE DEFINED DESCRIPTION +// sparsetable() sparsetable A table of size 0; must resize() +// before using. +// sparsetable(size_type size) sparsetable A table of size size. All +// indices are unassigned. +// sparsetable( +// const sparsetable &tbl) sparsetable Copy constructor +// ~sparsetable() sparsetable The destructor +// sparsetable &operator=( sparsetable The assignment operator +// const sparsetable &tbl) +// +// void resize(size_type size) sparsetable Grow or shrink a table to +// have size indices [*] +// +// void swap(sparsetable &x) sparsetable Swap two sparsetables +// void swap(sparsetable &x, sparsetable Swap two sparsetables +// sparsetable &y) (global, not member, function) +// +// size_type size() const sparsetable Number of "buckets" in the table +// size_type max_size() const sparsetable Max allowed size of a sparsetable +// bool empty() const sparsetable true if size() == 0 +// size_type num_nonempty() const sparsetable Number of assigned "buckets" +// +// const_reference get( sparsetable Value at index i, or default +// size_type i) const value if i is unassigned +// const_reference operator[]( sparsetable Identical to get(i) [+] +// difference_type i) const +// reference set(size_type i, sparsetable Set element at index i to +// const_reference val) be a copy of val +// bool test(size_type i) sparsetable True if element at index i +// const has been assigned to +// bool test(iterator pos) sparsetable True if element pointed to +// const by pos has been assigned to +// void erase(iterator pos) sparsetable Set element pointed to by +// pos to be unassigned [!] +// void erase(size_type i) sparsetable Set element i to be unassigned +// void erase(iterator start, sparsetable Erases all elements between +// iterator end) start and end +// void clear() sparsetable Erases all elements in the table +// +// I/O versions exist for both FILE* and for File* (Google2-style files): +// bool write_metadata(FILE *fp) sparsetable Writes a sparsetable to the +// bool write_metadata(File *fp) given file. true if write +// completes successfully +// bool read_metadata(FILE *fp) sparsetable Replaces sparsetable with +// bool read_metadata(File *fp) version read from fp. true +// if read completes sucessfully +// bool write_nopointer_data(FILE *fp) Read/write the data stored in +// bool read_nopointer_data(FILE*fp) the table, if it's simple +// +// bool operator==( forward Tests two tables for equality. +// const sparsetable &t1, container This is a global function, +// const sparsetable &t2) not a member function. +// bool operator<( forward Lexicographical comparison. +// const sparsetable &t1, container This is a global function, +// const sparsetable &t2) not a member function. +// +// [*] If you shrink a sparsetable using resize(), assigned elements +// past the end of the table are removed using erase(). If you grow +// a sparsetable, new unassigned indices are created. +// +// [+] Note that operator[] returns a const reference. You must use +// set() to change the value of a table element. +// +// [!] Unassignment also calls the destructor. +// +// Iterators are invalidated whenever an item is inserted or +// deleted (ie set() or erase() is used) or when the size of +// the table changes (ie resize() or clear() is used). +// +// See doc/sparsetable.html for more information about how to use this class. + +// Note: this uses STL style for naming, rather than Google naming. +// That's because this is an STL-y container + +#pragma once + +#include // for malloc/free +#include // to read/write tables +#include // for memcpy +#include // the normal place uint16_t is defined +#include // for bounds checking +#include // to define reverse_iterator for me +#include // equal, lexicographical_compare, swap,... +#include // uninitialized_copy, uninitialized_fill +#include // a sparsetable is a vector of groups +#include +#include +#include +#include + +namespace google { +// The smaller this is, the faster lookup is (because the group bitmap is +// smaller) and the faster insert is, because there's less to move. +// On the other hand, there are more groups. Since group::size_type is +// a short, this number should be of the form 32*x + 16 to avoid waste. +static const uint16_t DEFAULT_SPARSEGROUP_SIZE = 48; // fits in 1.5 words + +// Our iterator as simple as iterators can be: basically it's just +// the index into our table. Dereference, the only complicated +// thing, we punt to the table class. This just goes to show how +// much machinery STL requires to do even the most trivial tasks. +// +// A NOTE ON ASSIGNING: +// A sparse table does not actually allocate memory for entries +// that are not filled. Because of this, it becomes complicated +// to have a non-const iterator: we don't know, if the iterator points +// to a not-filled bucket, whether you plan to fill it with something +// or whether you plan to read its value (in which case you'll get +// the default bucket value). Therefore, while we can define const +// operations in a pretty 'normal' way, for non-const operations, we +// define something that returns a helper object with operator= and +// operator& that allocate a bucket lazily. We use this for table[] +// and also for regular table iterators. + +template +class table_element_adaptor { + public: + typedef typename tabletype::value_type value_type; + typedef typename tabletype::size_type size_type; + typedef typename tabletype::reference reference; + typedef typename tabletype::pointer pointer; + + table_element_adaptor(tabletype* tbl, size_type p) : table(tbl), pos(p) {} + table_element_adaptor& operator=(const value_type& val) { + table->set(pos, val); + return *this; + } + operator value_type() { return table->get(pos); } // we look like a value + pointer operator&() { return &table->mutating_get(pos); } + + private: + tabletype* table; + size_type pos; +}; + +// Our iterator as simple as iterators can be: basically it's just +// the index into our table. Dereference, the only complicated +// thing, we punt to the table class. This just goes to show how +// much machinery STL requires to do even the most trivial tasks. +// +// By templatizing over tabletype, we have one iterator type which +// we can use for both sparsetables and sparsebins. In fact it +// works on any class that allows size() and operator[] (eg vector), +// as long as it does the standard STL typedefs too (eg value_type). + +template +class table_iterator { + public: + typedef table_iterator iterator; + + typedef std::random_access_iterator_tag iterator_category; + typedef typename tabletype::value_type value_type; + typedef typename tabletype::difference_type difference_type; + typedef typename tabletype::size_type size_type; + typedef table_element_adaptor reference; + typedef table_element_adaptor* pointer; + typedef typename tabletype::const_reference const_reference; // we're const-only + + // The "real" constructor + table_iterator(tabletype* tbl, size_type p) : table(tbl), pos(p) {} + // The default constructor, used when I define vars of type table::iterator + table_iterator() : table(NULL), pos(0) {} + // The copy constructor, for when I say table::iterator foo = tbl.begin() + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // The main thing our iterator does is dereference. If the table entry + // we point to is empty, we return the default value type. + // This is the big different function from the const iterator. + reference operator*() { return table_element_adaptor(table, pos); } + const_reference operator*() const { return table_element_adaptor(table, pos); } + pointer operator->() { return &(operator*()); } + + // Helper function to assert things are ok; eg pos is still in range + void check() const { + assert(table); + assert(pos <= table->size()); + } + + // Arithmetic: we just do arithmetic on pos. We don't even need to + // do bounds checking, since STL doesn't consider that its job. :-) + iterator& operator+=(size_type t) { + pos += t; + check(); + return *this; + } + iterator& operator-=(size_type t) { + pos -= t; + check(); + return *this; + } + iterator& operator++() { + ++pos; + check(); + return *this; + } + iterator& operator--() { + --pos; + check(); + return *this; + } + iterator operator++(int) { + iterator tmp(*this); // for x++ + ++pos; + check(); + return tmp; + } + iterator operator--(int) { + iterator tmp(*this); // for x-- + --pos; + check(); + return tmp; + } + iterator operator+(difference_type i) const { + iterator tmp(*this); + tmp += i; + return tmp; + } + iterator operator-(difference_type i) const { + iterator tmp(*this); + tmp -= i; + return tmp; + } + difference_type operator-(iterator it) const { // for "x = it2 - it" + assert(table == it.table); + return pos - it.pos; + } + reference operator[](difference_type n) const { + return *(*this + n); // simple though not totally efficient + } + + // Comparisons. + bool operator==(const iterator& it) const { + return table == it.table && pos == it.pos; + } + bool operator<(const iterator& it) const { + assert(table == it.table); // life is bad bad bad otherwise + return pos < it.pos; + } + bool operator!=(const iterator& it) const { return !(*this == it); } + bool operator<=(const iterator& it) const { return !(it < *this); } + bool operator>(const iterator& it) const { return it < *this; } + bool operator>=(const iterator& it) const { return !(*this < it); } + + // Here's the info we actually need to be an iterator + tabletype* table; // so we can dereference and bounds-check + size_type pos; // index into the table +}; + +// support for "3 + iterator" has to be defined outside the class, alas +template +table_iterator operator+(typename table_iterator::difference_type i, + table_iterator it) { + return it + i; // so people can say it2 = 3 + it +} + +template +class const_table_iterator { + public: + typedef table_iterator iterator; + typedef const_table_iterator const_iterator; + + typedef std::random_access_iterator_tag iterator_category; + typedef typename tabletype::value_type value_type; + typedef typename tabletype::difference_type difference_type; + typedef typename tabletype::size_type size_type; + typedef typename tabletype::const_reference reference; // we're const-only + typedef typename tabletype::const_pointer pointer; + + // The "real" constructor + const_table_iterator(const tabletype* tbl, size_type p) + : table(tbl), pos(p) {} + // The default constructor, used when I define vars of type table::iterator + const_table_iterator() : table(NULL), pos(0) {} + // The copy constructor, for when I say table::iterator foo = tbl.begin() + // Also converts normal iterators to const iterators + const_table_iterator(const iterator& from) + : table(from.table), pos(from.pos) {} + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // The main thing our iterator does is dereference. If the table entry + // we point to is empty, we return the default value type. + reference operator*() const { return (*table)[pos]; } + pointer operator->() const { return &(operator*()); } + + // Helper function to assert things are ok; eg pos is still in range + void check() const { + assert(table); + assert(pos <= table->size()); + } + + // Arithmetic: we just do arithmetic on pos. We don't even need to + // do bounds checking, since STL doesn't consider that its job. :-) + const_iterator& operator+=(size_type t) { + pos += t; + check(); + return *this; + } + const_iterator& operator-=(size_type t) { + pos -= t; + check(); + return *this; + } + const_iterator& operator++() { + ++pos; + check(); + return *this; + } + const_iterator& operator--() { + --pos; + check(); + return *this; + } + const_iterator operator++(int) { + const_iterator tmp(*this); // for x++ + ++pos; + check(); + return tmp; + } + const_iterator operator--(int) { + const_iterator tmp(*this); // for x-- + --pos; + check(); + return tmp; + } + const_iterator operator+(difference_type i) const { + const_iterator tmp(*this); + tmp += i; + return tmp; + } + const_iterator operator-(difference_type i) const { + const_iterator tmp(*this); + tmp -= i; + return tmp; + } + difference_type operator-(const_iterator it) const { // for "x = it2 - it" + assert(table == it.table); + return pos - it.pos; + } + reference operator[](difference_type n) const { + return *(*this + n); // simple though not totally efficient + } + + // Comparisons. + bool operator==(const const_iterator& it) const { + return table == it.table && pos == it.pos; + } + bool operator<(const const_iterator& it) const { + assert(table == it.table); // life is bad bad bad otherwise + return pos < it.pos; + } + bool operator!=(const const_iterator& it) const { return !(*this == it); } + bool operator<=(const const_iterator& it) const { return !(it < *this); } + bool operator>(const const_iterator& it) const { return it < *this; } + bool operator>=(const const_iterator& it) const { return !(*this < it); } + + // Here's the info we actually need to be an iterator + const tabletype* table; // so we can dereference and bounds-check + size_type pos; // index into the table +}; + +// support for "3 + iterator" has to be defined outside the class, alas +template +const_table_iterator operator+( + typename const_table_iterator::difference_type i, + const_table_iterator it) { + return it + i; // so people can say it2 = 3 + it +} + +// --------------------------------------------------------------------------- + +/* +// This is a 2-D iterator. You specify a begin and end over a list +// of *containers*. We iterate over each container by iterating over +// it. It's actually simple: +// VECTOR.begin() VECTOR[0].begin() --------> VECTOR[0].end() ---, +// | ________________________________________________/ +// | \_> VECTOR[1].begin() --------> VECTOR[1].end() -, +// | ___________________________________________________/ +// v \_> ...... +// VECTOR.end() +// +// It's impossible to do random access on one of these things in constant +// time, so it's just a bidirectional iterator. +// +// Unfortunately, because we need to use this for a non-empty iterator, +// we use nonempty_begin() and nonempty_end() instead of begin() and end() +// (though only going across, not down). +*/ + +#define TWOD_BEGIN_ nonempty_begin +#define TWOD_END_ nonempty_end +#define TWOD_ITER_ nonempty_iterator +#define TWOD_CONST_ITER_ const_nonempty_iterator + +template +class two_d_iterator { + public: + typedef two_d_iterator iterator; + + typedef std::bidirectional_iterator_tag iterator_category; + // apparently some versions of VC++ have trouble with two ::'s in a typename + typedef typename containertype::value_type _tmp_vt; + typedef typename _tmp_vt::value_type value_type; + typedef typename _tmp_vt::difference_type difference_type; + typedef typename _tmp_vt::reference reference; + typedef typename _tmp_vt::pointer pointer; + + // The "real" constructor. begin and end specify how many rows we have + // (in the diagram above); we always iterate over each row completely. + two_d_iterator(typename containertype::iterator begin, + typename containertype::iterator end, + typename containertype::iterator curr) + : row_begin(begin), row_end(end), row_current(curr), col_current() { + if (row_current != row_end) { + col_current = row_current->TWOD_BEGIN_(); + advance_past_end(); // in case cur->begin() == cur->end() + } + } + // If you want to start at an arbitrary place, you can, I guess + two_d_iterator(typename containertype::iterator begin, + typename containertype::iterator end, + typename containertype::iterator curr, + typename containertype::value_type::TWOD_ITER_ col) + : row_begin(begin), row_end(end), row_current(curr), col_current(col) { + advance_past_end(); // in case cur->begin() == cur->end() + } + // The default constructor, used when I define vars of type table::iterator + two_d_iterator() : row_begin(), row_end(), row_current(), col_current() {} + // The default destructor is fine; we don't define one + // The default operator= is fine; we don't define one + + // Happy dereferencer + reference operator*() const { return *col_current; } + pointer operator->() const { return &(operator*()); } + + // Arithmetic: we just do arithmetic on pos. We don't even need to + // do bounds checking, since STL doesn't consider that its job. :-) + // NOTE: this is not amortized constant time! What do we do about it? + void advance_past_end() { // used when col_current points to end() + while (col_current == row_current->TWOD_END_()) { // end of current row + ++row_current; // go to beginning of next + if (row_current != row_end) // col is irrelevant at end + col_current = row_current->TWOD_BEGIN_(); + else + break; // don't go past row_end + } + } + + iterator& operator++() { + assert(row_current != row_end); // how to ++ from there? + ++col_current; + advance_past_end(); // in case col_current is at end() + return *this; + } + iterator& operator--() { + while (row_current == row_end || + col_current == row_current->TWOD_BEGIN_()) { + assert(row_current != row_begin); + --row_current; + col_current = row_current->TWOD_END_(); // this is 1 too far + } + --col_current; + return *this; + } + iterator operator++(int) { + iterator tmp(*this); + ++*this; + return tmp; + } + iterator operator--(int) { + iterator tmp(*this); + --*this; + return tmp; + } + + // Comparisons. + bool operator==(const iterator& it) const { + return (row_begin == it.row_begin && row_end == it.row_end && + row_current == it.row_current && + (row_current == row_end || col_current == it.col_current)); + } + bool operator!=(const iterator& it) const { return !(*this == it); } + + // Here's the info we actually need to be an iterator + // These need to be public so we convert from iterator to const_iterator + typename containertype::iterator row_begin, row_end, row_current; + typename containertype::value_type::TWOD_ITER_ col_current; +}; + +// The same thing again, but this time const. :-( +template +class const_two_d_iterator { + public: + typedef const_two_d_iterator iterator; + + typedef std::bidirectional_iterator_tag iterator_category; + // apparently some versions of VC++ have trouble with two ::'s in a typename + typedef typename containertype::value_type _tmp_vt; + typedef typename _tmp_vt::value_type value_type; + typedef typename _tmp_vt::difference_type difference_type; + typedef typename _tmp_vt::const_reference reference; + typedef typename _tmp_vt::const_pointer pointer; + + const_two_d_iterator(typename containertype::const_iterator begin, + typename containertype::const_iterator end, + typename containertype::const_iterator curr) + : row_begin(begin), row_end(end), row_current(curr), col_current() { + if (curr != end) { + col_current = curr->TWOD_BEGIN_(); + advance_past_end(); // in case cur->begin() == cur->end() + } + } + const_two_d_iterator(typename containertype::const_iterator begin, + typename containertype::const_iterator end, + typename containertype::const_iterator curr, + typename containertype::value_type::TWOD_CONST_ITER_ col) + : row_begin(begin), row_end(end), row_current(curr), col_current(col) { + advance_past_end(); // in case cur->begin() == cur->end() + } + const_two_d_iterator() + : row_begin(), row_end(), row_current(), col_current() {} + // Need this explicitly so we can convert normal iterators to const + // iterators + const_two_d_iterator(const two_d_iterator& it) + : row_begin(it.row_begin), + row_end(it.row_end), + row_current(it.row_current), + col_current(it.col_current) {} + + typename containertype::const_iterator row_begin, row_end, row_current; + typename containertype::value_type::TWOD_CONST_ITER_ col_current; + + // EVERYTHING FROM HERE DOWN IS THE SAME AS THE NON-CONST ITERATOR + reference operator*() const { return *col_current; } + pointer operator->() const { return &(operator*()); } + + void advance_past_end() { // used when col_current points to end() + while (col_current == row_current->TWOD_END_()) { // end of current row + ++row_current; // go to beginning of next + if (row_current != row_end) // col is irrelevant at end + col_current = row_current->TWOD_BEGIN_(); + else + break; // don't go past row_end + } + } + iterator& operator++() { + assert(row_current != row_end); // how to ++ from there? + ++col_current; + advance_past_end(); // in case col_current is at end() + return *this; + } + iterator& operator--() { + while (row_current == row_end || + col_current == row_current->TWOD_BEGIN_()) { + assert(row_current != row_begin); + --row_current; + col_current = row_current->TWOD_END_(); // this is 1 too far + } + --col_current; + return *this; + } + iterator operator++(int) { + iterator tmp(*this); + ++*this; + return tmp; + } + iterator operator--(int) { + iterator tmp(*this); + --*this; + return tmp; + } + + bool operator==(const iterator& it) const { + return (row_begin == it.row_begin && row_end == it.row_end && + row_current == it.row_current && + (row_current == row_end || col_current == it.col_current)); + } + bool operator!=(const iterator& it) const { return !(*this == it); } +}; + +// We provide yet another version, to be as frugal with memory as +// possible. This one frees each block of memory as it finishes +// iterating over it. By the end, the entire table is freed. +// For understandable reasons, you can only iterate over it once, +// which is why it's an input iterator +template +class destructive_two_d_iterator { + public: + typedef destructive_two_d_iterator iterator; + + typedef std::input_iterator_tag iterator_category; + // apparently some versions of VC++ have trouble with two ::'s in a typename + typedef typename containertype::value_type _tmp_vt; + typedef typename _tmp_vt::value_type value_type; + typedef typename _tmp_vt::difference_type difference_type; + typedef typename _tmp_vt::reference reference; + typedef typename _tmp_vt::pointer pointer; + + destructive_two_d_iterator(typename containertype::iterator begin, + typename containertype::iterator end, + typename containertype::iterator curr) + : row_begin(begin), row_end(end), row_current(curr), col_current() { + if (curr != end) { + col_current = curr->TWOD_BEGIN_(); + advance_past_end(); // in case cur->begin() == cur->end() + } + } + destructive_two_d_iterator(typename containertype::iterator begin, + typename containertype::iterator end, + typename containertype::iterator curr, + typename containertype::value_type::TWOD_ITER_ col) + : row_begin(begin), row_end(end), row_current(curr), col_current(col) { + advance_past_end(); // in case cur->begin() == cur->end() + } + destructive_two_d_iterator() + : row_begin(), row_end(), row_current(), col_current() {} + + typename containertype::iterator row_begin, row_end, row_current; + typename containertype::value_type::TWOD_ITER_ col_current; + + // This is the part that destroys + void advance_past_end() { // used when col_current points to end() + while (col_current == row_current->TWOD_END_()) { // end of current row + row_current->clear(); // the destructive part + // It would be nice if we could decrement sparsetable->num_buckets + // here + ++row_current; // go to beginning of next + if (row_current != row_end) // col is irrelevant at end + col_current = row_current->TWOD_BEGIN_(); + else + break; // don't go past row_end + } + } + + // EVERYTHING FROM HERE DOWN IS THE SAME AS THE REGULAR ITERATOR + reference operator*() const { return *col_current; } + pointer operator->() const { return &(operator*()); } + + iterator& operator++() { + assert(row_current != row_end); // how to ++ from there? + ++col_current; + advance_past_end(); // in case col_current is at end() + return *this; + } + iterator operator++(int) { + iterator tmp(*this); + ++*this; + return tmp; + } + + bool operator==(const iterator& it) const { + return (row_begin == it.row_begin && row_end == it.row_end && + row_current == it.row_current && + (row_current == row_end || col_current == it.col_current)); + } + bool operator!=(const iterator& it) const { return !(*this == it); } +}; + +#undef TWOD_BEGIN_ +#undef TWOD_END_ +#undef TWOD_ITER_ +#undef TWOD_CONST_ITER_ + +// SPARSE-TABLE +// ------------ +// The idea is that a table with (logically) t buckets is divided +// into t/M *groups* of M buckets each. (M is a constant set in +// GROUP_SIZE for efficiency.) Each group is stored sparsely. +// Thus, inserting into the table causes some array to grow, which is +// slow but still constant time. Lookup involves doing a +// logical-position-to-sparse-position lookup, which is also slow but +// constant time. The larger M is, the slower these operations are +// but the less overhead (slightly). +// +// To store the sparse array, we store a bitmap B, where B[i] = 1 iff +// bucket i is non-empty. Then to look up bucket i we really look up +// array[# of 1s before i in B]. This is constant time for fixed M. +// +// Terminology: the position of an item in the overall table (from +// 1 .. t) is called its "location." The logical position in a group +// (from 1 .. M ) is called its "position." The actual location in +// the array (from 1 .. # of non-empty buckets in the group) is +// called its "offset." + +template +class sparsegroup { + public: + typedef T value_type; + typedef Alloc allocator_type; + + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + typedef std::integral_constant< + bool, (is_relocatable::value && + std::is_same>::value)> + realloc_and_memmove_ok; // we pretend mv(x,y) == "x.~T(); + // new(x) T(y)" + public: + // Basic types + typedef typename value_alloc_type::reference reference; + typedef typename value_alloc_type::const_reference const_reference; + typedef typename value_alloc_type::pointer pointer; + typedef typename value_alloc_type::const_pointer const_pointer; + + typedef table_iterator> iterator; + typedef const_table_iterator> + const_iterator; + typedef table_element_adaptor> + element_adaptor; + typedef uint16_t size_type; // max # of buckets + typedef int16_t difference_type; + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; // from iterator.h + + // These are our special iterators, that go over non-empty buckets in a + // group. These aren't const-only because you can change non-empty bcks. + typedef pointer nonempty_iterator; + typedef const_pointer const_nonempty_iterator; + typedef std::reverse_iterator reverse_nonempty_iterator; + typedef std::reverse_iterator + const_reverse_nonempty_iterator; + + // Iterator functions + iterator begin() { return iterator(this, 0); } + const_iterator begin() const { return const_iterator(this, 0); } + iterator end() { return iterator(this, size()); } + const_iterator end() const { return const_iterator(this, size()); } + reverse_iterator rbegin() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + // We'll have versions for our special non-empty iterator too + nonempty_iterator nonempty_begin() { return group; } + const_nonempty_iterator nonempty_begin() const { return group; } + nonempty_iterator nonempty_end() { return group + settings.num_buckets; } + const_nonempty_iterator nonempty_end() const { + return group + settings.num_buckets; + } + reverse_nonempty_iterator nonempty_rbegin() { + return reverse_nonempty_iterator(nonempty_end()); + } + const_reverse_nonempty_iterator nonempty_rbegin() const { + return const_reverse_nonempty_iterator(nonempty_end()); + } + reverse_nonempty_iterator nonempty_rend() { + return reverse_nonempty_iterator(nonempty_begin()); + } + const_reverse_nonempty_iterator nonempty_rend() const { + return const_reverse_nonempty_iterator(nonempty_begin()); + } + + // This gives us the "default" value to return for an empty bucket. + // We just use the default constructor on T, the template type + const_reference default_value() const { + static value_type defaultval = value_type(); + return defaultval; + } + + private: + // We need to do all this bit manipulation, of course. ick + static size_type charbit(size_type i) { return i >> 3; } + static size_type modbit(size_type i) { return 1 << (i & 7); } + int bmtest(size_type i) const { return bitmap[charbit(i)] & modbit(i); } + void bmset(size_type i) { bitmap[charbit(i)] |= modbit(i); } + void bmclear(size_type i) { bitmap[charbit(i)] &= ~modbit(i); } + + pointer allocate_group(size_type n) { + pointer retval = settings.allocate(n); + if (retval == NULL) { + // We really should use PRIuS here, but I don't want to have to add + // a whole new configure option, with concomitant macro namespace + // pollution, just to print this (unlikely) error message. So I + // cast. + fprintf(stderr, "sparsehash FATAL ERROR: failed to allocate %lu groups\n", + static_cast(n)); + exit(1); + } + return retval; + } + + void free_group() { + if (!group) return; + pointer end_it = group + settings.num_buckets; + for (pointer p = group; p != end_it; ++p) p->~value_type(); + settings.deallocate(group, settings.num_buckets); + group = NULL; + } + + static size_type bits_in_char(unsigned char c) { + // We could make these ints. The tradeoff is size (eg does it overwhelm + // the cache?) vs efficiency in referencing sub-word-sized array + // elements. + static const char bits_in[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, + }; + return bits_in[c]; + } + + public: // get_iter() in sparsetable needs it + // We need a small function that tells us how many set bits there are + // in positions 0..i-1 of the bitmap. It uses a big table. + // We make it static so templates don't allocate lots of these tables. + // There are lots of ways to do this calculation (called 'popcount'). + // The 8-bit table lookup is one of the fastest, though this + // implementation suffers from not doing any loop unrolling. See, eg, + // http://www.dalkescientific.com/writings/diary/archive/2008/07/03/hakmem_and_other_popcounts.html + // http://gurmeetsingh.wordpress.com/2008/08/05/fast-bit-counting-routines/ + static size_type pos_to_offset(const unsigned char* bm, size_type pos) { + size_type retval = 0; + + // [Note: condition pos > 8 is an optimization; convince yourself we + // give exactly the same result as if we had pos >= 8 here instead.] + for (; pos > 8; pos -= 8) // bm[0..pos/8-1] + retval += bits_in_char(*bm++); // chars we want *all* bits in + return retval + bits_in_char(*bm & ((1 << pos) - 1)); // char including pos + } + + size_type pos_to_offset(size_type pos) const { // not static but still const + return pos_to_offset(bitmap, pos); + } + + // Returns the (logical) position in the bm[] array, i, such that + // bm[i] is the offset-th set bit in the array. It is the inverse + // of pos_to_offset. get_pos() uses this function to find the index + // of an nonempty_iterator in the table. Bit-twiddling from + // http://hackersdelight.org/basics.pdf + static size_type offset_to_pos(const unsigned char* bm, size_type offset) { + size_type retval = 0; + // This is sizeof(this->bitmap). + const size_type group_size = (GROUP_SIZE - 1) / 8 + 1; + for (size_type i = 0; i < group_size; i++) { // forward scan + const size_type pop_count = bits_in_char(*bm); + if (pop_count > offset) { + unsigned char last_bm = *bm; + for (; offset > 0; offset--) { + last_bm &= (last_bm - 1); // remove right-most set bit + } + // Clear all bits to the left of the rightmost bit (the &), + // and then clear the rightmost bit but set all bits to the + // right of it (the -1). + last_bm = (last_bm & -last_bm) - 1; + retval += bits_in_char(last_bm); + return retval; + } + offset -= pop_count; + retval += 8; + bm++; + } + return retval; + } + + size_type offset_to_pos(size_type offset) const { + return offset_to_pos(bitmap, offset); + } + + public: + // Constructors -- default and copy -- and destructor + explicit sparsegroup(allocator_type& a) + : group(0), settings(alloc_impl(a)) { + memset(bitmap, 0, sizeof(bitmap)); + } + sparsegroup(const sparsegroup& x) : group(0), settings(x.settings) { + if (settings.num_buckets) { + group = allocate_group(x.settings.num_buckets); + std::uninitialized_copy(x.group, x.group + x.settings.num_buckets, group); + } + memcpy(bitmap, x.bitmap, sizeof(bitmap)); + } + ~sparsegroup() { free_group(); } + + // Operator= is just like the copy constructor, I guess + // TODO(austern): Make this exception safe. Handle exceptions in + // value_type's + // copy constructor. + sparsegroup& operator=(const sparsegroup& x) { + if (&x == this) return *this; // x = x + if (x.settings.num_buckets == 0) { + free_group(); + } else { + pointer p = allocate_group(x.settings.num_buckets); + std::uninitialized_copy(x.group, x.group + x.settings.num_buckets, p); + free_group(); + group = p; + } + memcpy(bitmap, x.bitmap, sizeof(bitmap)); + settings.num_buckets = x.settings.num_buckets; + return *this; + } + + // Many STL algorithms use swap instead of copy constructors + void swap(sparsegroup& x) { + std::swap(group, x.group); // defined in + for (int i = 0; i < sizeof(bitmap) / sizeof(*bitmap); ++i) + std::swap(bitmap[i], x.bitmap[i]); // swap not defined on arrays + std::swap(settings.num_buckets, x.settings.num_buckets); + // we purposefully don't swap the allocator, which may not be swap-able + } + + // It's always nice to be able to clear a table without deallocating it + void clear() { + free_group(); + memset(bitmap, 0, sizeof(bitmap)); + settings.num_buckets = 0; + } + + // Functions that tell you about size. Alas, these aren't so useful + // because our table is always fixed size. + size_type size() const { return GROUP_SIZE; } + size_type max_size() const { return GROUP_SIZE; } + bool empty() const { return false; } + // We also may want to know how many *used* buckets there are + size_type num_nonempty() const { return settings.num_buckets; } + + // get()/set() are explicitly const/non-const. You can use [] if + // you want something that can be either (potentially more expensive). + const_reference get(size_type i) const { + if (bmtest(i)) // bucket i is occupied + return group[pos_to_offset(bitmap, i)]; + else + return default_value(); // return the default reference + } + + // TODO(csilvers): make protected + friend + // This is used by sparse_hashtable to get an element from the table + // when we know it exists. + const_reference unsafe_get(size_type i) const { + assert(bmtest(i)); + return group[pos_to_offset(bitmap, i)]; + } + + // TODO(csilvers): make protected + friend + reference mutating_get(size_type i) { // fills bucket i before getting + if (!bmtest(i)) set(i, default_value()); + return group[pos_to_offset(bitmap, i)]; + } + + // Syntactic sugar. It's easy to return a const reference. To + // return a non-const reference, we need to use the assigner adaptor. + const_reference operator[](size_type i) const { return get(i); } + + element_adaptor operator[](size_type i) { return element_adaptor(this, i); } + + private: + // Create space at group[offset], assuming value_type has trivial + // copy constructor and destructor, and the allocator_type is + // the default libc_allocator_with_alloc. (Really, we want it to have + // "trivial move", because that's what realloc and memmove both do. + // But there's no way to capture that using type_traits, so we + // pretend that move(x, y) is equivalent to "x.~T(); new(x) T(y);" + // which is pretty much correct, if a bit conservative.) + void set_aux(size_type offset, std::true_type) { + group = settings.realloc_or_die(group, settings.num_buckets + 1); + // This is equivalent to memmove(), but faster on my Intel P4, + // at least with gcc4.1 -O2 / glibc 2.3.6. + for (size_type i = settings.num_buckets; i > offset; --i) + // cast to void* to prevent compiler warnings about writing to an object + // with no trivial copy-assignment + memcpy(static_cast(group + i), group + i - 1, sizeof(*group)); + } + + // Create space at group[offset], without special assumptions about + // value_type + // and allocator_type. + void set_aux(size_type offset, std::false_type) { + // This is valid because 0 <= offset <= num_buckets + pointer p = allocate_group(settings.num_buckets + 1); + std::uninitialized_copy(group, group + offset, p); + std::uninitialized_copy(group + offset, group + settings.num_buckets, + p + offset + 1); + free_group(); + group = p; + } + + public: + // This returns a reference to the inserted item (which is a copy of val). + // TODO(austern): Make this exception safe: handle exceptions from + // value_type's copy constructor. + reference set(size_type i, const_reference val) { + size_type offset = + pos_to_offset(bitmap, i); // where we'll find (or insert) + if (bmtest(i)) { + // Delete the old value, which we're replacing with the new one + group[offset].~value_type(); + } else { + set_aux(offset, realloc_and_memmove_ok()); + ++settings.num_buckets; + bmset(i); + } + // This does the actual inserting. Since we made the array using + // malloc, we use "placement new" to just call the constructor. + new (&group[offset]) value_type(val); + return group[offset]; + } + + // We let you see if a bucket is non-empty without retrieving it + bool test(size_type i) const { return bmtest(i) != 0; } + bool test(iterator pos) const { return bmtest(pos.pos) != 0; } + + private: + // Shrink the array, assuming value_type has trivial copy + // constructor and destructor, and the allocator_type is the default + // libc_allocator_with_alloc. (Really, we want it to have "trivial + // move", because that's what realloc and memmove both do. But + // there's no way to capture that using type_traits, so we pretend + // that move(x, y) is equivalent to ""x.~T(); new(x) T(y);" + // which is pretty much correct, if a bit conservative.) + void erase_aux(size_type offset, std::true_type) { + // This isn't technically necessary, since we know we have a + // trivial destructor, but is a cheap way to get a bit more safety. + group[offset].~value_type(); + // This is equivalent to memmove(), but faster on my Intel P4, + // at lesat with gcc4.1 -O2 / glibc 2.3.6. + assert(settings.num_buckets > 0); + for (size_type i = offset; i < settings.num_buckets - 1; ++i) + // cast to void* to prevent compiler warnings about writing to an object + // with no trivial copy-assignment + // hopefully inlined! + memcpy(static_cast(group + i), group + i + 1, sizeof(*group)); + group = settings.realloc_or_die(group, settings.num_buckets - 1); + } + + // Shrink the array, without any special assumptions about value_type and + // allocator_type. + void erase_aux(size_type offset, std::false_type) { + // This is valid because 0 <= offset < num_buckets. Note the inequality. + pointer p = allocate_group(settings.num_buckets - 1); + std::uninitialized_copy(group, group + offset, p); + std::uninitialized_copy(group + offset + 1, group + settings.num_buckets, + p + offset); + free_group(); + group = p; + } + + public: + // This takes the specified elements out of the group. This is + // "undefining", rather than "clearing". + // TODO(austern): Make this exception safe: handle exceptions from + // value_type's copy constructor. + void erase(size_type i) { + if (bmtest(i)) { // trivial to erase empty bucket + size_type offset = + pos_to_offset(bitmap, i); // where we'll find (or insert) + if (settings.num_buckets == 1) { + free_group(); + group = NULL; + } else { + erase_aux(offset, realloc_and_memmove_ok()); + } + --settings.num_buckets; + bmclear(i); + } + } + + void erase(iterator pos) { erase(pos.pos); } + + void erase(iterator start_it, iterator end_it) { + // This could be more efficient, but to do so we'd need to make + // bmclear() clear a range of indices. Doesn't seem worth it. + for (; start_it != end_it; ++start_it) erase(start_it); + } + + // I/O + // We support reading and writing groups to disk. We don't store + // the actual array contents (which we don't know how to store), + // just the bitmap and size. Meant to be used with table I/O. + + template + bool write_metadata(OUTPUT* fp) const { + // we explicitly set to uint16_t + assert(sizeof(settings.num_buckets) == 2); + if (!sparsehash_internal::write_bigendian_number(fp, settings.num_buckets, + 2)) + return false; + if (!sparsehash_internal::write_data(fp, bitmap, sizeof(bitmap))) + return false; + return true; + } + + // Reading destroys the old group contents! Returns true if all was ok. + template + bool read_metadata(INPUT* fp) { + clear(); + if (!sparsehash_internal::read_bigendian_number(fp, &settings.num_buckets, + 2)) + return false; + if (!sparsehash_internal::read_data(fp, bitmap, sizeof(bitmap))) + return false; + // We'll allocate the space, but we won't fill it: it will be + // left as uninitialized raw memory. + group = allocate_group(settings.num_buckets); + return true; + } + + // Again, only meaningful if value_type is a POD. + template + bool read_nopointer_data(INPUT* fp) { + for (nonempty_iterator it = nonempty_begin(); it != nonempty_end(); ++it) { + if (!sparsehash_internal::read_data(fp, &(*it), sizeof(*it))) + return false; + } + return true; + } + + // If your keys and values are simple enough, we can write them + // to disk for you. "simple enough" means POD and no pointers. + // However, we don't try to normalize endianness. + template + bool write_nopointer_data(OUTPUT* fp) const { + for (const_nonempty_iterator it = nonempty_begin(); it != nonempty_end(); + ++it) { + if (!sparsehash_internal::write_data(fp, &(*it), sizeof(*it))) + return false; + } + return true; + } + + // Comparisons. We only need to define == and < -- we get + // != > <= >= via relops.h (which we happily included above). + // Note the comparisons are pretty arbitrary: we compare + // values of the first index that isn't equal (using default + // value for empty buckets). + bool operator==(const sparsegroup& x) const { + return (settings.num_buckets == x.settings.num_buckets && + memcmp(bitmap, x.bitmap, sizeof(bitmap)) == 0 && + std::equal(begin(), end(), x.begin())); // from + } + + bool operator<(const sparsegroup& x) const { // also from + return std::lexicographical_compare(begin(), end(), x.begin(), x.end()); + } + bool operator!=(const sparsegroup& x) const { return !(*this == x); } + bool operator<=(const sparsegroup& x) const { return !(x < *this); } + bool operator>(const sparsegroup& x) const { return x < *this; } + bool operator>=(const sparsegroup& x) const { return !(*this < x); } + + private: + template + class alloc_impl : public A { + public: + typedef typename A::pointer pointer; + typedef typename A::size_type size_type; + + // Convert a normal allocator to one that has realloc_or_die() + alloc_impl(const A& a) : A(a) {} + + // realloc_or_die should only be used when using the default + // allocator (libc_allocator_with_realloc). + pointer realloc_or_die(pointer /*ptr*/, size_type /*n*/) { + fprintf(stderr, + "realloc_or_die is only supported for " + "libc_allocator_with_realloc\n"); + exit(1); + return NULL; + } + }; + + // A template specialization of alloc_impl for + // libc_allocator_with_realloc that can handle realloc_or_die. + template + class alloc_impl> + : public libc_allocator_with_realloc { + public: + typedef typename libc_allocator_with_realloc::pointer pointer; + typedef typename libc_allocator_with_realloc::size_type size_type; + + alloc_impl(const libc_allocator_with_realloc& a) + : libc_allocator_with_realloc(a) {} + + pointer realloc_or_die(pointer ptr, size_type n) { + pointer retval = this->reallocate(ptr, n); + if (retval == NULL) { + fprintf(stderr, + "sparsehash: FATAL ERROR: failed to reallocate " + "%lu elements for ptr %p", + static_cast(n), static_cast(ptr)); + exit(1); + } + return retval; + } + }; + + // Package allocator with num_buckets to eliminate memory needed for the + // zero-size allocator. + // If new fields are added to this class, we should add them to + // operator= and swap. + class Settings : public alloc_impl { + public: + Settings(const alloc_impl& a, uint16_t n = 0) + : alloc_impl(a), num_buckets(n) {} + Settings(const Settings& s) + : alloc_impl(s), num_buckets(s.num_buckets) {} + + uint16_t num_buckets; // limits GROUP_SIZE to 64K + }; + + // The actual data + pointer group; // (small) array of T's + Settings settings; // allocator and num_buckets + unsigned char + bitmap[(GROUP_SIZE - 1) / 8 + 1]; // fancy math is so we round up +}; + +// We need a global swap as well +template +inline void swap(sparsegroup& x, + sparsegroup& y) { + x.swap(y); +} + +// --------------------------------------------------------------------------- + +template > +class sparsetable { + private: + using value_alloc_type = + typename std::allocator_traits::template rebind_alloc; + using vector_alloc = + typename std::allocator_traits::template rebind_alloc< + sparsegroup>; + + public: + // Basic types + typedef T value_type; // stolen from stl_vector.h + typedef Alloc allocator_type; + typedef typename value_alloc_type::size_type size_type; + typedef typename value_alloc_type::difference_type difference_type; + typedef typename value_alloc_type::reference reference; + typedef typename value_alloc_type::const_reference const_reference; + typedef typename value_alloc_type::pointer pointer; + typedef typename value_alloc_type::const_pointer const_pointer; + typedef table_iterator> iterator; + typedef const_table_iterator> + const_iterator; + typedef table_element_adaptor> + element_adaptor; + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; // from iterator.h + + // These are our special iterators, that go over non-empty buckets in a + // table. These aren't const only because you can change non-empty bcks. + typedef two_d_iterator, vector_alloc>> + nonempty_iterator; + typedef const_two_d_iterator, vector_alloc>> + const_nonempty_iterator; + typedef std::reverse_iterator reverse_nonempty_iterator; + typedef std::reverse_iterator + const_reverse_nonempty_iterator; + // Another special iterator: it frees memory as it iterates (used to resize) + typedef destructive_two_d_iterator, vector_alloc>> + destructive_iterator; + + // Iterator functions + iterator begin() { return iterator(this, 0); } + const_iterator begin() const { return const_iterator(this, 0); } + iterator end() { return iterator(this, size()); } + const_iterator end() const { return const_iterator(this, size()); } + reverse_iterator rbegin() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + // Versions for our special non-empty iterator + nonempty_iterator nonempty_begin() { + return nonempty_iterator(groups.begin(), groups.end(), groups.begin()); + } + const_nonempty_iterator nonempty_begin() const { + return const_nonempty_iterator(groups.begin(), groups.end(), + groups.begin()); + } + nonempty_iterator nonempty_end() { + return nonempty_iterator(groups.begin(), groups.end(), groups.end()); + } + const_nonempty_iterator nonempty_end() const { + return const_nonempty_iterator(groups.begin(), groups.end(), groups.end()); + } + reverse_nonempty_iterator nonempty_rbegin() { + return reverse_nonempty_iterator(nonempty_end()); + } + const_reverse_nonempty_iterator nonempty_rbegin() const { + return const_reverse_nonempty_iterator(nonempty_end()); + } + reverse_nonempty_iterator nonempty_rend() { + return reverse_nonempty_iterator(nonempty_begin()); + } + const_reverse_nonempty_iterator nonempty_rend() const { + return const_reverse_nonempty_iterator(nonempty_begin()); + } + destructive_iterator destructive_begin() { + return destructive_iterator(groups.begin(), groups.end(), groups.begin()); + } + destructive_iterator destructive_end() { + return destructive_iterator(groups.begin(), groups.end(), groups.end()); + } + + typedef sparsegroup group_type; + using group_vector_type_allocator_type = + typename std::allocator_traits::template rebind_alloc; + typedef std::vector + group_vector_type; + + typedef typename group_vector_type::reference GroupsReference; + typedef typename group_vector_type::const_reference GroupsConstReference; + typedef typename group_vector_type::iterator GroupsIterator; + typedef typename group_vector_type::const_iterator GroupsConstIterator; + + // How to deal with the proper group + static size_type num_groups(size_type num) { // how many to hold num buckets + return num == 0 ? 0 : ((num - 1) / GROUP_SIZE) + 1; + } + + uint16_t pos_in_group(size_type i) const { + return static_cast(i % GROUP_SIZE); + } + size_type group_num(size_type i) const { return i / GROUP_SIZE; } + GroupsReference which_group(size_type i) { return groups[group_num(i)]; } + GroupsConstReference which_group(size_type i) const { + return groups[group_num(i)]; + } + + public: + // Constructors -- default, normal (when you specify size), and copy + explicit sparsetable(size_type sz = 0, Alloc alloc = Alloc()) + : groups(vector_alloc(alloc)), settings(alloc, sz) { + groups.resize(num_groups(sz), group_type(settings)); + } + // We can get away with using the default copy constructor, + // and default destructor, and hence the default operator=. Huzzah! + + // Many STL algorithms use swap instead of copy constructors + void swap(sparsetable& x) { + std::swap(groups, x.groups); // defined in stl_algobase.h + std::swap(settings.table_size, x.settings.table_size); + std::swap(settings.num_buckets, x.settings.num_buckets); + } + + // It's always nice to be able to clear a table without deallocating it + void clear() { + GroupsIterator group; + for (group = groups.begin(); group != groups.end(); ++group) { + group->clear(); + } + settings.num_buckets = 0; + } + + // ACCESSOR FUNCTIONS for the things we templatize on, basically + allocator_type get_allocator() const { return allocator_type(settings); } + + // Functions that tell you about size. + // NOTE: empty() is non-intuitive! It does not tell you the number + // of not-empty buckets (use num_nonempty() for that). Instead + // it says whether you've allocated any buckets or not. + size_type size() const { return settings.table_size; } + size_type max_size() const { return settings.max_size(); } + bool empty() const { return settings.table_size == 0; } + // We also may want to know how many *used* buckets there are + size_type num_nonempty() const { return settings.num_buckets; } + + // OK, we'll let you resize one of these puppies + void resize(size_type new_size) { + groups.resize(num_groups(new_size), group_type(settings)); + if (new_size < settings.table_size) { + // lower num_buckets, clear last group + if (pos_in_group(new_size) > 0) // need to clear inside last group + groups.back().erase(groups.back().begin() + pos_in_group(new_size), + groups.back().end()); + settings.num_buckets = 0; // refigure # of used buckets + GroupsConstIterator group; + for (group = groups.begin(); group != groups.end(); ++group) + settings.num_buckets += group->num_nonempty(); + } + settings.table_size = new_size; + } + + // We let you see if a bucket is non-empty without retrieving it + bool test(size_type i) const { + assert(i < settings.table_size); + return which_group(i).test(pos_in_group(i)); + } + bool test(iterator pos) const { + return which_group(pos.pos).test(pos_in_group(pos.pos)); + } + bool test(const_iterator pos) const { + return which_group(pos.pos).test(pos_in_group(pos.pos)); + } + + // We only return const_references because it's really hard to + // return something settable for empty buckets. Use set() instead. + const_reference get(size_type i) const { + assert(i < settings.table_size); + return which_group(i).get(pos_in_group(i)); + } + + // TODO(csilvers): make protected + friend + // This is used by sparse_hashtable to get an element from the table + // when we know it exists (because the caller has called test(i)). + const_reference unsafe_get(size_type i) const { + assert(i < settings.table_size); + assert(test(i)); + return which_group(i).unsafe_get(pos_in_group(i)); + } + + // TODO(csilvers): make protected + friend element_adaptor + reference mutating_get(size_type i) { // fills bucket i before getting + assert(i < settings.table_size); + typename group_type::size_type old_numbuckets = + which_group(i).num_nonempty(); + reference retval = which_group(i).mutating_get(pos_in_group(i)); + settings.num_buckets += which_group(i).num_nonempty() - old_numbuckets; + return retval; + } + + // Syntactic sugar. As in sparsegroup, the non-const version is harder + const_reference operator[](size_type i) const { return get(i); } + + element_adaptor operator[](size_type i) { return element_adaptor(this, i); } + + // Needed for hashtables, gets as a nonempty_iterator. Crashes for empty + // bcks + const_nonempty_iterator get_iter(size_type i) const { + assert(test(i)); // how can a nonempty_iterator point to an empty bucket? + return const_nonempty_iterator( + groups.begin(), groups.end(), groups.begin() + group_num(i), + (groups[group_num(i)].nonempty_begin() + + groups[group_num(i)].pos_to_offset(pos_in_group(i)))); + } + // For nonempty we can return a non-const version + nonempty_iterator get_iter(size_type i) { + assert(test(i)); // how can a nonempty_iterator point to an empty bucket? + return nonempty_iterator( + groups.begin(), groups.end(), groups.begin() + group_num(i), + (groups[group_num(i)].nonempty_begin() + + groups[group_num(i)].pos_to_offset(pos_in_group(i)))); + } + + // And the reverse transformation. + size_type get_pos(const const_nonempty_iterator& it) const { + difference_type current_row = it.row_current - it.row_begin; + difference_type current_col = + (it.col_current - groups[current_row].nonempty_begin()); + return ((current_row * GROUP_SIZE) + + groups[current_row].offset_to_pos(current_col)); + } + + // This returns a reference to the inserted item (which is a copy of val) + // The trick is to figure out whether we're replacing or inserting anew + reference set(size_type i, const_reference val) { + assert(i < settings.table_size); + typename group_type::size_type old_numbuckets = + which_group(i).num_nonempty(); + reference retval = which_group(i).set(pos_in_group(i), val); + settings.num_buckets += which_group(i).num_nonempty() - old_numbuckets; + return retval; + } + + // This takes the specified elements out of the table. This is + // "undefining", rather than "clearing". + void erase(size_type i) { + assert(i < settings.table_size); + typename group_type::size_type old_numbuckets = + which_group(i).num_nonempty(); + which_group(i).erase(pos_in_group(i)); + settings.num_buckets += which_group(i).num_nonempty() - old_numbuckets; + } + + void erase(iterator pos) { erase(pos.pos); } + + void erase(iterator start_it, iterator end_it) { + // This could be more efficient, but then we'd need to figure + // out if we spanned groups or not. Doesn't seem worth it. + for (; start_it != end_it; ++start_it) erase(start_it); + } + + // We support reading and writing tables to disk. We don't store + // the actual array contents (which we don't know how to store), + // just the groups and sizes. Returns true if all went ok. + + private: + // Every time the disk format changes, this should probably change too + typedef unsigned long MagicNumberType; + static const MagicNumberType MAGIC_NUMBER = 0x24687531; + + // Old versions of this code write all data in 32 bits. We need to + // support these files as well as having support for 64-bit systems. + // So we use the following encoding scheme: for values < 2^32-1, we + // store in 4 bytes in big-endian order. For values > 2^32, we + // store 0xFFFFFFF followed by 8 bytes in big-endian order. This + // causes us to mis-read old-version code that stores exactly + // 0xFFFFFFF, but I don't think that is likely to have happened for + // these particular values. + template + static bool write_32_or_64(OUTPUT* fp, IntType value) { + if (value < 0xFFFFFFFFULL) { // fits in 4 bytes + if (!sparsehash_internal::write_bigendian_number(fp, value, 4)) + return false; + } else { + if (!sparsehash_internal::write_bigendian_number(fp, 0xFFFFFFFFUL, 4)) + return false; + if (!sparsehash_internal::write_bigendian_number(fp, value, 8)) + return false; + } + return true; + } + + template + static bool read_32_or_64(INPUT* fp, IntType* value) { // reads into value + MagicNumberType first4 = 0; // a convenient 32-bit unsigned type + if (!sparsehash_internal::read_bigendian_number(fp, &first4, 4)) + return false; + if (first4 < 0xFFFFFFFFULL) { + *value = first4; + } else { + if (!sparsehash_internal::read_bigendian_number(fp, value, 8)) + return false; + } + return true; + } + + public: + // read/write_metadata() and read_write/nopointer_data() are DEPRECATED. + // Use serialize() and unserialize(), below, for new code. + + template + bool write_metadata(OUTPUT* fp) const { + if (!write_32_or_64(fp, MAGIC_NUMBER)) return false; + if (!write_32_or_64(fp, settings.table_size)) return false; + if (!write_32_or_64(fp, settings.num_buckets)) return false; + + GroupsConstIterator group; + for (group = groups.begin(); group != groups.end(); ++group) + if (group->write_metadata(fp) == false) return false; + return true; + } + + // Reading destroys the old table contents! Returns true if read ok. + template + bool read_metadata(INPUT* fp) { + size_type magic_read = 0; + if (!read_32_or_64(fp, &magic_read)) return false; + if (magic_read != MAGIC_NUMBER) { + clear(); // just to be consistent + return false; + } + + if (!read_32_or_64(fp, &settings.table_size)) return false; + if (!read_32_or_64(fp, &settings.num_buckets)) return false; + + resize(settings.table_size); // so the vector's sized ok + GroupsIterator group; + for (group = groups.begin(); group != groups.end(); ++group) + if (group->read_metadata(fp) == false) return false; + return true; + } + + // This code is identical to that for SparseGroup + // If your keys and values are simple enough, we can write them + // to disk for you. "simple enough" means no pointers. + // However, we don't try to normalize endianness + bool write_nopointer_data(FILE* fp) const { + for (const_nonempty_iterator it = nonempty_begin(); it != nonempty_end(); + ++it) { + if (!fwrite(&*it, sizeof(*it), 1, fp)) return false; + } + return true; + } + + // When reading, we have to override the potential const-ness of *it + bool read_nopointer_data(FILE* fp) { + for (nonempty_iterator it = nonempty_begin(); it != nonempty_end(); ++it) { + if (!fread(reinterpret_cast(&(*it)), sizeof(*it), 1, fp)) + return false; + } + return true; + } + + // INPUT and OUTPUT must be either a FILE, *or* a C++ stream + // (istream, ostream, etc) *or* a class providing + // Read(void*, size_t) and Write(const void*, size_t) + // (respectively), which writes a buffer into a stream + // (which the INPUT/OUTPUT instance presumably owns). + + typedef sparsehash_internal::pod_serializer NopointerSerializer; + + // ValueSerializer: a functor. operator()(OUTPUT*, const value_type&) + template + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + if (!write_metadata(fp)) return false; + for (const_nonempty_iterator it = nonempty_begin(); it != nonempty_end(); + ++it) { + if (!serializer(fp, *it)) return false; + } + return true; + } + + // ValueSerializer: a functor. operator()(INPUT*, value_type*) + template + bool unserialize(ValueSerializer serializer, INPUT* fp) { + clear(); + if (!read_metadata(fp)) return false; + for (nonempty_iterator it = nonempty_begin(); it != nonempty_end(); ++it) { + if (!serializer(fp, &*it)) return false; + } + return true; + } + + // Comparisons. Note the comparisons are pretty arbitrary: we + // compare values of the first index that isn't equal (using default + // value for empty buckets). + bool operator==(const sparsetable& x) const { + return (settings.table_size == x.settings.table_size && + settings.num_buckets == x.settings.num_buckets && + groups == x.groups); + } + + bool operator<(const sparsetable& x) const { + return std::lexicographical_compare(begin(), end(), x.begin(), x.end()); + } + bool operator!=(const sparsetable& x) const { return !(*this == x); } + bool operator<=(const sparsetable& x) const { return !(x < *this); } + bool operator>(const sparsetable& x) const { return x < *this; } + bool operator>=(const sparsetable& x) const { return !(*this < x); } + + private: + // Package allocator with table_size and num_buckets to eliminate memory + // needed for the zero-size allocator. + // If new fields are added to this class, we should add them to + // operator= and swap. + class Settings : public allocator_type { + public: + typedef typename allocator_type::size_type size_type; + + Settings(const allocator_type& a, size_type sz = 0, size_type n = 0) + : allocator_type(a), table_size(sz), num_buckets(n) {} + + Settings(const Settings& s) + : allocator_type(s), + table_size(s.table_size), + num_buckets(s.num_buckets) {} + + size_type table_size; // how many buckets they want + size_type num_buckets; // number of non-empty buckets + }; + + // The actual data + group_vector_type groups; // our list of groups + Settings settings; // allocator, table size, buckets +}; + +// We need a global swap as well +template +inline void swap(sparsetable& x, + sparsetable& y) { + x.swap(y); +} +} // namespace google diff --git a/oap-native-sql/cpp/src/third_party/sparsehash/traits b/oap-native-sql/cpp/src/third_party/sparsehash/traits new file mode 100644 index 000000000..65135a6dc --- /dev/null +++ b/oap-native-sql/cpp/src/third_party/sparsehash/traits @@ -0,0 +1,55 @@ +// Copyright (c) 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#include +#include // For pair + +namespace google { + +// trait which can be added to user types to enable use of memcpy in sparsetable +// Example: +// namespace google{ +// template <> +// struct is_relocatable : std::true_type {}; +// } + +template +struct is_relocatable + : std::integral_constant::value && + std::is_trivially_destructible::value)> {}; +template +struct is_relocatable> + : std::integral_constant::value && + is_relocatable::value)> {}; + +template +struct is_relocatable : is_relocatable {}; +} \ No newline at end of file diff --git a/oap-native-sql/cpp/src/utils/macros.h b/oap-native-sql/cpp/src/utils/macros.h index 9e9b3e3de..3f01888f1 100644 --- a/oap-native-sql/cpp/src/utils/macros.h +++ b/oap-native-sql/cpp/src/utils/macros.h @@ -39,5 +39,23 @@ time += std::chrono::duration_cast(end - start).count(); \ } while (false); +#define VECTOR_PRINT(v, name) \ + std::cout << "[" << name << "]:"; \ + for (int i = 0; i < v.size(); i++) { \ + if (i != v.size() - 1) \ + std::cout << v[i] << ","; \ + else \ + std::cout << v[i]; \ + } \ + std::cout << std::endl; + +#define THROW_NOT_OK(expr) \ + do { \ + auto __s = (expr); \ + if (!__s.ok()) { \ + throw std::runtime_error(__s.message()); \ + } \ + } while (false); + #define TIME_TO_STRING(time) \ (time > 10000 ? time / 1000 : time) << (time > 10000 ? " ms" : " us") diff --git a/oap-native-sql/resource/ApacheArrowInstallation.md b/oap-native-sql/resource/ApacheArrowInstallation.md index d0755ac8d..f6ea405ae 100644 --- a/oap-native-sql/resource/ApacheArrowInstallation.md +++ b/oap-native-sql/resource/ApacheArrowInstallation.md @@ -42,7 +42,7 @@ git clone https://github.com/Intel-bigdata/arrow.git cd arrow && git checkout native-sql-engine-clean mkdir -p arrow/cpp/release-build cd arrow/cpp/release-build -cmake -DARROW_GANDIVA_JAVA=ON -DARROW_GANDIVA=ON -DARROW_PARQUET=ON -DARROW_HDFS=ON -DARROW_BOOST_USE_SHARED=ON -DARROW_JNI=ON -DARROW_WITH_SNAPPY=ON -DARROW_FILESYSTEM=ON -DARROW_JSON=ON .. +cmake -DARROW_GANDIVA_JAVA=ON -DARROW_GANDIVA=ON -DARROW_PARQUET=ON -DARROW_HDFS=ON -DARROW_BOOST_USE_SHARED=ON -DARROW_JNI=ON -DARROW_WITH_SNAPPY=ON -DARROW_WITH_LZ4=ON -DARROW_FILESYSTEM=ON -DARROW_JSON=ON .. make -j make install diff --git a/oap-native-sql/resource/Native_SQL_Engine_Intro.jpg b/oap-native-sql/resource/Native_SQL_Engine_Intro.jpg deleted file mode 100755 index bb0754e2e..000000000 Binary files a/oap-native-sql/resource/Native_SQL_Engine_Intro.jpg and /dev/null differ diff --git a/oap-native-sql/resource/columnar.png b/oap-native-sql/resource/columnar.png new file mode 100644 index 000000000..d89074905 Binary files /dev/null and b/oap-native-sql/resource/columnar.png differ diff --git a/oap-native-sql/resource/dataset.png b/oap-native-sql/resource/dataset.png new file mode 100644 index 000000000..5d3e607ab Binary files /dev/null and b/oap-native-sql/resource/dataset.png differ diff --git a/oap-native-sql/resource/installation.md b/oap-native-sql/resource/installation.md new file mode 100644 index 000000000..8f5a6aeb2 --- /dev/null +++ b/oap-native-sql/resource/installation.md @@ -0,0 +1,136 @@ +## Installation + +For detailed testing scripts, please refer to [solution guide](https://github.com/Intel-bigdata/Solution_navigator/tree/master/nativesql) + +### Installation option 1: For evaluation, simple and fast + +#### install spark 3.0.0 or above + +[spark download](https://spark.apache.org/downloads.html) + +Remove original Arrow Jars inside Spark assemply folder +``` shell +yes | rm assembly/target/scala-2.12/jars/arrow-format-0.15.1.jar +yes | rm assembly/target/scala-2.12/jars/arrow-vector-0.15.1.jar +yes | rm assembly/target/scala-2.12/jars/arrow-memory-0.15.1.jar +``` + +#### install arrow 0.17.0 + +``` +git clone https://github.com/apache/arrow && cd arrow & git checkout arrow-0.17.0 +vim ci/conda_env_gandiva.yml +clangdev=7 +llvmdev=7 + +conda create -y -n pyarrow-dev -c conda-forge \ + --file ci/conda_env_unix.yml \ + --file ci/conda_env_cpp.yml \ + --file ci/conda_env_python.yml \ + --file ci/conda_env_gandiva.yml \ + compilers \ + python=3.7 \ + pandas +conda activate pyarrow-dev +``` + +#### Build native-sql cpp + +``` shell +git clone https://github.com/Intel-bigdata/OAP.git +cd OAP && git checkout branch-nativesql-spark-3.0.0 +cd oap-native-sql +cp cpp/src/resources/libhdfs.so ${HADOOP_HOME}/lib/native/ +cp cpp/src/resources/libprotobuf.so.13 /usr/lib64/ +``` + +Download spark-columnar-core-1.0-jar-with-dependencies.jar to local, add classPath to spark.driver.extraClassPath and spark.executor.extraClassPath +``` shell +Internal Location: vsr602://mnt/nvme2/chendi/000000/spark-columnar-core-1.0-jar-with-dependencies.jar +``` + +Download spark-sql_2.12-3.1.0-SNAPSHOT.jar to ${SPARK_HOME}/assembly/target/scala-2.12/jars/spark-sql_2.12-3.1.0-SNAPSHOT.jar +``` shell +Internal Location: vsr602://mnt/nvme2/chendi/000000/spark-sql_2.12-3.1.0-SNAPSHOT.jar +``` + +### Installation option 2: For contribution, Patch and build + +#### install spark 3.0.0 or above + +Please refer this link to install Spark. +[Apache Spark Installation](/oap-native-sql/resource/SparkInstallation.md) + +Remove original Arrow Jars inside Spark assemply folder +``` shell +yes | rm assembly/target/scala-2.12/jars/arrow-format-0.15.1.jar +yes | rm assembly/target/scala-2.12/jars/arrow-vector-0.15.1.jar +yes | rm assembly/target/scala-2.12/jars/arrow-memory-0.15.1.jar +``` + +#### install arrow 0.17.0 + +Please refer this markdown to install Apache Arrow and Gandiva. +[Apache Arrow Installation](/oap-native-sql/resource/ApacheArrowInstallation.md) + +#### compile and install oap-native-sql + +##### Install Googletest and Googlemock + +``` shell +yum install gtest-devel +yum install gmock +``` + +##### Build this project + +``` shell +git clone https://github.com/Intel-bigdata/OAP.git +cd OAP && git checkout branch-nativesql-spark-3.0.0 +cd oap-native-sql +cd cpp/ +mkdir build/ +cd build/ +cmake .. -DTESTS=ON +make -j +make install +#when deploying on multiple node, make sure all nodes copied libhdfs.so and libprotobuf.so.13 +``` + +``` shell +cd SparkColumnarPlugin/core/ +mvn clean package -DskipTests +``` +### Additonal Notes +[Notes for Installation Issues](/oap-native-sql/resource/InstallationNotes.md) + + +## Spark Configuration + +Add below configuration to spark-defaults.conf + +``` +##### Columnar Process Configuration + +spark.sql.parquet.columnarReaderBatchSize 4096 +spark.sql.sources.useV1SourceList avro +spark.sql.join.preferSortMergeJoin false +spark.sql.extensions com.intel.sparkColumnarPlugin.ColumnarPlugin +spark.shuffle.manager org.apache.spark.shuffle.sort.ColumnarShuffleManager + +spark.driver.extraClassPath ${PATH_TO_OAP_NATIVE_SQL}/core/target/spark-columnar-core-1.0-jar-with-dependencies.jar +spark.executor.extraClassPath ${PATH_TO_OAP_NATIVE_SQL}/core/target/spark-columnar-core-1.0-jar-with-dependencies.jar + +###### +``` +## Benchmark + +For initial microbenchmark performance, we add 10 fields up with spark, data size is 200G data + +![Performance](/oap-native-sql/resource/performance.png) + +## Coding Style + +* For Java code, we used [google-java-format](https://github.com/google/google-java-format) +* For Scala code, we used [Spark Scala Format](https://github.com/apache/spark/blob/master/dev/.scalafmt.conf), please use [scalafmt](https://github.com/scalameta/scalafmt) or run ./scalafmt for scala codes format +* For Cpp codes, we used Clang-Format, check on this link [google-vim-codefmt](https://github.com/google/vim-codefmt) for details. diff --git a/oap-native-sql/resource/kernel.png b/oap-native-sql/resource/kernel.png new file mode 100644 index 000000000..f88b002aa Binary files /dev/null and b/oap-native-sql/resource/kernel.png differ diff --git a/oap-native-sql/resource/nativesql_arch.png b/oap-native-sql/resource/nativesql_arch.png new file mode 100644 index 000000000..a8304f5af Binary files /dev/null and b/oap-native-sql/resource/nativesql_arch.png differ diff --git a/oap-native-sql/resource/shuffle.png b/oap-native-sql/resource/shuffle.png new file mode 100644 index 000000000..504234536 Binary files /dev/null and b/oap-native-sql/resource/shuffle.png differ