Skip to content

Commit

Permalink
Merge branch 'new-pipeline' of https://github.com/ivogabe/accelerate-…
Browse files Browse the repository at this point in the history
…llvm into new-pipeline
  • Loading branch information
dpvanbalen committed Apr 15, 2024
2 parents 50588bb + 7ebec32 commit 53c3937
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import Data.Array.Accelerate.AST.Operation (groundToExpVar, Fun, mapArgs)
import Data.Array.Accelerate.LLVM.Native.CodeGen.Permute (atomically)
import Control.Monad.State (StateT(..), lift, evalStateT, execStateT)
import qualified Data.Map as M
import Data.ByteString.Short ( ShortByteString )
import Data.Array.Accelerate.AST.LeftHandSide (Exists (Exists), lhsToTupR)
import Data.Array.Accelerate.Trafo.Partitioning.ILP.Labels (Label)
import Data.Array.Accelerate.LLVM.CodeGen.Constant (constant, boolean)
Expand All @@ -88,13 +89,13 @@ import Data.Array.Accelerate.Backend (SLVOperation(..))



codegen :: UID
codegen :: ShortByteString
-> Env AccessGroundR env
-> Clustered NativeOp args
-> Args env args
-> LLVM Native (Module (KernelType env))
codegen uid env (Clustered c b) args =
codeGenFunction uid "fused_cluster_name" (LLVM.Lam argTp "arg") $ do
codegen name env (Clustered c b) args =
codeGenFunction name (LLVM.Lam argTp "arg") $ do
extractEnv
workstealLoop workstealIndex workstealActiveThreads (op scalarTypeInt32 $ constant (TupRsingle scalarTypeInt32) 1) $ \_ -> do
let b' = mapArgs BCAJA b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,15 @@ data ObjectR f = ObjectR
, sharedObjPath :: {- LAZY -} FilePath
}

compile :: UID -> Module f -> LLVM Native (ObjectR f)
compile uid module' = do
compile :: UID -> ShortByteString -> Module f -> LLVM Native (ObjectR f)
compile uid name module' = do
cachePath <- cacheOfUID uid
let
ast = downcast module'
staticObjFile = cachePath <.> staticObjExt
sharedObjFile = cachePath <.> sharedObjExt
triple = fromMaybe BS.empty (LLVM.moduleTargetTriple ast)
datalayout = LLVM.moduleDataLayout ast
GlobalFunctionBody (Label name) _ = functionBody $ moduleMain module'
-- Lower the generated LLVM and produce an object file.
--
-- The 'staticObjPath' field is only lazily evaluated since the object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ prepareKernel env (NativeKernelMetadata envSize) fun args = do

let
go :: forall kenv f'. Int -> OpenKernelFun NativeKernel kenv f' -> SArgs env f' -> IO (Exists KernelCall)
go cursor (KernelFunBody (NativeKernel _ funLifetime)) ArgsNil
go cursor (KernelFunBody (NativeKernel funLifetime _ _ _)) ArgsNil
| cursor == cacheLineSize * 2 + envSize =
return $ Exists $ KernelCall @kenv (unsafeGetValue funLifetime) foreignPtr
| otherwise = internalError "Cursor and size do not match. prepareKernel and sizeOfEnv might be inconsistent."
Expand All @@ -84,7 +84,7 @@ touchKernel :: forall env f. NativeEnv env -> KernelFun NativeKernel f -> SArgs
touchKernel env = go
where
go :: OpenKernelFun NativeKernel kenv f' -> SArgs env f' -> IO ()
go (KernelFunBody (NativeKernel _ funLifetime)) ArgsNil = touchLifetime funLifetime
go (KernelFunBody (NativeKernel funLifetime _ _ _)) ArgsNil = touchLifetime funLifetime
go (KernelFunLam argR fun) (arg :>: args) = do
touchArg env argR arg
go fun args
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE OverloadedStrings #-}
Expand Down Expand Up @@ -29,6 +30,7 @@ import Data.Array.Accelerate.Type
import Data.Array.Accelerate.AST.Exp
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.AST.Kernel
import Data.Array.Accelerate.AST.Schedule
import Data.Array.Accelerate.AST.Schedule.Uniform
import Data.Array.Accelerate.Backend
import Data.Array.Accelerate.Error
Expand All @@ -48,7 +50,8 @@ import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.Execute.Marshal
import qualified LLVM.AST as LLVM
import LLVM.AST.Type.Function
import Data.ByteString.Short ( ShortByteString )
import Data.ByteString.Short ( ShortByteString, fromShort )
import qualified Data.ByteString.Char8 as Char8
import System.FilePath ( FilePath, (<.>) )
import System.IO.Unsafe
import Control.DeepSeq
Expand All @@ -62,13 +65,15 @@ import LLVM.AST.Type.Representation

data NativeKernel env where
NativeKernel
:: { kernelId :: {-# UNPACK #-} !UID
, kernelFunction :: !(Lifetime (FunPtr (KernelType env)))
:: { kernelFunction :: !(Lifetime (FunPtr (KernelType env)))
, kernelId :: {-# UNPACK #-} !ShortByteString
, kernelDescDetail :: String
, kernelDescBrief :: String
}
-> NativeKernel env

instance NFData' NativeKernel where
rnf' (NativeKernel !_ fn) = unsafeGetValue fn `seq` ()
rnf' (NativeKernel fn !_ s l) = unsafeGetValue fn `seq` rnf s `seq` rnf l

newtype NativeKernelMetadata f =
NativeKernelMetadata { kernelArgsSize :: Int }
Expand All @@ -82,11 +87,13 @@ instance IsKernel NativeKernel where
type KernelMetadata NativeKernel = NativeKernelMetadata

compileKernel env cluster args = unsafePerformIO $ evalLLVM defaultTarget $ do
module' <- codegen uid env cluster args
obj <- compile uid module'
module' <- codegen fullName env cluster args
obj <- compile uid fullName module'
funPtr <- link obj
return $ NativeKernel uid funPtr
return $ NativeKernel funPtr fullName detail brief
where
(name, detail, brief) = generateKernelNameAndDescription operationName cluster
fullName = fromString $ name ++ "-" ++ show uid
uid = hashOperation cluster args

kernelMetadata kernel = NativeKernelMetadata $ sizeOfEnv kernel
Expand All @@ -96,4 +103,23 @@ instance PrettyKernel NativeKernel where
where
go :: OpenKernelFun NativeKernel env t -> Adoc
go (KernelFunLam _ f) = go f
go (KernelFunBody kernel) = fromString $ take 16 $ show $ kernelId kernel
go (KernelFunBody (NativeKernel _ name "" _))
= fromString $ take 32 $ toString name
go (KernelFunBody (NativeKernel _ name detail brief))
= fromString (take 32 $ toString name)
<+> flatAlt (group $ line' <> "-- " <> desc)
("{- " <> desc <> "-}")
where desc = group $ flatAlt (fromString brief) (fromString detail)

toString :: ShortByteString -> String
toString = Char8.unpack . fromShort

operationName :: NativeOp t -> (Int, String, String)
operationName = \case
NMap -> (2, "map", "maps")
NBackpermute -> (1, "backpermute", "backpermutes")
NGenerate -> (2, "generate", "generates")
NPermute -> (5, "permute", "permutes")
NScanl1 -> (4, "scan", "scans")
NFold1 -> (3, "fold", "folds")
NFold2 -> (3, "fold", "folds")
13 changes: 4 additions & 9 deletions accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,11 @@ liftCodeGen = CodeGen . lift

codeGenFunction
:: forall arch f. (HasCallStack, Target arch, Intrinsic arch, Result f ~ Bool)
=> UID
-> Label
=> ShortByteString
-> (GlobalFunctionDefinition Bool -> GlobalFunctionDefinition f)
-> CodeGen arch ()
-> LLVM arch (Module f)
codeGenFunction uid name bind body = do
codeGenFunction name bind body = do
-- Execute the CodeGen monad and retrieve the code of the function and final state.
(code, st) <- runStateT
( runCodeGen $ do
Expand All @@ -147,20 +146,16 @@ codeGenFunction uid name bind body = do
}

let
fullName = name <> fromString ('_' : show uid)
fullName'
| Label s <- fullName = s
| otherwise = "<undefined>"
typeDefs = map (\(n,t) -> LLVM.TypeDefinition (downcast n) t) $ HashMap.toList $ typedefTable st
symbols = map LLVM.GlobalDefinition $ HashMap.elems $ symbolTable st
metadata = createMetadata $ metadataTable st

return $ Module
{ moduleName = fullName'
{ moduleName = name
, moduleSourceFileName = B.empty
, moduleDataLayout = targetDataLayout @arch
, moduleTargetTriple = targetTriple @arch
, moduleMain = bind $ Body (PrimType BoolPrimType) Nothing (GlobalFunctionBody fullName code)
, moduleMain = bind $ Body (PrimType BoolPrimType) Nothing (GlobalFunctionBody (Label name) code)
, moduleOtherDefinitions = typeDefs ++ symbols ++ metadata
}

Expand Down

0 comments on commit 53c3937

Please sign in to comment.