Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-465] Release JVM Arrow memory using GC (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Sep 1, 2021
1 parent d6bc791 commit 856f19f
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,136 @@

package org.apache.spark.sql.execution.datasources.v2.arrow

import java.io.PrintWriter
import java.util
import java.util.UUID

import scala.collection.JavaConverters._

import com.intel.oap.spark.sql.execution.datasources.v2.arrow._
import com.sun.xml.internal.messaging.saaj.util.ByteOutputStream
import org.apache.arrow.dataset.jni.NativeMemoryPool
import org.apache.arrow.memory.AllocationListener
import org.apache.arrow.memory.AllocationOutcome
import org.apache.arrow.memory.AutoBufferLedger
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.memory.BufferLedger
import org.apache.arrow.memory.DirectAllocationListener
import org.apache.arrow.memory.ImmutableConfig
import org.apache.arrow.memory.LegacyBufferLedger
import org.apache.arrow.memory.RootAllocator

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.TaskCompletionListener

object SparkMemoryUtils extends Logging {

private val DEBUG: Boolean = false

class AllocationListenerList(listeners: AllocationListener *)
extends AllocationListener {
override def onPreAllocation(size: Long): Unit = {
listeners.foreach(_.onPreAllocation(size))
}

override def onAllocation(size: Long): Unit = {
listeners.foreach(_.onAllocation(size))
}

override def onRelease(size: Long): Unit = {
listeners.foreach(_.onRelease(size))
}

override def onFailedAllocation(size: Long, outcome: AllocationOutcome): Boolean = {
listeners.forall(_.onFailedAllocation(size, outcome))
}

override def onChildAdded(parentAllocator: BufferAllocator,
childAllocator: BufferAllocator): Unit = {
listeners.foreach(_.onChildAdded(parentAllocator, childAllocator))

}

override def onChildRemoved(parentAllocator: BufferAllocator,
childAllocator: BufferAllocator): Unit = {
listeners.foreach(_.onChildRemoved(parentAllocator, childAllocator))
}
}

class TaskMemoryResources {
if (!inSparkTask()) {
throw new IllegalStateException("Creating TaskMemoryResources instance out of Spark task")
}

val sharedMetrics = new NativeSQLMemoryMetrics()

val isArrowAutoReleaseEnabled: Boolean = {
SQLConf.get
.getConfString("spark.oap.sql.columnar.autorelease", "false").toBoolean
}

val ledgerFactory: BufferLedger.Factory = if (isArrowAutoReleaseEnabled) {
AutoBufferLedger.newFactory()
} else {
LegacyBufferLedger.FACTORY
}

val sparkManagedAllocationListener = new SparkManagedAllocationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP),
sharedMetrics)
val directAllocationListener = DirectAllocationListener.INSTANCE

val allocListener: AllocationListener = if (isArrowAutoReleaseEnabled) {
new AllocationListenerList(sparkManagedAllocationListener, directAllocationListener)
} else {
sparkManagedAllocationListener
}

private def collectStackForDebug = {
if (DEBUG) {
val out = new ByteOutputStream()
val writer = new PrintWriter(out)
new Exception().printStackTrace(writer)
writer.close()
out.toString
} else {
null
}
}

private val allocators = new util.ArrayList[BufferAllocator]()

private val memoryPools = new util.ArrayList[NativeMemoryPoolWrapper]()

val defaultAllocator: BufferAllocator = {
val al = new SparkManagedAllocationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP),
sharedMetrics)
new RootAllocator(al, Long.MaxValue)
val alloc = new RootAllocator(ImmutableConfig.builder()
.maxAllocation(Long.MaxValue)
.bufferLedgerFactory(ledgerFactory)
.listener(allocListener)
.build)
allocators.add(alloc)
alloc
}

val defaultMemoryPool: NativeMemoryPoolWrapper = {
val rl = new SparkManagedReservationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP),
sharedMetrics)
NativeMemoryPoolWrapper(NativeMemoryPool.createListenable(rl), rl)
val pool = NativeMemoryPoolWrapper(NativeMemoryPool.createListenable(rl), rl,
collectStackForDebug)
memoryPools.add(pool)
pool
}

private val allocators = new util.ArrayList[BufferAllocator]()
allocators.add(defaultAllocator)

private val memoryPools = new util.ArrayList[NativeMemoryPoolWrapper]()
memoryPools.add(defaultMemoryPool)

def createSpillableMemoryPool(spiller: Spiller): NativeMemoryPool = {
val rl = new SparkManagedReservationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), spiller),
sharedMetrics)
val pool = NativeMemoryPool.createListenable(rl)
memoryPools.add(NativeMemoryPoolWrapper(pool, rl))
memoryPools.add(NativeMemoryPoolWrapper(pool, rl, collectStackForDebug))
pool
}

Expand Down Expand Up @@ -119,8 +198,12 @@ object SparkMemoryUtils extends Logging {
}

def release(): Unit = {
ledgerFactory match {
case closeable: AutoCloseable =>
closeable.close()
case _ =>
}
for (allocator <- allocators.asScala.reverse) {
// reversed iterating: close children first
val allocated = allocator.getAllocatedMemory
if (allocated == 0L) {
close(allocator)
Expand Down Expand Up @@ -188,8 +271,15 @@ object SparkMemoryUtils extends Logging {
}
}

private val allocator = new RootAllocator(
ImmutableConfig.builder()
.maxAllocation(Long.MaxValue)
.bufferLedgerFactory(AutoBufferLedger.newFactory())
.listener(DirectAllocationListener.INSTANCE)
.build)

def globalAllocator(): BufferAllocator = {
org.apache.spark.sql.util.ArrowUtils.rootAllocator
allocator
}

def globalMemoryPool(): NativeMemoryPool = {
Expand Down Expand Up @@ -272,5 +362,5 @@ object SparkMemoryUtils extends Logging {
}

case class NativeMemoryPoolWrapper(pool: NativeMemoryPool,
listener: SparkManagedReservationListener)
listener: SparkManagedReservationListener, log: String = null)
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession {
override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
conf.set("spark.memory.offHeap.size", String.valueOf(10 * 1024 * 1024))
conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
conf.set(SPARK_SESSION_EXTENSIONS.key, classOf[ArrowWriteExtension].getCanonicalName)
conf
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@
import java.util.ArrayList;
import java.util.List;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.BufferLedger;
import org.apache.arrow.memory.NativeUnderlyingMemory;
import org.apache.arrow.memory.*;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;

import org.apache.arrow.memory.ArrowBuf;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;

/** ArrowRecordBatchBuilderImpl used to wrap native returned data into an ArrowRecordBatch. */
Expand Down Expand Up @@ -65,8 +62,8 @@ public ArrowRecordBatch build() throws IOException {
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
NativeUnderlyingMemory am = new Underlying(allocator, tmp.size,
tmp.nativeInstanceId, tmp.memoryAddress);
BufferLedger ledger = am.associate(allocator);
buffers.add(new ArrowBuf(ledger, null, tmp.size, tmp.memoryAddress));
ReferenceManager rm = am.createReferenceManager(allocator);
buffers.add(new ArrowBuf(rm, null, tmp.size, tmp.memoryAddress));
}
try {
return new ArrowRecordBatch(recordBatchBuilder.length, nodes, buffers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ColumnarArrowPythonRunner(
context: TaskContext): Iterator[ColumnarBatch] = {

new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
private val allocator = SparkMemoryUtils.globalAllocator().newChildAllocator(
s"stdin reader for $pythonExec", 0, Long.MaxValue)

private var reader: ArrowStreamReader = _
Expand Down Expand Up @@ -148,7 +148,7 @@ class ColumnarArrowPythonRunner(
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
var numRows: Long = 0
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
val allocator = SparkMemoryUtils.globalAllocator().newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils

object ArrowUtils {

@Deprecated
@deprecated
val rootAllocator = new RootAllocator(Long.MaxValue)

// todo: support more types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
.set("spark.network.io.preferDirectBufs", "false")
.set("spark.sql.sources.useV1SourceList", "arrow,parquet")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.oap.sql.columnar.sortmergejoin.lazyread", "true")
.set("spark.oap.sql.columnar.autorelease", "true")
return conf
}

Expand Down Expand Up @@ -96,8 +98,20 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
runner.runTPCQuery("q1", 1, true)
}

test("smj query 2") {
runner.runTPCQuery("q24a", 1, true)
}

test("smj query 3") {
runner.runTPCQuery("q95", 1, true)
}

test("q47") {
runner.runTPCQuery("q47", 1, true)
}

test("window function with non-decimal input") {
val df = spark.sql("SELECT i_item_sk, i_class_id, SUM(i_category_id)" +
val df = spark.sql("SELECT i_item_sk, i_clalss_id, SUM(i_category_id)" +
" OVER (PARTITION BY i_class_id) FROM item LIMIT 1000")
df.explain()
df.show()
Expand Down

0 comments on commit 856f19f

Please sign in to comment.