Skip to content

Commit

Permalink
fix 1-dim folds (sequential); still doesn't support fused fold . fold
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed May 22, 2024
1 parent 20dda7f commit ea63590
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,17 @@ codegen name env (Clustered c b) args =
-- body acc loopsize'
acc' <- operandsMapToPairs acc $ \(accTypeR, toOp, fromOp) -> fmap fromOp $ flip execStateT (toOp acc) $ case loopsize of
LS loopshr loopsh ->
workstealChunked loopshr workstealIndex workstealActiveThreads (flipShape loopshr loopsh) accTypeR (body loopshr toOp fromOp)
workstealChunked loopshr workstealIndex workstealActiveThreads (flipShape loopshr loopsh) accTypeR
(body loopshr toOp fromOp, -- the LoopWork
StateT $ \op -> second toOp <$> runStateT (foo (liftInt 0) []) (fromOp op)) -- the action to run after the outer loop
-- acc'' <- flip execStateT acc' $ foo (liftInt 0) []
pure ()
where
ba = makeBackendArg @NativeOp args gamma c b
(argTp, extractEnv, workstealIndex, workstealActiveThreads, gamma) = bindHeaderEnv env
body :: ShapeR sh -> (Accumulated -> a) -> (a -> Accumulated) -> LoopWork sh (StateT a (CodeGen Native))
body ShapeRz _ _ = LoopWorkZ
body (ShapeRsnoc shr) toOp fromOp = LoopWorkSnoc (body shr toOp fromOp) (\i is -> StateT $ \op -> second toOp <$> runStateT (foo i is) (fromOp op))
where
foo :: Operands Int -> [Operands Int] -> StateT Accumulated (CodeGen Native) ()
foo linix ixs = do
let d = length ixs -- TODO check: this or its inverse (i.e. totalDepth - length ixs)?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ imapNestFromTo shr start end extent body =
$ \sz -> imapFromTo ssz esz
$ \i -> k (OP_Pair sz i)

loopWorkFromTo :: ShapeR sh -> Operands sh -> Operands sh -> Operands sh -> TypeR s -> LoopWork sh (StateT (Operands s) (CodeGen Native)) -> StateT (Operands s) (CodeGen Native) ()
loopWorkFromTo shr start end extent tys loopwork = do
loopWorkFromTo :: ShapeR sh -> Operands sh -> Operands sh -> Operands sh -> TypeR s -> (LoopWork sh (StateT (Operands s) (CodeGen Native)),StateT (Operands s) (CodeGen Native) ()) -> StateT (Operands s) (CodeGen Native) ()
loopWorkFromTo shr start end extent tys (loopwork,finish) = do
linix <- lift (intOfIndex shr extent start)
loopWorkFromTo' shr start end extent linix [] tys loopwork
finish


loopWorkFromTo' :: ShapeR sh -> Operands sh -> Operands sh -> Operands sh -> Operands Int -> [Operands Int] -> TypeR s -> LoopWork sh (StateT (Operands s) (CodeGen Native)) -> StateT (Operands s) (CodeGen Native) ()
loopWorkFromTo' ShapeRz OP_Unit OP_Unit OP_Unit _ _ _ LoopWorkZ = pure ()
Expand Down Expand Up @@ -272,7 +274,7 @@ workstealLoop counter activeThreads size doWork = do
-- lift $ setBlock dummy -- without this, the previous block always returns True for some reason


workstealChunked :: ShapeR sh -> Operand (Ptr Int32) -> Operand (Ptr Int32) -> Operands sh -> TypeR s -> LoopWork sh (StateT (Operands s) (CodeGen Native)) -> StateT (Operands s) (CodeGen Native) ()
workstealChunked :: ShapeR sh -> Operand (Ptr Int32) -> Operand (Ptr Int32) -> Operands sh -> TypeR s -> (LoopWork sh (StateT (Operands s) (CodeGen Native)), StateT (Operands s) (CodeGen Native) ()) -> StateT (Operands s) (CodeGen Native) ()
workstealChunked shr counter activeThreads sh tys loopwork = do
let chunkSz = chunkSize' shr sh
chunkCounts <- lift $ chunkCount shr sh chunkSz
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ compile uid name module' = do
withNativeTargetMachine $ \machine ->
withTargetLibraryInfo triple $ \libinfo -> do
-- dump llvm
-- hPutStrLn stderr . T.unpack . decodeUtf8 =<< moduleLLVMAssembly mdl
hPutStrLn stderr . T.unpack . decodeUtf8 =<< moduleLLVMAssembly mdl

optimiseModule datalayout (Just machine) (Just libinfo) mdl

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,9 @@ instance MakesILP NativeOp where

encodeBackendClusterArg (BCAN i) = intHost $(hashQ ("BCAN" :: String)) <> intHost i

inputConstraints :: HasCallStack => Label -> Labels -> Constraint NativeOp
inputConstraints :: Label -> Labels -> Constraint NativeOp
inputConstraints l = foldMap $ \lIn ->
timesN (fused lIn l) .>=. ILP.c (InDir l) .-. ILP.c (OutDir lIn)
<> (-1) .*. timesN (fused lIn l) .<=. ILP.c (InDir l) .-. ILP.c (OutDir lIn)
<> timesN (fused lIn l) .>=. ILP.c (InDims l) .-. ILP.c (OutDims lIn)
timesN (fused lIn l) .>=. ILP.c (InDims l) .-. ILP.c (OutDims lIn)
<> (-1) .*. timesN (fused lIn l) .<=. ILP.c (InDims l) .-. ILP.c (OutDims lIn)

inrankifmanifest :: ShapeR sh -> Label -> Constraint NativeOp
Expand Down
38 changes: 20 additions & 18 deletions accelerate-llvm-native/test/nofib/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,22 @@ import Control.Concurrent
-- import Quickhull
main :: IO ()
main = do
let xs = fromList (Z :. 5 :. 7) [1 :: Int ..]
-- let ys = map (+1) $
-- use xs
-- let f = map (*2)
-- let program = awhile (map (A.>0) . asnd) (\(T2 a b) -> T2 (f a) (map (\x -> x - 1) b)) (T2 ys $ unit $ constant (100000 :: Int))

putStrLn "scan:"
let f =
--map (*2) $
scanl1 (+) $
--map (+4) $
use xs
putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f
let xs = fromList (Z :. 5) [1 :: Int ..]
let ys = map (+1) $
use xs
let f = map (*2)
let program = awhile (map (A.>0) . asnd) (\(T2 a b) -> T2 (f a) (map (\x -> x - 1) b)) (T2 ys $ unit $ constant (100000 :: Int))

-- putStrLn "scan:"
-- let f =
-- --map (*2) $
-- scanl1 (+) $
-- --map (+4) $
-- use xs
-- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- print $ run @Native f

-- putStrLn $ test @UniformScheduleFun @NativeKernel $ map (\(I2 a b)->b) (generate (I2 10 5) (\(I2 i j) -> fromIndex (I2 (5 :: Exp Int) (10 :: Exp Int)) (toIndex (I2 10 5) (I2 i j))))

-- threadDelay 5000000
-- putStrLn "done"
Expand Down Expand Up @@ -83,10 +85,10 @@ main = do
-- print $ runN @Native f xs
-- print $ runN @Native (f ys)

-- putStrLn "fold:"
-- let f = fold1 (+) ys
-- -- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- print $ run @Native f
putStrLn "fold:"
let f = fold1 (+) ys
putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f

-- putStrLn "scan:"
-- let f = scanl1 (+) ys
Expand Down

0 comments on commit ea63590

Please sign in to comment.