Skip to content

Commit

Permalink
Minor changes to keep up with accelerate
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed Mar 14, 2024
1 parent 263f8a3 commit 00e2ada
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ codegen uid env (Clustered c b) args =
workstealLoop workstealIndex workstealActiveThreads (op scalarTypeInt32 $ constant (TupRsingle scalarTypeInt32) 1) $ \_ -> do
let b' = mapArgs BCAJA b
(acc, loopsize') <- execStateT (evalCluster (toOnlyAcc c) b' args gamma ()) (mempty, LS ShapeRz OP_Unit)
Debug.Trace.traceShow loopsize' $ return ()
body acc loopsize'
retval_ $ boolean True
where
Expand Down Expand Up @@ -281,8 +280,7 @@ instance EvalOp NativeOp where
(BAE (flip (llvmOfFun2 @Native) gamma -> c) _)) -- combination function
| CJ x <- x'
, shrx `isAtDepth'` d'
= Debug.Trace.trace "generating code for permute" $
lift $ do
= lift $ do
ix' <- app1 f (multidim shrx is)
-- project element onto the destination array and (atomically) update
when (isJust ix') $ do
Expand Down Expand Up @@ -484,6 +482,7 @@ instance (StaticClusterAnalysis op, EnvF (JustAccumulator op) ~ EnvF op) => Stat
varToUnit x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToUnit $ coerce x

deriving instance (Eq (BackendClusterArg2 op x y)) => Eq (BackendClusterArg2 (JustAccumulator op) x y)
deriving instance (Show (BackendClusterArg2 op x y)) => Show (BackendClusterArg2 (JustAccumulator op) x y)


toOnlyAcc :: Cluster op args -> Cluster (JustAccumulator op) args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ compile uid module' = do
withModuleFromAST ctx ast $ \mdl ->
withNativeTargetMachine $ \machine ->
withTargetLibraryInfo triple $ \libinfo -> do
hPutStrLn stderr . T.unpack . decodeUtf8 =<< moduleLLVMAssembly mdl
-- hPutStrLn stderr . T.unpack . decodeUtf8 =<< moduleLLVMAssembly mdl

optimiseModule datalayout (Just machine) (Just libinfo) mdl

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ instance DesugarAcc NativeOp where
mkScan _ _ _ _ _ = error "todo"
mkPermute a b@(ArgArray _ (ArrayR shr _) sh _) c d
| DeclareVars lhs w lock <- declareVars $ buffersR $ TupRsingle scalarTypeWord8
= Debug.Trace.trace "hello??" $ aletUnique lhs
= aletUnique lhs
(Alloc shr scalarTypeWord8 $ groundToExpVar (shapeType shr) sh)
$ alet LeftHandSideUnit
(Exec NGenerate ( -- TODO: The old pipeline used a 'memset 0' instead, which sounds faster...
Expand Down Expand Up @@ -206,8 +206,9 @@ pattern OutDims l = BackendSpecific (Dims OutArr l)
-- TODO: constraints and bounds for the new variable(s)
instance MakesILP NativeOp where
type BackendVar NativeOp = NativeILPVar
type BackendArg NativeOp = (Int, Depth) -- direction, depth
data BackendClusterArg NativeOp a = BCAN Depth
type BackendArg NativeOp = (Int, IterationDepth) -- direction, depth
defaultBA = (0,0)
data BackendClusterArg NativeOp a = BCAN IterationDepth

mkGraph NBackpermute (_ :>: L (ArgArray In (ArrayR _shrI _) _ _) (_, lIns) :>: L (ArgArray Out (ArrayR shrO _) _ _) _ :>: ArgsNil) l@(Label i _) =
Graph.Info
Expand Down Expand Up @@ -312,12 +313,14 @@ instance ShrinkArg (BackendClusterArg NativeOp) where

data IndexPermutation env where
BP :: ShapeR sh1 -> ShapeR sh2 -> Fun env (sh1 -> sh2) -> GroundVars env sh1 -> IndexPermutation env
type Depth = Int
type IterationDepth = Int
instance Show (BackendClusterArg2 NativeOp env arg) where
show (BCAN2 _ d) = "depth " <> show d
show (BCAN2 i d) = "{ depth = " <> show d <> ", perm = " <> show i <> " }"
instance Show (IndexPermutation env) where
show (BP sh1 sh2 _ _) = show (rank sh1) <> "->" <> show (rank sh2)
instance StaticClusterAnalysis NativeOp where
data BackendClusterArg2 NativeOp env arg where
BCAN2 :: Maybe (IndexPermutation env) -> Depth -> BackendClusterArg2 NativeOp env arg
BCAN2 :: Maybe (IndexPermutation env) -> IterationDepth -> BackendClusterArg2 NativeOp env arg
def _ _ (BCAN i) = BCAN2 Nothing i
unitToVar = bcan2id
varToUnit = bcan2id
Expand All @@ -341,7 +344,6 @@ instance StaticClusterAnalysis NativeOp where
onOp NBackpermute (BCAN2 Nothing d :>: ArgsNil) (ArgFun f :>: ArgArray In (ArrayR shrI _) _ _ :>: ArgArray Out (ArrayR shrO _) sh _ :>: ArgsNil) _
= BCAN2 Nothing d :>: BCAN2 (Just (BP shrO shrI f sh)) d :>: BCAN2 Nothing d :>: ArgsNil
onOp NGenerate (bp :>: ArgsNil) (_:>:ArgArray Out (ArrayR shR _) _ _ :>:ArgsNil) _ =
Debug.Trace.traceShow (bp, rank shR) $
bcan2id bp :>: bp :>: ArgsNil -- storing the bp in the function argument. Probably not required, could just take it from the array one during codegen
onOp NPermute ArgsNil (_:>:_:>:_:>:_:>:ArgArray In (ArrayR shR _) _ _ :>:ArgsNil) _ =
BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing (rank shR) :>: ArgsNil
Expand Down Expand Up @@ -370,7 +372,10 @@ fold2bp (BCAN2 (Just (BP shr1 shr2 g sh)) i) foldsize = flip BCAN2 (i+1) $ Just
(TupRpair sh foldsize)

instance Eq (BackendClusterArg2 NativeOp env arg) where
BCAN2 p i == BCAN2 p' i' = p == p' && i == i'
x@(BCAN2 p i) == y@(BCAN2 p' i') = f $ p == p' && i == i'
where
f True = True
f False = False
instance Eq (IndexPermutation env) where
(BP shr1 shr2 f _) == (BP shr1' shr2' f' _)
| Just Refl <- matchShapeR shr1 shr1'
Expand Down
42 changes: 25 additions & 17 deletions accelerate-llvm-native/test/nofib/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,35 @@ main :: IO ()
main = do
let xs = fromList (Z :. 10) [1 :: Int ..]
let ys = use xs
putStrLn "generate:"
let f = generate (I1 10) (\(I1 x0) -> 10 :: Exp Int)
-- putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f

putStrLn "mapmap:"
let f = map (+1) . map (*2) -- $ ys
-- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- putStrLn $ test @UniformScheduleFun @NativeKernel (f ys)
print $ runN @Native f xs
print $ runN @Native (f ys)

putStrLn "fold:"
let f = fold1 (+) ys
-- putStrLn $ test @UniformScheduleFun @NativeKernel f
let f = T2 (map (+1) ys) (map (*2) $ reverse ys)
-- let f = --map (\(T2 a b) -> a + b) $
-- zip ys $ reverse ys
putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f

putStrLn "scan:"
let f = scanl1 (+) ys
-- putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f
-- putStrLn "generate:"
-- let f = generate (I1 10) (\(I1 x0) -> 10 :: Exp Int)
-- -- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- print $ run @Native f

-- putStrLn "mapmap:"
-- let f = map (+1) . map (*2) -- $ ys
-- -- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- -- putStrLn $ test @UniformScheduleFun @NativeKernel (f ys)
-- print $ runN @Native f xs
-- print $ runN @Native (f ys)

-- putStrLn "fold:"
-- let f = fold1 (+) ys
-- -- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- print $ run @Native f

-- putStrLn "scan:"
-- let f = scanl1 (+) ys
-- -- putStrLn $ test @UniformScheduleFun @NativeKernel f
-- print $ run @Native f

-- Prelude.print $ runNWithObj @Native ArrayReadsWrites $ quicksort $ use $ fromList (Z :. 5) [100::Int, 200, 3, 5, 4]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ createBlocks
in
trace (bformat ("generated " % int % " instructions in " % int % " blocks") (n+m) m) ( F.toList blocks , s' )
where
makeBlock b@Block{..} = --Debug.Trace.traceShow b $
makeBlock b@Block{..} =
LLVM.BasicBlock (downcast blockLabel) (F.toList instructions) (LLVM.Do terminator)


Expand Down
1 change: 0 additions & 1 deletion stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ extra-deps:
- OptDir-0.0.4
- bytestring-encoding-0.1.2.0
- ../accelerate
# - MIP-0.1.1.0
- github: msakai/haskell-MIP
commit: 4295aa21a24a30926b55770c55ac00f749fb8a39
subdirs:
Expand Down

0 comments on commit 00e2ada

Please sign in to comment.