-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
avoid long tail tasks due to PrioritySemaphore #11574
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,8 @@ package com.nvidia.spark.rapids | |
import java.util.PriorityQueue | ||
import java.util.concurrent.locks.{Condition, ReentrantLock} | ||
|
||
import org.apache.spark.TaskContext | ||
|
||
class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) { | ||
// This lock is used to generate condition variables, which affords us the flexibility to notify | ||
// specific threads at a time. If we use the regular synchronized pattern, we have to either | ||
|
@@ -27,14 +29,18 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) | |
private val lock = new ReentrantLock() | ||
private var occupiedSlots: Int = 0 | ||
|
||
private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int) { | ||
private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int, taskId: Long) { | ||
var signaled: Boolean = false | ||
} | ||
|
||
// We expect a relatively small number of threads to be contending for this lock at any given | ||
// time, therefore we are not concerned with the insertion/removal time complexity. | ||
private val waitingQueue: PriorityQueue[ThreadInfo] = | ||
new PriorityQueue[ThreadInfo](Ordering.by[ThreadInfo, T](_.priority).reverse) | ||
new PriorityQueue[ThreadInfo]( | ||
// use task id as tie breaker when priorities are equal (both are 0 because never hold lock) | ||
Ordering.by[ThreadInfo, T](_.priority).reverse. | ||
thenComparing((a, b) => a.taskId.compareTo(b.taskId)) | ||
) | ||
|
||
def tryAcquire(numPermits: Int, priority: T): Boolean = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also needs a taskAttemptId and updates to the ordering comparison below otherwise the algorithm we're using for tryAcquire doesn't match the algorithm being used for waiting queue ordering (although it's very close). For example, a task with priority 0 and task attempt ID 2 with 5 permits will block a task with priority 0 and task attempt ID 1 with 2 permits, even if the semaphore had 4 permits available. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, it's a very corner case I did't pay attention to. Since your comment is after my merge action, I have submitted another PR to fix this: https://github.com/NVIDIA/spark-rapids/pull/11587/files. BTW, Is there any real cases that we'll have different permits for different threads? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, because the concurrent GPU tasks config can be updated at runtime, and that changes the number of permits for subsequent tasks. See There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't realize "concurrent GPU tasks config can be updated at runtime", thanks ! |
||
lock.lock() | ||
|
@@ -57,7 +63,7 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) | |
try { | ||
if (!tryAcquire(numPermits, priority)) { | ||
val condition = lock.newCondition() | ||
val info = ThreadInfo(priority, condition, numPermits) | ||
val info = ThreadInfo(priority, condition, numPermits, TaskContext.get().taskAttemptId()) | ||
binmahone marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try { | ||
waitingQueue.add(info) | ||
while (!info.signaled) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we write this as
(technically this would flip the taskId comparison but I don't think we care)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see the argument for wanting to be more explicit with
thenComparing
so that's totally fine tooThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it matters ? We hope tasks with smaller taskid could have higher priority, so that we can avoid the very long tasks spanning from the start of stage to end of stage.