diff --git a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala index 340f8e53f..4aeafe87b 100644 --- a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala +++ b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala @@ -17,24 +17,65 @@ package org.apache.spark.sql.execution.datasources.v2.arrow +import java.io.PrintWriter import java.util import java.util.UUID import scala.collection.JavaConverters._ import com.intel.oap.spark.sql.execution.datasources.v2.arrow._ +import com.sun.xml.internal.messaging.saaj.util.ByteOutputStream import org.apache.arrow.dataset.jni.NativeMemoryPool +import org.apache.arrow.memory.AllocationListener +import org.apache.arrow.memory.AllocationOutcome +import org.apache.arrow.memory.AutoBufferLedger import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.memory.BufferLedger +import org.apache.arrow.memory.DirectAllocationListener +import org.apache.arrow.memory.ImmutableConfig +import org.apache.arrow.memory.LegacyBufferLedger import org.apache.arrow.memory.RootAllocator import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.TaskCompletionListener object SparkMemoryUtils extends Logging { + private val DEBUG: Boolean = false + class AllocationListenerList(listeners: AllocationListener *) + extends AllocationListener { + override def onPreAllocation(size: Long): Unit = { + listeners.foreach(_.onPreAllocation(size)) + } + + override def onAllocation(size: Long): Unit = { + listeners.foreach(_.onAllocation(size)) + } + + override def onRelease(size: Long): Unit = { + listeners.foreach(_.onRelease(size)) + } + + override def onFailedAllocation(size: Long, outcome: AllocationOutcome): Boolean = { + listeners.forall(_.onFailedAllocation(size, outcome)) + } + + override def onChildAdded(parentAllocator: BufferAllocator, + childAllocator: BufferAllocator): Unit = { + listeners.foreach(_.onChildAdded(parentAllocator, childAllocator)) + + } + + override def onChildRemoved(parentAllocator: BufferAllocator, + childAllocator: BufferAllocator): Unit = { + listeners.foreach(_.onChildRemoved(parentAllocator, childAllocator)) + } + } + class TaskMemoryResources { if (!inSparkTask()) { throw new IllegalStateException("Creating TaskMemoryResources instance out of Spark task") @@ -42,32 +83,70 @@ object SparkMemoryUtils extends Logging { val sharedMetrics = new NativeSQLMemoryMetrics() + val isArrowAutoReleaseEnabled: Boolean = { + SQLConf.get + .getConfString("spark.oap.sql.columnar.autorelease", "false").toBoolean + } + + val ledgerFactory: BufferLedger.Factory = if (isArrowAutoReleaseEnabled) { + AutoBufferLedger.newFactory() + } else { + LegacyBufferLedger.FACTORY + } + + val sparkManagedAllocationListener = new SparkManagedAllocationListener( + new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP), + sharedMetrics) + val directAllocationListener = DirectAllocationListener.INSTANCE + + val allocListener: AllocationListener = if (isArrowAutoReleaseEnabled) { + new AllocationListenerList(sparkManagedAllocationListener, directAllocationListener) + } else { + sparkManagedAllocationListener + } + + private def collectStackForDebug = { + if (DEBUG) { + val out = new ByteOutputStream() + val writer = new PrintWriter(out) + new Exception().printStackTrace(writer) + writer.close() + out.toString + } else { + null + } + } + + private val allocators = new util.ArrayList[BufferAllocator]() + + private val memoryPools = new util.ArrayList[NativeMemoryPoolWrapper]() + val defaultAllocator: BufferAllocator = { - val al = new SparkManagedAllocationListener( - new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP), - sharedMetrics) - new RootAllocator(al, Long.MaxValue) + val alloc = new RootAllocator(ImmutableConfig.builder() + .maxAllocation(Long.MaxValue) + .bufferLedgerFactory(ledgerFactory) + .listener(allocListener) + .build) + allocators.add(alloc) + alloc } val defaultMemoryPool: NativeMemoryPoolWrapper = { val rl = new SparkManagedReservationListener( new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP), sharedMetrics) - NativeMemoryPoolWrapper(NativeMemoryPool.createListenable(rl), rl) + val pool = NativeMemoryPoolWrapper(NativeMemoryPool.createListenable(rl), rl, + collectStackForDebug) + memoryPools.add(pool) + pool } - private val allocators = new util.ArrayList[BufferAllocator]() - allocators.add(defaultAllocator) - - private val memoryPools = new util.ArrayList[NativeMemoryPoolWrapper]() - memoryPools.add(defaultMemoryPool) - def createSpillableMemoryPool(spiller: Spiller): NativeMemoryPool = { val rl = new SparkManagedReservationListener( new NativeSQLMemoryConsumer(getTaskMemoryManager(), spiller), sharedMetrics) val pool = NativeMemoryPool.createListenable(rl) - memoryPools.add(NativeMemoryPoolWrapper(pool, rl)) + memoryPools.add(NativeMemoryPoolWrapper(pool, rl, collectStackForDebug)) pool } @@ -119,8 +198,12 @@ object SparkMemoryUtils extends Logging { } def release(): Unit = { + ledgerFactory match { + case closeable: AutoCloseable => + closeable.close() + case _ => + } for (allocator <- allocators.asScala.reverse) { - // reversed iterating: close children first val allocated = allocator.getAllocatedMemory if (allocated == 0L) { close(allocator) @@ -188,8 +271,15 @@ object SparkMemoryUtils extends Logging { } } + private val allocator = new RootAllocator( + ImmutableConfig.builder() + .maxAllocation(Long.MaxValue) + .bufferLedgerFactory(AutoBufferLedger.newFactory()) + .listener(DirectAllocationListener.INSTANCE) + .build) + def globalAllocator(): BufferAllocator = { - org.apache.spark.sql.util.ArrowUtils.rootAllocator + allocator } def globalMemoryPool(): NativeMemoryPool = { @@ -272,5 +362,5 @@ object SparkMemoryUtils extends Logging { } case class NativeMemoryPoolWrapper(pool: NativeMemoryPool, - listener: SparkManagedReservationListener) + listener: SparkManagedReservationListener, log: String = null) } diff --git a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala index 40a51fc30..5ad7596b9 100644 --- a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala +++ b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala @@ -47,6 +47,7 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { override protected def sparkConf: SparkConf = { val conf = super.sparkConf conf.set("spark.memory.offHeap.size", String.valueOf(10 * 1024 * 1024)) + conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") conf.set(SPARK_SESSION_EXTENSIONS.key, classOf[ArrowWriteExtension].getCanonicalName) conf } diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ArrowRecordBatchBuilderImpl.java b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ArrowRecordBatchBuilderImpl.java index 06116fed3..1341b858c 100644 --- a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ArrowRecordBatchBuilderImpl.java +++ b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ArrowRecordBatchBuilderImpl.java @@ -21,13 +21,10 @@ import java.util.ArrayList; import java.util.List; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.BufferLedger; -import org.apache.arrow.memory.NativeUnderlyingMemory; +import org.apache.arrow.memory.*; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; -import org.apache.arrow.memory.ArrowBuf; import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils; /** ArrowRecordBatchBuilderImpl used to wrap native returned data into an ArrowRecordBatch. */ @@ -65,8 +62,8 @@ public ArrowRecordBatch build() throws IOException { BufferAllocator allocator = SparkMemoryUtils.contextAllocator(); NativeUnderlyingMemory am = new Underlying(allocator, tmp.size, tmp.nativeInstanceId, tmp.memoryAddress); - BufferLedger ledger = am.associate(allocator); - buffers.add(new ArrowBuf(ledger, null, tmp.size, tmp.memoryAddress)); + ReferenceManager rm = am.createReferenceManager(allocator); + buffers.add(new ArrowBuf(rm, null, tmp.size, tmp.memoryAddress)); } try { return new ArrowRecordBatch(recordBatchBuilder.length, nodes, buffers); diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala index 7a4b85fcb..dc7265f2f 100644 --- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala +++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala @@ -68,7 +68,7 @@ class ColumnarArrowPythonRunner( context: TaskContext): Iterator[ColumnarBatch] = { new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + private val allocator = SparkMemoryUtils.globalAllocator().newChildAllocator( s"stdin reader for $pythonExec", 0, Long.MaxValue) private var reader: ArrowStreamReader = _ @@ -148,7 +148,7 @@ class ColumnarArrowPythonRunner( protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { var numRows: Long = 0 val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( + val allocator = SparkMemoryUtils.globalAllocator().newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index ba184f557..b1d6ef2ff 100644 --- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -24,14 +24,13 @@ import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils object ArrowUtils { - @Deprecated + @deprecated val rootAllocator = new RootAllocator(Long.MaxValue) // todo: support more types. diff --git a/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala b/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala index 379af32e2..528da8049 100644 --- a/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala +++ b/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala @@ -53,6 +53,8 @@ class TPCDSSuite extends QueryTest with SharedSparkSession { .set("spark.network.io.preferDirectBufs", "false") .set("spark.sql.sources.useV1SourceList", "arrow,parquet") .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set("spark.oap.sql.columnar.sortmergejoin.lazyread", "true") + .set("spark.oap.sql.columnar.autorelease", "true") return conf } @@ -96,8 +98,20 @@ class TPCDSSuite extends QueryTest with SharedSparkSession { runner.runTPCQuery("q1", 1, true) } + test("smj query 2") { + runner.runTPCQuery("q24a", 1, true) + } + + test("smj query 3") { + runner.runTPCQuery("q95", 1, true) + } + + test("q47") { + runner.runTPCQuery("q47", 1, true) + } + test("window function with non-decimal input") { - val df = spark.sql("SELECT i_item_sk, i_class_id, SUM(i_category_id)" + + val df = spark.sql("SELECT i_item_sk, i_clalss_id, SUM(i_category_id)" + " OVER (PARTITION BY i_class_id) FROM item LIMIT 1000") df.explain() df.show()