Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose host store spill #9189

Merged
merged 4 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ class RapidsBufferCatalog(
* Free memory in `store` by spilling buffers to the spill store synchronously.
* @param store store to spill from
* @param targetTotalSize maximum total size of this store after spilling completes
* @param stream CUDA stream to use or null for default stream
* @return optionally number of bytes that were spilled, or None if this called
* @param stream CUDA stream to use or omit for default stream
* @return optionally number of bytes that were spilled, or None if this call
* made no attempt to spill due to a detected spill race
*/
def synchronousSpill(
Expand Down Expand Up @@ -806,6 +806,26 @@ object RapidsBufferCatalog extends Logging {
deviceStorage = rdms
}

/**
* Set a `RapidsDiskStore` instance to use when instantiating our
* catalog.
*
* @note This should only be called from tests!
*/
def setDiskStorage(rdms: RapidsDiskStore): Unit = {
diskStorage = rdms
}

/**
* Set a `RapidsHostMemoryStore` instance to use when instantiating our
* catalog.
*
* @note This should only be called from tests!
*/
def setHostStorage(rhms: RapidsHostMemoryStore): Unit = {
hostStorage = rhms
}

/**
* Set a `RapidsBufferCatalog` instance to use our singleton.
* @note This should only be called from tests!
Expand Down Expand Up @@ -908,6 +928,8 @@ object RapidsBufferCatalog extends Logging {

def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage

def getHostStorage: RapidsHostMemoryStore = hostStorage

def shouldUnspill: Boolean = _shouldUnspill

/**
Expand Down Expand Up @@ -968,6 +990,21 @@ object RapidsBufferCatalog extends Logging {

def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager

/**
* Free memory in `store` by spilling buffers to its spill store synchronously.
* @param store store to spill from
* @param targetTotalSize maximum total size of this store after spilling completes
* @param stream CUDA stream to use or omit for default stream
* @return optionally number of bytes that were spilled, or None if this call
* made no attempt to spill due to a detected spill race
*/
def synchronousSpill(
store: RapidsBufferStore,
targetTotalSize: Long,
stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Option[Long] = {
singleton.synchronousSpill(store, targetTotalSize, stream)
}

/**
* Given a `MemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated
* with it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,65 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar {
}
}

test("get memory buffer after host spill") {
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType,
DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5))
val spillPriority = -10
val hostStoreMaxSize = 1L * 1024 * 1024
try {
val bm = new RapidsDiskBlockManager(new SparkConf())
val (catalog, devStore, hostStore, diskStore) =
closeOnExcept(new RapidsDiskStore(bm)) { diskStore =>
closeOnExcept(new RapidsDeviceMemoryStore()) { devStore =>
closeOnExcept(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore =>
devStore.setSpillStore(hostStore)
hostStore.setSpillStore(diskStore)
val catalog = closeOnExcept(
new RapidsBufferCatalog(devStore, hostStore)) { catalog => catalog }
(catalog, devStore, hostStore, diskStore)
}
}
}

RapidsBufferCatalog.setDeviceStorage(devStore)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to close any previous RapidsBufferCatalog state that may have been leftover from a previous test before clobbering its settings here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that seems like a good idea. Will do.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RapidsBufferCatalog.setHostStorage(hostStore)
RapidsBufferCatalog.setDiskStorage(diskStore)
RapidsBufferCatalog.setCatalog(catalog)

var expectedBatch: ColumnarBatch = null
val handle = withResource(buildContiguousTable()) { ct =>
// make a copy of the table so we can compare it later to the
// one reconstituted after the spill
withResource(ct.getTable.contiguousSplit()) { copied =>
expectedBatch = GpuColumnVector.from(copied(0).getTable, sparkTypes)
}
RapidsBufferCatalog.addContiguousTable(
ct,
spillPriority)
}
withResource(expectedBatch) { _ =>
val spilledToHost =
RapidsBufferCatalog.synchronousSpill(
RapidsBufferCatalog.getDeviceStorage, 0)
assert(spilledToHost.isDefined && spilledToHost.get > 0)

val spilledToDisk =
RapidsBufferCatalog.synchronousSpill(
RapidsBufferCatalog.getHostStorage, 0)
assert(spilledToDisk.isDefined && spilledToDisk.get > 0)

withResource(RapidsBufferCatalog.acquireBuffer(handle)) { buffer =>
assertResult(StorageTier.DISK)(buffer.storageTier)
withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch =>
TestUtils.compareBatches(expectedBatch, actualBatch)
}
}
}
} finally {
RapidsBufferCatalog.close()
}
}

test("host buffer originated: get host memory buffer") {
val spillPriority = -10
val hostStoreMaxSize = 1L * 1024 * 1024
Expand Down