Skip to content

Commit

Permalink
Introduce a trait that encapsulates host-backed ColumnarBath pieces
Browse files Browse the repository at this point in the history
  • Loading branch information
abellina committed Aug 28, 2023
1 parent d557ef1 commit ff6ab4c
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 108 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids

import java.io.File
import java.io.OutputStream
import java.nio.channels.WritableByteChannel

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -245,24 +245,12 @@ trait RapidsBuffer extends AutoCloseable {
def getCopyIterator: RapidsBufferCopyIterator =
new RapidsBufferCopyIterator(this)

/**
* At spill time, the tier we are spilling to may need to hand the rapids buffer an output stream
* to write to. This is the case for a `RapidsHostColumnarBatch`.
* @param outputStream stream that the `RapidsBuffer` will serialize itself to
*/
def serializeToStream(outputStream: OutputStream): Unit = {
throw new IllegalStateException(s"Buffer $this does not support serializeToStream")
}

/** Descriptor for how the memory buffer is formatted */
def meta: TableMeta

/** The storage tier for this buffer */
val storageTier: StorageTier

/** A RapidsBuffer that needs to be serialized/deserialized at spill/materialization time */
val needsSerialization: Boolean = false

/**
* Get the columnar batch within this buffer. The caller must have
* successfully acquired the buffer beforehand.
Expand Down Expand Up @@ -448,3 +436,30 @@ sealed class DegenerateRapidsBuffer(

override def close(): Unit = {}
}

trait RapidsHostBatchBuffer extends AutoCloseable {
/**
* Get the host-backed columnar batch from this buffer. The caller must have
* successfully acquired the buffer beforehand.
*
* If this `RapidsBuffer` was added originally to the device tier, or if this is
* a just a buffer (not a batch), this function will throw.
*
* @param sparkTypes the spark data types the batch should have
* @see [[addReference]]
* @note It is the responsibility of the caller to close the batch.
*/
def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch

def getMemoryUsedBytes(): Long
}

trait RapidsBufferChannelWritable {
/**
* At spill time, write this buffer to an nio WritableByteChannel.
* @param writableChannel that this buffer can just write itself to, either byte-for-byte
* or via serialization if needed.
* @return the amount of bytes written to the channel
*/
def writeToChannel(writableChannel: WritableByteChannel): Long
}
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,23 @@ class RapidsBufferCatalog(
throw new IllegalStateException(s"Unable to acquire buffer for ID: $id")
}

/**
* Acquires a RapidsBuffer that the caller expects to be host-backed and not
* device bound. This ensures that the buffer acquired implements the correct
* trait, otherwise it throws and removes its buffer acquisition.
*
* @param handle handle associated with this `RapidsBuffer`
* @return host-backed RapidsBuffer that has been acquired
*/
def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = {
closeOnExcept(acquireBuffer(handle)) {
case hrb: RapidsHostBatchBuffer => hrb
case other =>
throw new IllegalStateException(
s"Attempted to acquire a RapidsHostBatchBuffer, but got $other instead")
}
}

/**
* Lookup the buffer that corresponds to the specified buffer ID at the specified storage tier,
* and acquire it.
Expand Down Expand Up @@ -940,6 +957,17 @@ object RapidsBufferCatalog extends Logging {
def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer =
singleton.acquireBuffer(handle)

/**
* Acquires a RapidsBuffer that the caller expects to be host-backed and not
* device bound. This ensures that the buffer acquired implements the correct
* trait, otherwise it throws and removes its buffer acquisition.
*
* @param handle handle associated with this `RapidsBuffer`
* @return host-backed RapidsBuffer that has been acquired
*/
def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer =
singleton.acquireHostBatchBuffer(handle)

def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,86 +49,62 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
val (fileOffset, diskLength) = if (id.canShareDiskPaths) {
// only one writer at a time for now when using shared files
path.synchronized {
if (incoming.needsSerialization) {
// only shuffle buffers share paths, and adding host-backed ColumnarBatch is
// not the shuffle case, so this is not supported and is never exercised.
throw new IllegalStateException(
s"Attempted spilling to disk a RapidsBuffer $incoming that needs serialization " +
s"while sharing spill paths.")
} else {
copyBufferToPath(incoming, path, append = true)
}
writeToFile(incoming, path, append = true)
}
} else {
if (incoming.needsSerialization) {
serializeBufferToStream(incoming, path)
} else {
copyBufferToPath(incoming, path, append = false)
}
writeToFile(incoming, path, append = false)
}

logDebug(s"Spilled to $path $fileOffset:$diskLength")
new RapidsDiskBuffer(
id,
fileOffset,
diskLength,
incoming.meta,
incoming.getSpillPriority,
incoming.needsSerialization)
incoming match {
case _: RapidsHostBatchBuffer =>
new RapidsDiskColumnarBatch(
id,
fileOffset,
diskLength,
incoming.meta,
incoming.getSpillPriority)

case _ =>
new RapidsDiskBuffer(
id,
fileOffset,
diskLength,
incoming.meta,
incoming.getSpillPriority)
}
}

/** Copy a host buffer to a file, returning the file offset at which the data was written. */
private def copyBufferToPath(
private def writeToFile(
incoming: RapidsBuffer,
path: File,
append: Boolean): (Long, Long) = {
val incomingBuffer =
withResource(incoming.getCopyIterator) { incomingCopyIterator =>
incomingCopyIterator.next()
}
withResource(incomingBuffer) { _ =>
val hostBuffer = incomingBuffer match {
case h: HostMemoryBuffer => h
case _ => throw new UnsupportedOperationException("buffer without host memory")
}
val iter = new HostByteBufferIterator(hostBuffer)
val fos = new FileOutputStream(path, append)
try {
val channel = fos.getChannel
val fileOffset = channel.position
iter.foreach { bb =>
while (bb.hasRemaining) {
channel.write(bb)
incoming match {
case fileWritable: RapidsBufferChannelWritable =>
withResource(new FileOutputStream(path, append)) { fos =>
withResource(fos.getChannel) { outputChannel =>
val startOffset = outputChannel.position()
val writtenBytes = fileWritable.writeToChannel(outputChannel)
(startOffset, writtenBytes)
}
}
(fileOffset, channel.position())
} finally {
fos.close()
}
}
}

private def serializeBufferToStream(
incoming: RapidsBuffer,
path: File): (Long, Long) = {
withResource(new FileOutputStream(path, false /*append not supported*/)) { fos =>
withResource(fos.getChannel) { outputChannel =>
val startOffset = outputChannel.position()
incoming.serializeToStream(fos)
val endingOffset = outputChannel.position()
val writtenBytes = endingOffset - startOffset
(startOffset, writtenBytes)
}
case other =>
throw new IllegalStateException(
s"Unable to write $other to file")
}
}

/**
* A RapidsDiskBuffer that is mean to represent device-bound memory. This
* buffer can produce a device-backed ColumnarBatch.
*/
class RapidsDiskBuffer(
id: RapidsBufferId,
fileOffset: Long,
size: Long,
meta: TableMeta,
spillPriority: Long,
override val needsSerialization: Boolean = false)
spillPriority: Long)
extends RapidsBufferBase(
id, meta, spillPriority) {
private[this] var hostBuffer: Option[HostMemoryBuffer] = None
Expand All @@ -138,8 +114,6 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
override val storageTier: StorageTier = StorageTier.DISK

override def getMemoryBuffer: MemoryBuffer = synchronized {
require(!needsSerialization,
"Called getMemoryBuffer on a disk buffer that needs deserialization")
if (hostBuffer.isEmpty) {
val path = id.getDiskPath(diskBlockManager)
val mappedBuffer = HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE,
Expand All @@ -151,22 +125,6 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
hostBuffer.get
}

override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
require(needsSerialization,
"Disk buffer was not serialized yet getHostColumnarBatch is being invoked")
require(fileOffset == 0,
"Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " +
"paths on disk")
val path = id.getDiskPath(diskBlockManager)
withResource(new FileInputStream(path)) { fis =>
val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis)
val hostCols = closeOnExcept(hostBuffer) { _ =>
SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes)
}
new ColumnarBatch(hostCols.toArray, header.getNumRows)
}
}

override def close(): Unit = synchronized {
if (refcount == 1) {
// free the memory mapping since this is the last active reader
Expand All @@ -193,4 +151,43 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
}
}
}

/**
* A RapidsDiskBuffer that should remain in the host, producing host-backed
* ColumnarBatch if the caller invokes getHostColumnarBatch, but not producing
* anything on the device.
*/
class RapidsDiskColumnarBatch(
id: RapidsBufferId,
fileOffset: Long,
size: Long,
// TODO: remove meta
meta: TableMeta,
spillPriority: Long)
extends RapidsDiskBuffer(
id, fileOffset, size, meta, spillPriority)
with RapidsHostBatchBuffer {

override def getMemoryBuffer: MemoryBuffer =
throw new IllegalStateException(
"Called getMemoryBuffer on a disk buffer that needs deserialization")

override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch =
throw new IllegalStateException(
"Called getColumnarBatch on a disk buffer that needs deserialization")

override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
require(fileOffset == 0,
"Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " +
"paths on disk")
val path = id.getDiskPath(diskBlockManager)
withResource(new FileInputStream(path)) { fis =>
val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis)
val hostCols = closeOnExcept(hostBuffer) { _ =>
SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes)
}
new ColumnarBatch(hostCols.toArray, header.getNumRows)
}
}
}
}
Loading

0 comments on commit ff6ab4c

Please sign in to comment.