From 931cdafb231ae10529dfd53f9df1c5802cdfd07d Mon Sep 17 00:00:00 2001 From: philo Date: Thu, 2 Jun 2022 17:10:39 +0800 Subject: [PATCH] Replace SortAggregate at columnar override --- .../UnsafeFixedWidthAggregationMap.java | 265 ------------------ .../oap/extension/ColumnarOverrides.scala | 14 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 5 +- 3 files changed, 13 insertions(+), 271 deletions(-) delete mode 100644 native-sql-engine/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java diff --git a/native-sql-engine/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/native-sql-engine/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java deleted file mode 100644 index 1889cd26c..000000000 --- a/native-sql-engine/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; - -import com.intel.oap.GazellePluginConfig; -import org.apache.spark.SparkEnv; -import org.apache.spark.TaskContext; -import org.apache.spark.internal.config.package$; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.map.BytesToBytesMap; - -import static org.apache.spark.sql.types.DataTypes.StringType; - -/** - * This file is copied from spark with few changes for supportsAggregationBufferSchema. - */ - -/** - * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. - * - * This map supports a maximum of 2 billion keys. - */ -public final class UnsafeFixedWidthAggregationMap { - - /** - * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the - * map, we copy this buffer and use it as the value. - */ - private final byte[] emptyAggregationBuffer; - - private final StructType aggregationBufferSchema; - - private final StructType groupingKeySchema; - - /** - * Encodes grouping keys as UnsafeRows. - */ - private final UnsafeProjection groupingKeyProjection; - - /** - * A hashmap which maps from opaque bytearray keys to bytearray values. - */ - private final BytesToBytesMap map; - - /** - * Re-used pointer to the current aggregation buffer - */ - private final UnsafeRow currentAggregationBuffer; - - /** - * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given - * schema, false otherwise. - */ - public static boolean supportsAggregationBufferSchema(StructType schema) { - GazellePluginConfig columnarConf = GazellePluginConfig.getSessionConf(); - for (StructField field: schema.fields()) { - if (!UnsafeRow.isMutable(field.dataType())) { - // To force going into hash agg for StringType input. Because sort agg is not supported - // in Gazelle. - if (columnarConf.enableHashAggForStringType() && field.dataType() == StringType) { - continue; - } - return false; - } - } - return true; - } - - /** - * Create a new UnsafeFixedWidthAggregationMap. - * - * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) - * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. - * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskContext the current task context. - * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). - * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. - */ - public UnsafeFixedWidthAggregationMap( - InternalRow emptyAggregationBuffer, - StructType aggregationBufferSchema, - StructType groupingKeySchema, - TaskContext taskContext, - int initialCapacity, - long pageSizeBytes) { - this.aggregationBufferSchema = aggregationBufferSchema; - this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); - this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); - this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap( - taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes); - - // Initialize the buffer for aggregation value - final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); - this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); - - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at - // the end of the task. This is necessary to avoid memory leaks in when the downstream operator - // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). - taskContext.addTaskCompletionListener(context -> { - free(); - }); - } - - /** - * Return the aggregation buffer for the current group. For efficiency, all calls to this method - * return the same object. If additional memory could not be allocated, then this method will - * signal an error by returning null. - */ - public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); - - return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); - } - - public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) { - return getAggregationBufferFromUnsafeRow(key, key.hashCode()); - } - - public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { - // Probe our map using the serialized key - final BytesToBytesMap.Location loc = map.lookup( - key.getBaseObject(), - key.getBaseOffset(), - key.getSizeInBytes(), - hash); - if (!loc.isDefined()) { - // This is the first time that we've seen this grouping key, so we'll insert a copy of the - // empty aggregation buffer into the map: - boolean putSucceeded = loc.append( - key.getBaseObject(), - key.getBaseOffset(), - key.getSizeInBytes(), - emptyAggregationBuffer, - Platform.BYTE_ARRAY_OFFSET, - emptyAggregationBuffer.length - ); - if (!putSucceeded) { - return null; - } - } - - // Reset the pointer to point to the value that we just stored or looked up: - currentAggregationBuffer.pointTo( - loc.getValueBase(), - loc.getValueOffset(), - loc.getValueLength() - ); - return currentAggregationBuffer; - } - - /** - * Returns an iterator over the keys and values in this map. This uses destructive iterator of - * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has - * been called. - * - * For efficiency, each call returns the same object. - */ - public KVIterator iterator() { - return new KVIterator() { - - private final BytesToBytesMap.MapIterator mapLocationIterator = - map.destructiveIterator(); - private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length()); - private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length()); - - @Override - public boolean next() { - if (mapLocationIterator.hasNext()) { - final BytesToBytesMap.Location loc = mapLocationIterator.next(); - key.pointTo( - loc.getKeyBase(), - loc.getKeyOffset(), - loc.getKeyLength() - ); - value.pointTo( - loc.getValueBase(), - loc.getValueOffset(), - loc.getValueLength() - ); - return true; - } else { - return false; - } - } - - @Override - public UnsafeRow getKey() { - return key; - } - - @Override - public UnsafeRow getValue() { - return value; - } - - @Override - public void close() { - // Do nothing. - } - }; - } - - /** - * Return the peak memory used so far, in bytes. - */ - public long getPeakMemoryUsedBytes() { - return map.getPeakMemoryUsedBytes(); - } - - /** - * Free the memory associated with this map. This is idempotent and can be called multiple times. - */ - public void free() { - map.free(); - } - - /** - * Gets the average bucket list iterations per lookup in the underlying `BytesToBytesMap`. - */ - public double getAvgHashProbeBucketListIterations() { - return map.getAvgHashProbeBucketListIterations(); - } - - /** - * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] - * - * Note that the map will be reset for inserting new records, and the returned sorter can NOT be - * used to insert records. - */ - public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { - return new UnsafeKVExternalSorter( - groupingKeySchema, - aggregationBufferSchema, - SparkEnv.get().blockManager(), - SparkEnv.get().serializerManager(), - map.getPageSizeBytes(), - (int) SparkEnv.get().conf().get( - package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), - map); - } -} diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala index d10bbeebe..3dae2df29 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.ShufflePartitionSpec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{ShuffleStageInfo, _} -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.exchange._ @@ -50,7 +50,6 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, ColumnarArrowEvalPythonExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf - import org.apache.spark.util.ShufflePartitionUtils import scala.collection.mutable @@ -122,6 +121,17 @@ case class ColumnarPreOverrides(session: SparkSession) extends Rule[SparkPlan] { plan.initialInputBufferOffset, plan.resultExpressions, child) + case plan: SortAggregateExec if (columnarConf.enableHashAggForStringType) => + val child = replaceWithColumnarPlan(plan.child) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) case plan: UnionExec => val children = plan.children.map(replaceWithColumnarPlan) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 27ac51c1a..97e5c1148 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution import java.util.Properties -import com.intel.oap.GazellePluginConfig - import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal @@ -109,8 +107,7 @@ class UnsafeFixedWidthAggregationMapSuite StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) - // Force to use hash agg for string type. - assert(supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) assert( !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) }