Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1145: Memory mapping with many small blocks can cause JVM allocation failures #43

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ object BlockFetcherIterator {
}

protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val maxRequestSize = math.max(maxBytesInFlight / 5, 1L)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why'd you change this to be called maxRequestSize? (as opposed to minRequestSize)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it's the maximum allowable request size and not the minimum.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you look at the code, this is used as a threshold where if the size gets larger than this it truncates the current request and creates a new one. I think the name was just wrong. Also even with my change it's sort of wrong because in fact the request size may be slightly larger than this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call this "targetRequestSize" or something? I think min is more correct than max -- because it always waits until requests are at least this big before truncating them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, note that the if statement below is if (curRequestSize >= maxRequestSize). In general you might have some requests bigger than this of one map's output is just super large. This threshold is just used to batch together small map outputs so we can grab more than one at once.

logInfo("maxBytesInFlight: " + maxBytesInFlight + ", maxRequestSize: " + maxRequestSize)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes you have atleast 5 nodes - which need not be the general case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was always bothered by this ... not specifically related to this PR btw !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, any reason to move this up ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only change here is to move this up. The reason is that there is no reason to calculate and log this in the inner loop, because its the same for all iterations of that loop. Logging it in the inner loop also means the same message is logged a bunch of times.

// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
Expand All @@ -157,11 +163,6 @@ object BlockFetcherIterator {
_numBlocksToFetch += localBlocksToFetch.size
} else {
numRemote += blockInfos.size
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
Expand All @@ -176,11 +177,12 @@ object BlockFetcherIterator {
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
if (curRequestSize >= maxRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
curBlocks = new ArrayBuffer[(BlockId, Long)]
logDebug(s"Creating fetch request of $curRequestSize at $address")
}
}
// Add in the final request
Expand All @@ -189,7 +191,7 @@ object BlockFetcherIterator {
}
}
}
logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
totalBlocks + " blocks")
remoteRequests
}
Expand Down Expand Up @@ -224,8 +226,8 @@ object BlockFetcherIterator {
sendRequest(fetchRequests.dequeue())
}

val numGets = remoteRequests.size - fetchRequests.size
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

// Get Local Blocks
startTime = System.currentTimeMillis
Expand Down Expand Up @@ -325,7 +327,7 @@ object BlockFetcherIterator {
}

copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6))
logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
logInfo("Started " + fetchRequestsSync.size + " remote fetches in " +
Utils.getUsedTimeMs(startTime))

// Get Local Blocks
Expand Down
25 changes: 20 additions & 5 deletions core/src/main/scala/org/apache/spark/storage/DiskStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,27 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val segment = diskManager.getBlockLocation(blockId)
val channel = new RandomAccessFile(segment.file, "r").getChannel()
val buffer = try {
channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
} finally {
channel.close()

val buffer =
// For small files, directly read rather than memory map
if (segment.length < 2 * 4096) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this 8k limit arrived at ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somewhat arbitrary, but related to the page size. If the segment is smaller than a couple of pages the fact that memory mapping aligns on page boundaries means a lot of wasted memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping it was obtained in a more principled manner ... ah well
in which case, atleast make it configurable - dont think anyone is going to muck around with it; but you never know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to limit the number of random configuration parameters - just because something could possibly have some minor performance gain if it's tweaked does not mean it should be exposed.

This just needs to be set to equal a small number of page sizes. Since memory maps are aligned on page boundaries, then on average you expect to have 1/2 of a page of extra memory use. If the block size is 2 pages or less, then this overhead is really significant. Sure, maybe if we made it 3 or 5 instead of 2, but the range of useful values here is small and I don't think any user would ever want to tune this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to think about it is this. This code path is really just to protect users who are doing something that already leads to very poor performance - shuffles where they use way too many partitions and the files are really tiny. The best thing a user can do there is to coalesce the number of mappers or reducers to get much better peformance. If they can figure that out, there is no reason to ever tune this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather have a hidden config param which can be tuned if required (however rarely), as opposed to have magic values which actually have no experimental basis.
The impact of this PR is not just in memory usage, but also in the IO characteristics of spark.

In particular, this is going to have significant performance impact if there are a large number of small files : which are fairly common - and in a lot of times, like in our experiments, aggressively aimed towards for the performance characterstics we are shooting for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would this have more impact than the current memory-mapping method? Both will lead to a bunch of seeks and filesystem operations, here when you read the blocks, and there when you try to send them over the network.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are performance implication between reading to vanilla buffer (note, this is not a direct buffer) and using memory mapped buffers - with added cost of some wastage for the 'last page'.
In addition, you always leave the option open for zero copy enhancement in spark.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so you're saying you might want to make the parameter higher than 8 KB? That makes sense, maybe it would be good to make it tunable. I thought you were saying that sometimes the patch here would create more IO load or use more memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually both : higher or lower (turn it off with 0 for example).
I am really not sure how this will interact with some of our experiments

val buf = ByteBuffer.allocate(segment.length.toInt)
try {
channel.read(buf, segment.offset)
}
finally {
channel.close()
}
Some(buf)
} else {
val buf = try {
channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
} finally {
channel.close()
}
Some(buf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you keep the finally outside of the if statement? just makes the code a little more concise and I think easier to read. i.e.,

val buf = try {
if (blah)
else (blah)
} finally {
channel.close()
}

}
Some(buffer)
buffer
}

override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,7 @@ private[spark] object Utils extends Logging {
}

/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
* millisecond.
* Return the string to tell how long has passed in milliseconds.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
return " " + (System.currentTimeMillis - startTimeMs) + " ms"
Expand Down