Skip to content

Commit

Permalink
Make sure tryToAcquire won't return a negative value
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Dec 3, 2014
1 parent 17c162f commit a193ae6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
val curMem = threadMemory(threadId)
val freeMemory = maxMemory - threadMemory.values.sum

// How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads
val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem)
// How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
// don't let it be negative
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))

if (curMem < maxMemory / (2 * numActiveThreads)) {
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class ShuffleMemoryManagerSuite extends FunSuite with Timeouts {

test("threads can block to get at least 1 / 2N memory") {
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
// for a bit and releases 250 bytes, which should then be greanted to t2. Further requests
// for a bit and releases 250 bytes, which should then be granted to t2. Further requests
// by t2 will return false right away because it now has 1 / 2N of the memory.

val manager = new ShuffleMemoryManager(1000L)
Expand Down Expand Up @@ -291,4 +291,19 @@ class ShuffleMemoryManagerSuite extends FunSuite with Timeouts {
assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
}
}

test("threads should not be granted a negative size") {
val manager = new ShuffleMemoryManager(1000L)
manager.tryToAcquire(700L)

val latch = new CountDownLatch(1)
startThread("t1") {
manager.tryToAcquire(300L)
latch.countDown()
}
latch.await() // Wait until `t1` calls `tryToAcquire`

val granted = manager.tryToAcquire(300L)
assert(0 === granted, "granted is negative")
}
}

0 comments on commit a193ae6

Please sign in to comment.