Skip to content

Commit

Permalink
wip: data parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed May 8, 2024
1 parent f38b65a commit 9c42bbc
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,32 @@ codegen :: ShortByteString
codegen name env (Clustered c b) args =
codeGenFunction name (LLVM.Lam argTp "arg") $ do
extractEnv
workstealLoop workstealIndex workstealActiveThreads (op scalarTypeInt32 $ constant (TupRsingle scalarTypeInt32) 1) $ \_ -> do
let b' = mapArgs BCAJA b
(acc, loopsize') <- execStateT (evalCluster (toOnlyAcc c) b' args gamma ()) (mempty, LS ShapeRz OP_Unit)
body acc loopsize'
retval_ $ boolean True
-- workstealLoop workstealIndex workstealActiveThreads (op scalarTypeInt32 $ constant (TupRsingle scalarTypeInt32) 1) $ \_ -> do
let b' = mapArgs BCAJA b
(acc, loopsize) <- execStateT (evalCluster (toOnlyAcc c) b' args gamma ()) (mempty, LS ShapeRz OP_Unit)
-- body acc loopsize'
acc' <- operandsMapToPairs acc $ \(accTypeR, toOp, fromOp) -> fmap fromOp $ flip execStateT (toOp acc) $ case loopsize of
LS loopshr loopsh ->
workstealChunked loopshr workstealIndex workstealActiveThreads loopsh accTypeR (body loopshr toOp fromOp)

retval_ $ boolean True
where
ba = makeBackendArg @NativeOp args gamma c b
(argTp, extractEnv, workstealIndex, workstealActiveThreads, gamma) = bindHeaderEnv env
body :: Accumulated -> Loopsizes -> CodeGen Native ()
body initialAcc partialLoopSize =
body :: ShapeR sh -> (Accumulated -> a) -> (a -> Accumulated) -> LoopWork sh (StateT a (CodeGen Native))
body ShapeRz _ _ = LoopWorkZ
body (ShapeRsnoc shr) toOp fromOp = LoopWorkSnoc (body shr toOp fromOp) (\i is -> StateT $ \op -> second toOp <$> runStateT (foo i is) (fromOp op))
where
foo :: Operands Int -> [Operands Int] -> StateT Accumulated (CodeGen Native) ()
foo linix ixs = do
let d = length ixs -- TODO check: this or its inverse (i.e. totalDepth - length ixs)?
let i = (d, linix, ixs)
newInputs <- readInputs @_ @_ @NativeOp i args ba gamma
outputs <- evalOps @NativeOp i c newInputs args gamma
writeOutputs @_ @_ @NativeOp i args outputs gamma

body' :: Accumulated -> Loopsizes -> CodeGen Native ()
body' initialAcc partialLoopSize =
case partialLoopSize of -- used to combine with loopSize here, but I think we can do everything in the static analysis?
LS shr' sh' ->
let go :: ShapeR sh -> Operands sh -> (Int, Operands Int, [Operands Int]) -> StateT Accumulated (CodeGen Native) ()
Expand Down Expand Up @@ -529,8 +546,6 @@ isAtDepth vs x = x == depth vs
isAtDepth' :: ShapeR sh -> Int -> Bool
isAtDepth' vs x = x == rank vs

typerInt :: TypeR Int
typerInt = TupRsingle scalarTypeInt

zeroes :: TypeR a -> Operands a
zeroes TupRunit = OP_Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TupleSections #-}
-- |
-- Module : Data.Array.Accelerate.LLVM.CodeGen.Native.Loop
-- Copyright : [2014..2020] The Accelerate Team
Expand All @@ -19,7 +20,7 @@ module Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Shape hiding ( eq )

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A hiding (lift)
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
Expand All @@ -34,6 +35,8 @@ import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.Volatile
import qualified LLVM.AST.Type.Instruction.RMW as RMW
import Control.Monad.Trans
import Control.Monad.State

-- | A standard 'for' loop, that steps from the start to end index executing the
-- given function at each index.
Expand Down Expand Up @@ -73,6 +76,30 @@ imapNestFromTo shr start end extent body =
$ \sz -> imapFromTo ssz esz
$ \i -> k (OP_Pair sz i)

loopWorkFromTo :: ShapeR sh -> Operands sh -> Operands sh -> Operands sh -> TypeR s -> LoopWork sh (StateT (Operands s) (CodeGen Native)) -> StateT (Operands s) (CodeGen Native) ()
loopWorkFromTo shr start end extent tys loopwork = do
linix <- lift (intOfIndex shr extent start)
loopWorkFromTo' shr start end linix [] tys loopwork

loopWorkFromTo' :: ShapeR sh -> Operands sh -> Operands sh -> Operands Int -> [Operands Int] -> TypeR s -> LoopWork sh (StateT (Operands s) (CodeGen Native)) -> StateT (Operands s) (CodeGen Native) ()
loopWorkFromTo' ShapeRz OP_Unit OP_Unit _ _ _ LoopWorkZ = pure ()
loopWorkFromTo' (ShapeRsnoc shr) (OP_Pair start' start) (OP_Pair end' end) linix ixs tys (LoopWorkSnoc lw foo) = do
StateT $ \s -> ((),) <$> Loop.iter
(TupRpair typerInt typerInt)
tys
(OP_Pair start linix)
s
(\(OP_Pair i _) -> lt singleType i end)
(\(OP_Pair i l) -> OP_Pair <$> add numType (constant typerInt 1) i <*> add numType (constant typerInt 1) l)
(\(OP_Pair i l) -> execStateT $ do
loopWorkFromTo' shr start' end' l (i:ixs) tys lw
foo l (i : ixs))



typerInt :: TypeR Int
typerInt = TupRsingle scalarTypeInt


{--
-- TLM: this version (seems to) compute the corresponding linear index as it
Expand Down Expand Up @@ -169,128 +196,167 @@ workstealLoop
:: Operand (Ptr Int32) -- index into work
-> Operand (Ptr Int32) -- number of threads that are working
-> Operand Int32 -- size of total work
-> (Operand Int32 -> CodeGen Native ())
-> CodeGen Native ()
-> (Operand Int32 -> StateT (Operands s) (CodeGen Native) ())
-> StateT (Operands s) (CodeGen Native) ()
workstealLoop counter activeThreads size doWork = do
start <- newBlock "worksteal.start"
work <- newBlock "worksteal.loop.work"
exit <- newBlock "worksteal.exit"
exitLast <- newBlock "worksteal.exit.last"
finished <- newBlock "worksteal.finished"
start <- lift $ newBlock "worksteal.start"
work <- lift $ newBlock "worksteal.loop.work"
exit <- lift $ newBlock "worksteal.exit"
exitLast <- lift $ newBlock "worksteal.exit.last"
finished <- lift $ newBlock "worksteal.finished"

-- Check whether there might be work for us
initialCounter <- instr' $ Load scalarType NonVolatile counter
initialCondition <- lt singleType (OP_Int32 initialCounter) (OP_Int32 size)
cbr initialCondition start finished
initialCounter <- lift $ instr' $ Load scalarType NonVolatile counter
initialCondition <- lift $ lt singleType (OP_Int32 initialCounter) (OP_Int32 size)
lift $ cbr initialCondition start finished

setBlock start
lift $ setBlock start
-- Might be work for us!
-- First mark that this thread is doing work.
atomicAdd Acquire activeThreads (integral TypeInt32 1)
startIndex <- atomicAdd Unordered counter (integral TypeInt32 1)
startCondition <- lt singleType (OP_Int32 startIndex) (OP_Int32 size)
cbr startCondition work exit
lift $ atomicAdd Acquire activeThreads (integral TypeInt32 1)
startIndex <- lift $ atomicAdd Unordered counter (integral TypeInt32 1)
startCondition <- lift $ lt singleType (OP_Int32 startIndex) (OP_Int32 size)
lift $ cbr startCondition work exit

setBlock work
indexName <- freshLocalName
lift $ setBlock work
indexName <- lift $ freshLocalName
let index = LocalReference type' indexName

-- Already claim the next work, to hide the latency of the atomic instruction
nextIndex <- atomicAdd Unordered counter (integral TypeInt32 1)
nextIndex <- lift $ atomicAdd Unordered counter (integral TypeInt32 1)

doWork index
condition <- lt singleType (OP_Int32 nextIndex) (OP_Int32 size)
condition <- lift $ lt singleType (OP_Int32 nextIndex) (OP_Int32 size)

-- Append the phi node to the start of the 'work' block.
-- We can only do this now, as we need to have 'nextIndex', and know the
-- exit block of 'doWork'.
currentBlock <- getBlock
phi1 work indexName [(startIndex, start), (nextIndex, currentBlock)]
currentBlock <- lift $ getBlock
lift $ phi1 work indexName [(startIndex, start), (nextIndex, currentBlock)]

cbr condition work exit
lift $ cbr condition work exit

setBlock exit
lift $ setBlock exit
-- Decrement activeThreads
remaining <- atomicAdd Release activeThreads (integral TypeInt32 (-1))
remaining <- lift $ atomicAdd Release activeThreads (integral TypeInt32 (-1))
-- If 'activeThreads' was 1 (now 0), then all work is definitely done.
-- Note that there may be multiple threads returning true here.
-- It is guaranteed that at least one thread returns true.
allDone <- eq singleType (OP_Int32 remaining) (liftInt32 1)
cbr allDone exitLast finished
allDone <- lift $ eq singleType (OP_Int32 remaining) (liftInt32 1)
lift $ cbr allDone exitLast finished

setBlock exitLast
lift $ setBlock exitLast
-- Use compare-and-set to change the active-threads counter to 1:
-- * Out of all threads that currently see an active-thread count of 0, only
-- 1 will succeed the CAS.
-- * Given that the counter is artifically increased here, no other thread
-- will see the counter ever drop to 0.
-- Hence we get a unique thread to continue the computation after this kernel.
casResult <- instr' $ CmpXchg TypeInt32 NonVolatile activeThreads (integral TypeInt32 0) (integral TypeInt32 1) (CrossThread, Monotonic) Monotonic
last <- instr' $ ExtractValue primType (TupleIdxRight TupleIdxSelf) casResult
retval_ last
casResult <- lift $ instr' $ CmpXchg TypeInt32 NonVolatile activeThreads (integral TypeInt32 0) (integral TypeInt32 1) (CrossThread, Monotonic) Monotonic
last <- lift $ instr' $ ExtractValue primType (TupleIdxRight TupleIdxSelf) casResult
lift $ retval_ last

setBlock finished
lift $ setBlock finished
-- Work was already finished
retval_ $ boolean False
lift $ retval_ $ boolean False

workstealChunked :: ShapeR sh -> Operand (Ptr Int32) -> Operand (Ptr Int32) -> Operands sh -> (Operands sh -> Operands Int -> CodeGen Native ()) -> CodeGen Native ()
workstealChunked shr counter activeThreads sh doWork = do
let chunkSz = chunkSize shr
chunkCounts <- chunkCount shr sh chunkSz
chunkCnt <- shapeSize shr chunkCounts
chunkCnt' :: Operand Int32 <- instr' $ Trunc boundedType boundedType $ op TypeInt chunkCnt
workstealChunked :: ShapeR sh -> Operand (Ptr Int32) -> Operand (Ptr Int32) -> Operands sh -> TypeR s -> LoopWork sh (StateT (Operands s) (CodeGen Native)) -> StateT (Operands s) (CodeGen Native) ()
workstealChunked shr counter activeThreads sh tys loopwork = do
let chunkSz = chunkSize' shr sh
chunkCounts <- lift $ chunkCount shr sh chunkSz
chunkCnt <- lift $ shapeSize shr chunkCounts
chunkCnt' :: Operand Int32 <- lift $ instr' $ Trunc boundedType boundedType $ op TypeInt chunkCnt

workstealLoop counter activeThreads chunkCnt' $ \chunkLinearIndex -> do
chunkLinearIndex' <- instr' $ Ext boundedType boundedType chunkLinearIndex
chunkIndex <- indexOfInt shr chunkCounts (OP_Int chunkLinearIndex')
start <- chunkStart shr chunkSz chunkIndex
end <- chunkEnd shr sh chunkSz start

imapNestFromTo shr start end sh doWork

chunkSize :: ShapeR sh -> sh
chunkSize ShapeRz = ()
chunkSize (ShapeRsnoc ShapeRz) = ((), 1024 * 16)
chunkSize (ShapeRsnoc (ShapeRsnoc ShapeRz)) = (((), 64), 64)
chunkSize (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc ShapeRz))) = ((((), 16), 16), 32)
chunkSize (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc sh)))) = ((((go sh, 8), 8), 16), 16)
where
go :: ShapeR sh' -> sh'
go ShapeRz = ()
go (ShapeRsnoc sh') = (go sh', 1)

chunkCount :: ShapeR sh -> Operands sh -> sh -> CodeGen Native (Operands sh)
chunkCount ShapeRz OP_Unit () = return OP_Unit
chunkCount (ShapeRsnoc shr) (OP_Pair sh sz) (chunkSh, chunkSz) = do
chunkLinearIndex' <- lift $ instr' $ Ext boundedType boundedType chunkLinearIndex
chunkIndex <- lift $ indexOfInt shr chunkCounts (OP_Int chunkLinearIndex')
start <- lift $ chunkStart shr chunkSz chunkIndex
end <- lift $ chunkEnd shr sh chunkSz start
-- imapNestFromTo shr start end sh doWork
loopWorkFromTo shr start end sh tys loopwork


chunkSize' :: ShapeR sh -> Operands sh -> Operands sh
chunkSize' ShapeRz OP_Unit = OP_Unit
chunkSize' (ShapeRsnoc ShapeRz) (OP_Pair _ sz) = OP_Pair OP_Unit sz
chunkSize' (ShapeRsnoc shr) (OP_Pair sh _) = OP_Pair (chunkSize' shr sh) (liftInt 1)

-- chunkSize :: ShapeR sh -> Operands sh
-- chunkSize ShapeRz = OP_Unit
-- chunkSize (ShapeRsnoc shr) = OP_Pair (chunkSize shr) (liftInt 1)
-- chunkSize (ShapeRsnoc ShapeRz) = ((), 1024 * 16)
-- chunkSize (ShapeRsnoc (ShapeRsnoc ShapeRz)) = (((), 64), 64)
-- chunkSize (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc ShapeRz))) = ((((), 16), 16), 32)
-- chunkSize (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc sh)))) = ((((go sh, 8), 8), 16), 16)
-- where
-- go :: ShapeR sh' -> sh'
-- go ShapeRz = ()
-- go (ShapeRsnoc sh') = (go sh', 1)

chunkCount :: ShapeR sh -> Operands sh -> Operands sh -> CodeGen Native (Operands sh)
chunkCount ShapeRz OP_Unit OP_Unit = return OP_Unit
chunkCount (ShapeRsnoc shr) (OP_Pair sh sz) (OP_Pair chunkSh chunkSz) = do
counts <- chunkCount shr sh chunkSh

-- Compute ceil(sz / chunkSz), as
-- (sz + chunkSz - 1) `quot` chunkSz
sz' <- add numType sz (liftInt $ chunkSz - 1)
count <- A.quot TypeInt sz' $ liftInt chunkSz
chunkszsub1 <- sub numType chunkSz $ constant typerInt 1
sz' <- add numType sz chunkszsub1
count <- A.quot TypeInt sz' chunkSz

return $ OP_Pair counts count

chunkStart :: ShapeR sh -> sh -> Operands sh -> CodeGen Native (Operands sh)
chunkStart ShapeRz () OP_Unit = return OP_Unit
chunkStart (ShapeRsnoc shr) (chunkSh, chunkSz) (OP_Pair sh sz) = do
-- chunkSize :: ShapeR sh -> sh
-- chunkSize ShapeRz = ()
-- chunkSize (ShapeRsnoc ShapeRz) = ((), 1024 * 16)
-- chunkSize (ShapeRsnoc (ShapeRsnoc ShapeRz)) = (((), 64), 64)
-- chunkSize (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc ShapeRz))) = ((((), 16), 16), 32)
-- chunkSize (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc (ShapeRsnoc sh)))) = ((((go sh, 8), 8), 16), 16)
-- where
-- go :: ShapeR sh' -> sh'
-- go ShapeRz = ()
-- go (ShapeRsnoc sh') = (go sh', 1)

-- chunkCount :: ShapeR sh -> Operands sh -> sh -> CodeGen Native (Operands sh)
-- chunkCount ShapeRz OP_Unit () = return OP_Unit
-- chunkCount (ShapeRsnoc shr) (OP_Pair sh sz) (chunkSh, chunkSz) = do
-- counts <- chunkCount shr sh chunkSh

-- -- Compute ceil(sz / chunkSz), as
-- -- (sz + chunkSz - 1) `quot` chunkSz
-- sz' <- add numType sz (liftInt $ chunkSz - 1)
-- count <- A.quot TypeInt sz' $ liftInt chunkSz

-- return $ OP_Pair counts count

chunkStart :: ShapeR sh -> Operands sh -> Operands sh -> CodeGen Native (Operands sh)
chunkStart ShapeRz OP_Unit OP_Unit = return OP_Unit
chunkStart (ShapeRsnoc shr) (OP_Pair chunkSh chunkSz) (OP_Pair sh sz) = do
ixs <- chunkStart shr chunkSh sh
ix <- mul numType sz $ liftInt chunkSz
ix <- mul numType sz chunkSz
return $ OP_Pair ixs ix

chunkEnd
:: ShapeR sh
-> Operands sh -- Array sizee
-> sh -- Chunk size
-> Operands sh -- Chunk size
-> Operands sh -- Chunk start
-> CodeGen Native (Operands sh) -- Chunk end
chunkEnd ShapeRz OP_Unit () OP_Unit = return OP_Unit
chunkEnd (ShapeRsnoc shr) (OP_Pair sh0 sz0) (sh1, sz1) (OP_Pair sh2 sz2) = do
chunkEnd ShapeRz OP_Unit OP_Unit OP_Unit = return OP_Unit
chunkEnd (ShapeRsnoc shr) (OP_Pair sh0 sz0) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2) = do
sh3 <- chunkStart shr sh1 sh2
sz3 <- add numType sz2 $ liftInt sz1
sz3 <- add numType sz2 sz1
sz3' <- A.min singleType sz3 sz0
return $ OP_Pair sh3 sz3'

atomicAdd :: MemoryOrdering -> Operand (Ptr Int32) -> Operand Int32 -> CodeGen Native (Operand Int32)
atomicAdd ordering ptr increment = do
instr' $ AtomicRMW numType NonVolatile RMW.Add ptr increment (CrossThread, ordering)


data LoopWork sh m where
LoopWorkZ :: LoopWork () m
LoopWorkSnoc :: LoopWork sh m
-- The list contains only indices available, i.e. not the ones in even deeper nesting
-> (Operands Int -> [Operands Int] -> m ())
-> LoopWork (sh, Int) m
23 changes: 12 additions & 11 deletions accelerate-llvm-native/test/nofib/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,20 @@ import Data.Array.Accelerate.Trafo.Partitioning.ILP.Solve
import Data.Array.Accelerate.Data.Bits
import Data.Array.Accelerate.Unsafe
import Control.Concurrent
import Quickhull
-- import Quickhull
main :: IO ()
main = do
let xs = fromList (Z :. 10) [1 :: Int ..]
let ys = map (+1) $
use xs
let f = map (*2)
let program = awhile (map (A.>0) . asnd) (\(T2 a b) -> T2 (f a) (map (\x -> x - 1) b)) (T2 ys $ unit $ constant (100000 :: Int))
let xs = fromList (Z :. 2 :. 2) [1 :: Int ..]
-- let ys = map (+1) $
-- use xs
-- let f = map (*2)
-- let program = awhile (map (A.>0) . asnd) (\(T2 a b) -> T2 (f a) (map (\x -> x - 1) b)) (T2 ys $ unit $ constant (100000 :: Int))

putStrLn "mapscanmap:"
let f = map (*2) $ scanl1 (+) $ map (+4) $ use xs
putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f

-- let program xs =
-- -- let xs = A.use (A.fromList (A.Z A.:. 10) ([0..] :: [Int])) in
-- A.map fst $ A.zip (A.reverse xs) (A.reverse $ A.backpermute (A.I1 10) Prelude.id (xs :: A.Acc (A.Vector Int)))
Expand Down Expand Up @@ -76,11 +82,6 @@ main = do
-- -- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- print $ run @Native f

putStrLn "mapscanmap:"
let f = map (*2) $ scanl1 (+) $ map (+4) ys
putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f



-- Prelude.print $ runNWithObj @Native ArrayReadsWrites $ quicksort $ use $ fromList (Z :. 5) [100::Int, 200, 3, 5, 4]
Expand Down

0 comments on commit 9c42bbc

Please sign in to comment.