Skip to content

Commit

Permalink
backport fixes of #11573 to branch 24.10 (#11588)
Browse files Browse the repository at this point in the history
* avoid long tail tasks due to PrioritySemaphore (#11574)

* use task id as tie breaker

Signed-off-by: Hongbin Ma (Mahone) <mahongbin@apache.org>

* save threadlocal lookup

Signed-off-by: Hongbin Ma (Mahone) <mahongbin@apache.org>

---------

Signed-off-by: Hongbin Ma (Mahone) <mahongbin@apache.org>

* addressing jason's comment

Signed-off-by: Hongbin Ma (Mahone) <mahongbin@apache.org>

---------

Signed-off-by: Hongbin Ma (Mahone) <mahongbin@apache.org>
  • Loading branch information
binmahone authored Oct 11, 2024
1 parent b715ef2 commit ec9d008
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ object GpuSemaphore {
* this is considered to be okay as there are other mechanisms in place, and it should be rather
* rare.
*/
private final class SemaphoreTaskInfo() extends Logging {
private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging {
/**
* This holds threads that are not on the GPU yet. Most of the time they are
* blocked waiting for the semaphore to let them on, but it may hold one
Expand Down Expand Up @@ -253,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging {
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits, lastHeld)
semaphore.acquire(numPermits, lastHeld, taskAttemptId)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
Expand All @@ -280,15 +280,15 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def tryAcquire(semaphore: GpuBackingSemaphore): Boolean = synchronized {
def tryAcquire(semaphore: GpuBackingSemaphore, taskAttemptId: Long): Boolean = synchronized {
val t = Thread.currentThread()
if (hasSemaphore) {
activeThreads.add(t)
true
} else {
if (blockedThreads.size() == 0) {
// No other threads for this task are waiting, so we might be able to grab this directly
val ret = semaphore.tryAcquire(numPermits, lastHeld)
val ret = semaphore.tryAcquire(numPermits, lastHeld, taskAttemptId)
if (ret) {
hasSemaphore = true
activeThreads.add(t)
Expand Down Expand Up @@ -333,9 +333,9 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
if (taskInfo.tryAcquire(semaphore)) {
if (taskInfo.tryAcquire(semaphore, taskAttemptId)) {
GpuDeviceManager.initializeFromTask()
SemaphoreAcquired
} else {
Expand All @@ -357,7 +357,7 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
taskInfo.blockUntilReady(semaphore)
GpuDeviceManager.initializeFromTask()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,30 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
private val lock = new ReentrantLock()
private var occupiedSlots: Int = 0

private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int) {
private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int, taskId: Long) {
var signaled: Boolean = false
}

// use task id as tie breaker when priorities are equal (both are 0 because never hold lock)
private val priorityComp = Ordering.by[ThreadInfo, T](_.priority).reverse.
thenComparing((a, b) => a.taskId.compareTo(b.taskId))

// We expect a relatively small number of threads to be contending for this lock at any given
// time, therefore we are not concerned with the insertion/removal time complexity.
private val waitingQueue: PriorityQueue[ThreadInfo] =
new PriorityQueue[ThreadInfo](Ordering.by[ThreadInfo, T](_.priority).reverse)
new PriorityQueue[ThreadInfo](priorityComp)

def tryAcquire(numPermits: Int, priority: T): Boolean = {
def tryAcquire(numPermits: Int, priority: T, taskAttemptId: Long): Boolean = {
lock.lock()
try {
if (waitingQueue.size() > 0 && ordering.gt(waitingQueue.peek.priority, priority)) {
if (waitingQueue.size() > 0 &&
priorityComp.compare(
waitingQueue.peek(),
ThreadInfo(priority, null, numPermits, taskAttemptId)
) < 0) {
false
} else if (!canAcquire(numPermits)) {
}
else if (!canAcquire(numPermits)) {
false
} else {
commitAcquire(numPermits)
Expand All @@ -52,12 +61,12 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
}
}

def acquire(numPermits: Int, priority: T): Unit = {
def acquire(numPermits: Int, priority: T, taskAttemptId: Long): Unit = {
lock.lock()
try {
if (!tryAcquire(numPermits, priority)) {
if (!tryAcquire(numPermits, priority, taskAttemptId)) {
val condition = lock.newCondition()
val info = ThreadInfo(priority, condition, numPermits)
val info = ThreadInfo(priority, condition, numPermits, taskAttemptId)
try {
waitingQueue.add(info)
while (!info.signaled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ class PrioritySemaphoreSuite extends AnyFunSuite {
test("tryAcquire should return true if permits are available") {
val semaphore = new TestPrioritySemaphore(10)

assert(semaphore.tryAcquire(5, 0))
assert(semaphore.tryAcquire(3, 0))
assert(semaphore.tryAcquire(2, 0))
assert(!semaphore.tryAcquire(1, 0))
assert(semaphore.tryAcquire(5, 0, 0))
assert(semaphore.tryAcquire(3, 0, 0))
assert(semaphore.tryAcquire(2, 0, 0))
assert(!semaphore.tryAcquire(1, 0, 0))
}

test("acquire and release should work correctly") {
val semaphore = new TestPrioritySemaphore(1)

assert(semaphore.tryAcquire(1, 0))
assert(semaphore.tryAcquire(1, 0, 0))

val t = new Thread(() => {
try {
semaphore.acquire(1, 1)
semaphore.acquire(1, 1, 0)
fail("Should not acquire permit")
} catch {
case _: InterruptedException =>
semaphore.acquire(1, 1)
semaphore.acquire(1, 1, 0)
}
})
t.start()
Expand All @@ -62,7 +62,7 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

def taskWithPriority(priority: Int) = new Runnable {
override def run(): Unit = {
semaphore.acquire(1, priority)
semaphore.acquire(1, priority, 0)
results.add(priority)
semaphore.release(1)
}
Expand All @@ -84,20 +84,46 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

test("low priority thread cannot surpass high priority thread") {
val semaphore = new TestPrioritySemaphore(10)
semaphore.acquire(5, 0)
semaphore.acquire(5, 0, 0)
val t = new Thread(() => {
semaphore.acquire(10, 2)
semaphore.acquire(10, 2, 0)
semaphore.release(10)
})
t.start()
Thread.sleep(100)

// Here, there should be 5 available permits, but a thread with higher priority (2)
// is waiting to acquire, therefore we should get rejected here
assert(!semaphore.tryAcquire(5, 0))
assert(!semaphore.tryAcquire(5, 0, 0))
semaphore.release(5)
t.join(1000)
// After the high priority thread finishes, we can acquire with lower priority
assert(semaphore.tryAcquire(5, 0))
assert(semaphore.tryAcquire(5, 0, 0))
}

// this case is described at https://github.com/NVIDIA/spark-rapids/pull/11574/files#r1795652488
test("thread with larger task id should not surpass smaller task id in the waiting queue") {
val semaphore = new TestPrioritySemaphore(10)
semaphore.acquire(8, 0, 0)
val t = new Thread(() => {
semaphore.acquire(5, 0, 0)
semaphore.release(5)
})
t.start()
Thread.sleep(100)

// Here, there should be 2 available permits, and a thread with same task id (0)
// is waiting to acquire 5 permits, in this case we should succeed here
assert(semaphore.tryAcquire(2, 0, 0))
semaphore.release(2)

// Here, there should be 2 available permits, but a thread with smaller task id (0)
// is waiting to acquire, therefore we should get rejected here
assert(!semaphore.tryAcquire(2, 0, 1))

semaphore.release(8)
t.join(1000)
// After the high priority thread finishes, we can acquire with lower priority
assert(semaphore.tryAcquire(2, 0, 1))
}
}

0 comments on commit ec9d008

Please sign in to comment.