Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent duplicate queueing in the prio semaphore #11389

Merged
merged 9 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ private final class SemaphoreTaskInfo() extends Logging {
} else {
if (blockedThreads.size() == 0) {
// No other threads for this task are waiting, so we might be able to grab this directly
val ret = semaphore.tryAcquire(numPermits)
val ret = semaphore.tryAcquire(numPermits, lastHeld)
if (ret) {
hasSemaphore = true
activeThreads.add(t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@

package com.nvidia.spark.rapids

import java.util.PriorityQueue
import java.util.concurrent.locks.{Condition, ReentrantLock}

import scala.collection.mutable.PriorityQueue

object PrioritySemaphore {
private val DEFAULT_MAX_PERMITS = 1000
}

class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) {
// This lock is used to generate condition variables, which affords us the flexibility to notify
// specific threads at a time. If we use the regular synchronized pattern, we have to either
Expand All @@ -32,22 +27,25 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
private val lock = new ReentrantLock()
private var occupiedSlots: Int = 0

private case class ThreadInfo(priority: T, condition: Condition)
private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int) {
var signaled: Boolean = false
}

// We expect a relatively small number of threads to be contending for this lock at any given
// time, therefore we are not concerned with the insertion/removal time complexity.
private val waitingQueue: PriorityQueue[ThreadInfo] = PriorityQueue()(Ordering.by(_.priority))
private val waitingQueue: PriorityQueue[ThreadInfo] =
new PriorityQueue[ThreadInfo](Ordering.by[ThreadInfo, T](_.priority).reverse)

def this()(implicit ordering: Ordering[T]) = this(PrioritySemaphore.DEFAULT_MAX_PERMITS)(ordering)

def tryAcquire(numPermits: Int): Boolean = {
def tryAcquire(numPermits: Int, priority: T): Boolean = {
lock.lock()
try {
if (canAcquire(numPermits)) {
if (waitingQueue.size() > 0 && ordering.gt(waitingQueue.peek.priority, priority)) {
false
} else if (!canAcquire(numPermits)) {
false
} else {
commitAcquire(numPermits)
true
} else {
false
}
} finally {
lock.unlock()
Expand All @@ -57,16 +55,27 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
def acquire(numPermits: Int, priority: T): Unit = {
lock.lock()
try {
val condition = lock.newCondition()
while (!canAcquire(numPermits)) {
waitingQueue.enqueue(ThreadInfo(priority, condition))
condition.await()
if (!tryAcquire(numPermits, priority)) {
val condition = lock.newCondition()
val info = ThreadInfo(priority, condition, numPermits)
try {
waitingQueue.add(info)
while (!info.signaled) {
info.condition.await()
}
} catch {
case e: Exception =>
waitingQueue.remove(info)
if (info.signaled) {
release(numPermits)
}
throw e
}
}
commitAcquire(numPermits)

} finally {
lock.unlock()
}}
}
}

private def commitAcquire(numPermits: Int): Unit = {
occupiedSlots += numPermits
Expand All @@ -76,9 +85,19 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
lock.lock()
try {
occupiedSlots -= numPermits
if (waitingQueue.nonEmpty) {
val nextThread = waitingQueue.dequeue()
nextThread.condition.signal()
// acquire and wakeup for all threads that now have enough permits
var done = false
while (!done && waitingQueue.size() > 0) {
val nextThread = waitingQueue.peek()
if (canAcquire(nextThread.numPermits)) {
val popped = waitingQueue.poll()
assert(popped eq nextThread)
commitAcquire(nextThread.numPermits)
nextThread.signaled = true
nextThread.condition.signal()
} else {
done = true
}
}
} finally {
lock.unlock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package com.nvidia.spark.rapids

import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.collection.JavaConverters._

import org.scalatest.funsuite.AnyFunSuite
Expand All @@ -28,27 +26,24 @@ class PrioritySemaphoreSuite extends AnyFunSuite {
test("tryAcquire should return true if permits are available") {
val semaphore = new TestPrioritySemaphore(10)

assert(semaphore.tryAcquire(5))
assert(semaphore.tryAcquire(3))
assert(semaphore.tryAcquire(2))
assert(!semaphore.tryAcquire(1))
assert(semaphore.tryAcquire(5, 0))
assert(semaphore.tryAcquire(3, 0))
assert(semaphore.tryAcquire(2, 0))
assert(!semaphore.tryAcquire(1, 0))
}

test("acquire and release should work correctly") {
val semaphore = new TestPrioritySemaphore(1)

assert(semaphore.tryAcquire(1))
assert(semaphore.tryAcquire(1, 0))

val latch = new CountDownLatch(1)
val t = new Thread(() => {
try {
semaphore.acquire(1, 1)
fail("Should not acquire permit")
} catch {
case _: InterruptedException =>
semaphore.acquire(1, 1)
} finally {
latch.countDown()
}
})
t.start()
Expand All @@ -58,34 +53,51 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

semaphore.release(1)

latch.await(1, TimeUnit.SECONDS)
t.join(1000)
}

test("multiple threads should handle permits and priority correctly") {
val semaphore = new TestPrioritySemaphore(0)
val latch = new CountDownLatch(3)
val results = new java.util.ArrayList[Int]()

def taskWithPriority(priority: Int) = new Runnable {
override def run(): Unit = {
try {
semaphore.acquire(1, priority)
results.add(priority)
semaphore.release(1)
} finally {
latch.countDown()
}
semaphore.acquire(1, priority)
results.add(priority)
semaphore.release(1)
}
}

new Thread(taskWithPriority(2)).start()
new Thread(taskWithPriority(1)).start()
new Thread(taskWithPriority(3)).start()
val threads = List(
new Thread(taskWithPriority(2)),
new Thread(taskWithPriority(1)),
new Thread(taskWithPriority(3))
)
threads.foreach(_.start)

Thread.sleep(100)
semaphore.release(1)

latch.await(1, TimeUnit.SECONDS)
threads.foreach(_.join(1000))
assert(results.asScala.toList == List(3, 2, 1))
}

test("low priority thread cannot surpass high priority thread") {
val semaphore = new TestPrioritySemaphore(10)
semaphore.acquire(5, 0)
val t = new Thread(() => {
semaphore.acquire(10, 2)
semaphore.release(10)
})
t.start()
Thread.sleep(100)

// Here, there should be 5 available permits, but a thread with higher priority (2)
// is waiting to acquire, therefore we should get rejected here
assert(!semaphore.tryAcquire(5, 0))
semaphore.release(5)
t.join(1000)
// After the high priority thread finishes, we can acquire with lower priority
assert(semaphore.tryAcquire(5, 0))
}
}