diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index 503909eccd5..68dbf84b0d4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import java.util -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -183,6 +183,9 @@ private final class SemaphoreTaskInfo() extends Logging { * If this task holds the GPU semaphore or not. */ private var hasSemaphore = false + private var lastHeld: Long = 0 + + type GpuBackingSemaphore = PrioritySemaphore[Long] /** * Does this task have the GPU semaphore or not. Be careful because it can change at @@ -216,7 +219,7 @@ private final class SemaphoreTaskInfo() extends Logging { * Block the current thread until we have the semaphore. * @param semaphore what we are going to wait on. */ - def blockUntilReady(semaphore: Semaphore): Unit = { + def blockUntilReady(semaphore: GpuBackingSemaphore): Unit = { val t = Thread.currentThread() // All threads start out in blocked, but will move out of it inside of the while loop. synchronized { @@ -250,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging { if (!done && shouldBlockOnSemaphore) { // We cannot be in a synchronized block and wait on the semaphore // so we have to release it and grab it again afterwards. - semaphore.acquire(numPermits) + semaphore.acquire(numPermits, lastHeld) synchronized { // We now own the semaphore so we need to wake up all of the other tasks that are // waiting. @@ -277,7 +280,7 @@ private final class SemaphoreTaskInfo() extends Logging { } } - def tryAcquire(semaphore: Semaphore): Boolean = synchronized { + def tryAcquire(semaphore: GpuBackingSemaphore): Boolean = synchronized { val t = Thread.currentThread() if (hasSemaphore) { activeThreads.add(t) @@ -299,12 +302,13 @@ private final class SemaphoreTaskInfo() extends Logging { } } - def releaseSemaphore(semaphore: Semaphore): Unit = synchronized { + def releaseSemaphore(semaphore: GpuBackingSemaphore): Unit = synchronized { val t = Thread.currentThread() activeThreads.remove(t) if (hasSemaphore) { semaphore.release(numPermits) hasSemaphore = false + lastHeld = System.currentTimeMillis() } // It should be impossible for the current thread to be blocked when releasing the semaphore // because no blocked thread should ever leave `blockUntilReady`, which is where we put it in @@ -317,7 +321,9 @@ private final class SemaphoreTaskInfo() extends Logging { private final class GpuSemaphore() extends Logging { import GpuSemaphore._ - private val semaphore = new Semaphore(MAX_PERMITS) + + type GpuBackingSemaphore = PrioritySemaphore[Long] + private val semaphore = new GpuBackingSemaphore(MAX_PERMITS) // Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo] diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala new file mode 100644 index 00000000000..ae2b0b362a1 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.util.concurrent.locks.{Condition, ReentrantLock} + +import scala.collection.mutable.PriorityQueue + +object PrioritySemaphore { + private val DEFAULT_MAX_PERMITS = 1000 +} + +class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) { + // This lock is used to generate condition variables, which affords us the flexibility to notify + // specific threads at a time. If we use the regular synchronized pattern, we have to either + // notify randomly, or if we try creating condition variables not tied to a shared lock, they + // won't work together properly, and we see things like deadlocks. + private val lock = new ReentrantLock() + private var occupiedSlots: Int = 0 + + private case class ThreadInfo(priority: T, condition: Condition) + + // We expect a relatively small number of threads to be contending for this lock at any given + // time, therefore we are not concerned with the insertion/removal time complexity. + private val waitingQueue: PriorityQueue[ThreadInfo] = PriorityQueue()(Ordering.by(_.priority)) + + def this()(implicit ordering: Ordering[T]) = this(PrioritySemaphore.DEFAULT_MAX_PERMITS)(ordering) + + def tryAcquire(numPermits: Int): Boolean = { + lock.lock() + try { + if (canAcquire(numPermits)) { + commitAcquire(numPermits) + true + } else { + false + } + } finally { + lock.unlock() + } + } + + def acquire(numPermits: Int, priority: T): Unit = { + lock.lock() + try { + val condition = lock.newCondition() + while (!canAcquire(numPermits)) { + waitingQueue.enqueue(ThreadInfo(priority, condition)) + condition.await() + } + commitAcquire(numPermits) + + } finally { + lock.unlock() + }} + + private def commitAcquire(numPermits: Int): Unit = { + occupiedSlots += numPermits + } + + def release(numPermits: Int): Unit = { + lock.lock() + try { + occupiedSlots -= numPermits + if (waitingQueue.nonEmpty) { + val nextThread = waitingQueue.dequeue() + nextThread.condition.signal() + } + } finally { + lock.unlock() + } + } + + private def canAcquire(numPermits: Int): Boolean = { + occupiedSlots + numPermits <= maxPermits + } + +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala new file mode 100644 index 00000000000..48ffc90440c --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import scala.collection.JavaConverters._ + +import org.scalatest.funsuite.AnyFunSuite + +class PrioritySemaphoreSuite extends AnyFunSuite { + type TestPrioritySemaphore = PrioritySemaphore[Long] + + test("tryAcquire should return true if permits are available") { + val semaphore = new TestPrioritySemaphore(10) + + assert(semaphore.tryAcquire(5)) + assert(semaphore.tryAcquire(3)) + assert(semaphore.tryAcquire(2)) + assert(!semaphore.tryAcquire(1)) + } + + test("acquire and release should work correctly") { + val semaphore = new TestPrioritySemaphore(1) + + assert(semaphore.tryAcquire(1)) + + val latch = new CountDownLatch(1) + val t = new Thread(() => { + try { + semaphore.acquire(1, 1) + fail("Should not acquire permit") + } catch { + case _: InterruptedException => + semaphore.acquire(1, 1) + } finally { + latch.countDown() + } + }) + t.start() + + Thread.sleep(100) + t.interrupt() + + semaphore.release(1) + + latch.await(1, TimeUnit.SECONDS) + } + + test("multiple threads should handle permits and priority correctly") { + val semaphore = new TestPrioritySemaphore(0) + val latch = new CountDownLatch(3) + val results = new java.util.ArrayList[Int]() + + def taskWithPriority(priority: Int) = new Runnable { + override def run(): Unit = { + try { + semaphore.acquire(1, priority) + results.add(priority) + semaphore.release(1) + } finally { + latch.countDown() + } + } + } + + new Thread(taskWithPriority(2)).start() + new Thread(taskWithPriority(1)).start() + new Thread(taskWithPriority(3)).start() + + Thread.sleep(100) + semaphore.release(1) + + latch.await(1, TimeUnit.SECONDS) + assert(results.asScala.toList == List(3, 2, 1)) + } +}