diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs index 52ac156b2..c55056c7c 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs @@ -85,6 +85,11 @@ data Workers = Workers -- won't duplicate that activity into this array. -- TODO: Should we add padding in this array to place each element in a different cache line? , workerActivity :: {-# UNPACK #-} !(MutableArray RealWorld Activity) + -- To attach a (reliable) finalizer to Workers, we add an IORef. + -- Adding a finalizer to other objects is not reliable, and may cause that + -- the finalizer runs too early. Hence we attach the finalizer to an IORef. + -- https://hackage.haskell.org/package/base-4.19.1.0/docs/Data-IORef.html#v:mkWeakIORef + , workerIORef :: {-# UNPACK #-} !(IORef ()) } data Activity where @@ -110,7 +115,7 @@ instance Show Activity where schedule :: Workers -> Job -> IO () schedule workers job = do pushL (workerTaskQueue workers) job - wakeAll $ workerSleep workers + wakeAll (workerSleep workers) Work runWorker :: Workers -> ThreadIdx -> IO () runWorker !workers !threadIdx = do @@ -118,8 +123,10 @@ runWorker !workers !threadIdx = do tryRunWork :: Workers -> ThreadIdx -> Int16 -> IO () tryRunWork !workers !threadIdx 16 = do - sleepIf (workerSleep workers) ({- Last check before sleeping -} noWork workers) - runWorker workers threadIdx + reason <- sleepIf (workerSleep workers) ({- Last check before sleeping -} noWork workers) + case reason of + Work -> runWorker workers threadIdx + Exit -> return () tryRunWork !workers !threadIdx !retries = do mjob <- tryDequeue workers case mjob of @@ -182,7 +189,9 @@ hireWorkersOn caps = do queue <- newQ let count = length caps activities <- newArray count Inactive - let workers = Workers count sleep queue activities + ioref <- newIORef () + _ <- mkWeakIORef ioref $ wakeAll sleep Exit + let workers = Workers count sleep queue activities ioref forM_ caps $ \cpu -> do tid <- instrumentedForkOn "worker" cpu $ do tid <- myThreadId @@ -270,7 +279,7 @@ resolveSignal !workers (NativeSignal ioref) = do executeKernel :: forall env. Workers -> ThreadIdx -> KernelCall env -> Job -> IO () executeKernel !workers !myIdx (KernelCall fun arg) continuation = do writeArray (workerActivity workers) myIdx $ Active @env Proxy fun arg continuation - wakeAll $ workerSleep workers + wakeAll (workerSleep workers) Work helpKernel workers myIdx myIdx (return ()) (return ()) {-# INLINE helpKernel #-} diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Sleep.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Sleep.hs index 4e3bbf212..efb1d33ff 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Sleep.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Sleep.hs @@ -18,6 +18,7 @@ module Data.Array.Accelerate.LLVM.Native.Execute.Sleep ( SleepScope, newSleepScope, sleepIf, wakeAll + , WakeReason(..) ) where import Data.Atomics @@ -35,10 +36,12 @@ newSleepScope = do data State -- Some thread is waiting. The MVar is to be filled when new work is -- available. - = Waiting {-# UNPACK #-} !(MVar ()) + = Waiting {-# UNPACK #-} !(MVar WakeReason) -- All threads are busy. The MVar is currently not used (and is empty). -- It will be used when the state changes to waiting. - | Busy {-# UNPACK #-} !(MVar ()) + | Busy {-# UNPACK #-} !(MVar WakeReason) + +data WakeReason = Work | Exit -- Invariants: -- * If the state is Waiting, then 'sleepIf' will not write to the state. @@ -46,7 +49,7 @@ data State -- That will ensure that if a CAS fails, then it was interleaved by (at least) -- another call to the same function. -sleepIf :: SleepScope -> IO (Bool) -> IO () +sleepIf :: SleepScope -> IO (Bool) -> IO WakeReason sleepIf (SleepScope ref) condition = do ticket <- readForCAS ref case peekTicket ticket of @@ -60,7 +63,7 @@ sleepIf (SleepScope ref) condition = do -- sleepIf (SleepScope ref) condition else -- Don't wait - return () + return Work Busy mvar -> do -- No thread is waiting yet c <- condition @@ -79,10 +82,10 @@ sleepIf (SleepScope ref) condition = do -- sleepIf (SleepScope ref) condition else -- Don't wait - return () + return Work -wakeAll :: SleepScope -> IO () -wakeAll (SleepScope ref) = do +wakeAll :: SleepScope -> WakeReason -> IO () +wakeAll (SleepScope ref) reason = do ticket <- readForCAS ref case peekTicket ticket of -- No need to wake anyone! @@ -96,5 +99,5 @@ wakeAll (SleepScope ref) = do -- interleaved by other threads doing 'wakeAll' and 'sleepIf'. -- Wake all threads - when success $ putMVar mvar () + when success $ putMVar mvar reason