Skip to content

Commit

Permalink
Create a PrioritySemaphore to back the GpuSemaphore (#11376)
Browse files Browse the repository at this point in the history
* priority semaphore implementation and tests

Signed-off-by: Zach Puller <zpuller@nvidia.com>

---------

Signed-off-by: Zach Puller <zpuller@nvidia.com>
  • Loading branch information
zpuller authored Aug 22, 2024
1 parent 0f5254b commit 35d2163
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids

import java.util
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore}
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -183,6 +183,9 @@ private final class SemaphoreTaskInfo() extends Logging {
* If this task holds the GPU semaphore or not.
*/
private var hasSemaphore = false
private var lastHeld: Long = 0

type GpuBackingSemaphore = PrioritySemaphore[Long]

/**
* Does this task have the GPU semaphore or not. Be careful because it can change at
Expand Down Expand Up @@ -216,7 +219,7 @@ private final class SemaphoreTaskInfo() extends Logging {
* Block the current thread until we have the semaphore.
* @param semaphore what we are going to wait on.
*/
def blockUntilReady(semaphore: Semaphore): Unit = {
def blockUntilReady(semaphore: GpuBackingSemaphore): Unit = {
val t = Thread.currentThread()
// All threads start out in blocked, but will move out of it inside of the while loop.
synchronized {
Expand Down Expand Up @@ -250,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging {
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits)
semaphore.acquire(numPermits, lastHeld)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
Expand All @@ -277,7 +280,7 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def tryAcquire(semaphore: Semaphore): Boolean = synchronized {
def tryAcquire(semaphore: GpuBackingSemaphore): Boolean = synchronized {
val t = Thread.currentThread()
if (hasSemaphore) {
activeThreads.add(t)
Expand All @@ -299,12 +302,13 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
def releaseSemaphore(semaphore: GpuBackingSemaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
if (hasSemaphore) {
semaphore.release(numPermits)
hasSemaphore = false
lastHeld = System.currentTimeMillis()
}
// It should be impossible for the current thread to be blocked when releasing the semaphore
// because no blocked thread should ever leave `blockUntilReady`, which is where we put it in
Expand All @@ -317,7 +321,9 @@ private final class SemaphoreTaskInfo() extends Logging {

private final class GpuSemaphore() extends Logging {
import GpuSemaphore._
private val semaphore = new Semaphore(MAX_PERMITS)

type GpuBackingSemaphore = PrioritySemaphore[Long]
private val semaphore = new GpuBackingSemaphore(MAX_PERMITS)
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.util.concurrent.locks.{Condition, ReentrantLock}

import scala.collection.mutable.PriorityQueue

object PrioritySemaphore {
private val DEFAULT_MAX_PERMITS = 1000
}

class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) {
// This lock is used to generate condition variables, which affords us the flexibility to notify
// specific threads at a time. If we use the regular synchronized pattern, we have to either
// notify randomly, or if we try creating condition variables not tied to a shared lock, they
// won't work together properly, and we see things like deadlocks.
private val lock = new ReentrantLock()
private var occupiedSlots: Int = 0

private case class ThreadInfo(priority: T, condition: Condition)

// We expect a relatively small number of threads to be contending for this lock at any given
// time, therefore we are not concerned with the insertion/removal time complexity.
private val waitingQueue: PriorityQueue[ThreadInfo] = PriorityQueue()(Ordering.by(_.priority))

def this()(implicit ordering: Ordering[T]) = this(PrioritySemaphore.DEFAULT_MAX_PERMITS)(ordering)

def tryAcquire(numPermits: Int): Boolean = {
lock.lock()
try {
if (canAcquire(numPermits)) {
commitAcquire(numPermits)
true
} else {
false
}
} finally {
lock.unlock()
}
}

def acquire(numPermits: Int, priority: T): Unit = {
lock.lock()
try {
val condition = lock.newCondition()
while (!canAcquire(numPermits)) {
waitingQueue.enqueue(ThreadInfo(priority, condition))
condition.await()
}
commitAcquire(numPermits)

} finally {
lock.unlock()
}}

private def commitAcquire(numPermits: Int): Unit = {
occupiedSlots += numPermits
}

def release(numPermits: Int): Unit = {
lock.lock()
try {
occupiedSlots -= numPermits
if (waitingQueue.nonEmpty) {
val nextThread = waitingQueue.dequeue()
nextThread.condition.signal()
}
} finally {
lock.unlock()
}
}

private def canAcquire(numPermits: Int): Boolean = {
occupiedSlots + numPermits <= maxPermits
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.collection.JavaConverters._

import org.scalatest.funsuite.AnyFunSuite

class PrioritySemaphoreSuite extends AnyFunSuite {
type TestPrioritySemaphore = PrioritySemaphore[Long]

test("tryAcquire should return true if permits are available") {
val semaphore = new TestPrioritySemaphore(10)

assert(semaphore.tryAcquire(5))
assert(semaphore.tryAcquire(3))
assert(semaphore.tryAcquire(2))
assert(!semaphore.tryAcquire(1))
}

test("acquire and release should work correctly") {
val semaphore = new TestPrioritySemaphore(1)

assert(semaphore.tryAcquire(1))

val latch = new CountDownLatch(1)
val t = new Thread(() => {
try {
semaphore.acquire(1, 1)
fail("Should not acquire permit")
} catch {
case _: InterruptedException =>
semaphore.acquire(1, 1)
} finally {
latch.countDown()
}
})
t.start()

Thread.sleep(100)
t.interrupt()

semaphore.release(1)

latch.await(1, TimeUnit.SECONDS)
}

test("multiple threads should handle permits and priority correctly") {
val semaphore = new TestPrioritySemaphore(0)
val latch = new CountDownLatch(3)
val results = new java.util.ArrayList[Int]()

def taskWithPriority(priority: Int) = new Runnable {
override def run(): Unit = {
try {
semaphore.acquire(1, priority)
results.add(priority)
semaphore.release(1)
} finally {
latch.countDown()
}
}
}

new Thread(taskWithPriority(2)).start()
new Thread(taskWithPriority(1)).start()
new Thread(taskWithPriority(3)).start()

Thread.sleep(100)
semaphore.release(1)

latch.await(1, TimeUnit.SECONDS)
assert(results.asScala.toList == List(3, 2, 1))
}
}

0 comments on commit 35d2163

Please sign in to comment.