Skip to content

Commit

Permalink
add separate unit constructor to bcan2; fixes "unequal pairing" error…
Browse files Browse the repository at this point in the history
… on backpermutes on tuples
  • Loading branch information
dpvanbalen committed Mar 18, 2024
1 parent 00e2ada commit 3ef7bbd
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ instance (StaticClusterAnalysis op, EnvF (JustAccumulator op) ~ EnvF op) => Stat
addTup x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ addTup $ coerce x
unitToVar x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ unitToVar $ coerce x
varToUnit x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToUnit $ coerce x
pairinfo x y = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ pairinfo (coerce x) (coerce y)

deriving instance (Eq (BackendClusterArg2 op x y)) => Eq (BackendClusterArg2 (JustAccumulator op) x y)
deriving instance (Show (BackendClusterArg2 op x y)) => Show (BackendClusterArg2 (JustAccumulator op) x y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ import Data.Array.Accelerate.Trafo.Desugar (desugarAlloc)
import qualified Debug.Trace
import GHC.Stack
import Data.Array.Accelerate.AST.Idx (Idx(..))
import Data.Array.Accelerate.Pretty.Operation (prettyFun)
import Data.Array.Accelerate.Pretty.Exp (PrettyEnv(..), Val (Push))
import Prettyprinter (pretty)
import Unsafe.Coerce (unsafeCoerce)

data NativeOp t where
NMap :: NativeOp (Fun' (s -> t) -> In sh s -> Out sh t -> ())
Expand Down Expand Up @@ -316,11 +320,16 @@ data IndexPermutation env where
type IterationDepth = Int
instance Show (BackendClusterArg2 NativeOp env arg) where
show (BCAN2 i d) = "{ depth = " <> show d <> ", perm = " <> show i <> " }"
show IsUnit = "()"
instance Show (IndexPermutation env) where
show (BP sh1 sh2 _ _) = show (rank sh1) <> "->" <> show (rank sh2)
show (BP sh1 sh2 f _) = show (rank sh1) <> "->" <> show (rank sh2) <> ": " <> show (prettyFun (infenv 0) f)
where
infenv i = unsafeCoerce $ infenv (i+1) `Push` (pretty $ "x"<>show i)
instance StaticClusterAnalysis NativeOp where
data BackendClusterArg2 NativeOp env arg where
BCAN2 :: Maybe (IndexPermutation env) -> IterationDepth -> BackendClusterArg2 NativeOp env arg
IsUnit ::BackendClusterArg2 NativeOp env (m sh ()) -- units don't get backpermuted because they don't exist
def (ArgArray _ (ArrayR _ TupRunit) _ _) _ _ = IsUnit
def _ _ (BCAN i) = BCAN2 Nothing i
unitToVar = bcan2id
varToUnit = bcan2id
Expand All @@ -344,12 +353,17 @@ instance StaticClusterAnalysis NativeOp where
onOp NBackpermute (BCAN2 Nothing d :>: ArgsNil) (ArgFun f :>: ArgArray In (ArrayR shrI _) _ _ :>: ArgArray Out (ArrayR shrO _) sh _ :>: ArgsNil) _
= BCAN2 Nothing d :>: BCAN2 (Just (BP shrO shrI f sh)) d :>: BCAN2 Nothing d :>: ArgsNil
onOp NGenerate (bp :>: ArgsNil) (_:>:ArgArray Out (ArrayR shR _) _ _ :>:ArgsNil) _ =
bcan2id bp :>: bp :>: ArgsNil -- storing the bp in the function argument. Probably not required, could just take it from the array one during codegen
bcan2id bp :>: bp :>: ArgsNil -- store the bp in the function, because there is no input array
onOp NPermute ArgsNil (_:>:_:>:_:>:_:>:ArgArray In (ArrayR shR _) _ _ :>:ArgsNil) _ =
BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing (rank shR) :>: ArgsNil
onOp NFold2 (bp :>: ArgsNil) (_ :>: ArgArray In _ fs _ :>: _ :>: ArgsNil) _ = BCAN2 Nothing 0 :>: fold2bp bp (case fs of TupRpair _ x -> x) :>: bp :>: ArgsNil
onOp NFold1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 0 :>: fold1bp bp :>: bp :>: ArgsNil
onOp NScanl1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 0 :>: bcan2id bp :>: bp :>: ArgsNil
pairinfo IsUnit x = shrinkOrGrow x
pairinfo x IsUnit = shrinkOrGrow x
pairinfo a b = if shrinkOrGrow a == b then shrinkOrGrow a else error $ "pairing unequal: " <> show a <> ", " <> show b



bcan2id :: BackendClusterArg2 NativeOp env arg -> BackendClusterArg2 NativeOp env arg'
bcan2id (BCAN2 Nothing i) = BCAN2 Nothing i
Expand All @@ -372,10 +386,13 @@ fold2bp (BCAN2 (Just (BP shr1 shr2 g sh)) i) foldsize = flip BCAN2 (i+1) $ Just
(TupRpair sh foldsize)

instance Eq (BackendClusterArg2 NativeOp env arg) where
IsUnit == IsUnit = True
x@(BCAN2 p i) == y@(BCAN2 p' i') = f $ p == p' && i == i'
where
f True = True
f False = False
_ == _ = False

instance Eq (IndexPermutation env) where
(BP shr1 shr2 f _) == (BP shr1' shr2' f' _)
| Just Refl <- matchShapeR shr1 shr1'
Expand Down
21 changes: 13 additions & 8 deletions accelerate-llvm-native/test/nofib/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@ import Data.Array.Accelerate.Unsafe
main :: IO ()
main = do
let xs = fromList (Z :. 10) [1 :: Int ..]
let ys = use xs


let f = T2 (map (+1) ys) (map (*2) $ reverse ys)
-- let f = --map (\(T2 a b) -> a + b) $
-- zip ys $ reverse ys
putStrLn $ test @UniformScheduleFun @NativeKernel f
print $ run @Native f
let ys = map (\x -> T2 x x) $
use xs


-- let f = T2 (map (+1) ys) (map (*2) $ reverse ys)
-- let f = sum $ map (\(T2 a b) -> a + b) $
-- zip (reverse $ map (+1) (reverse ys)) $ reverse ys
let Z_ ::. n = shape ys
let f'' = backpermute (Z_ ::. 5 ::. 2) (\(I2 x y) -> I1 (x*y)) ys
let f' = replicate (Z_ ::. All_ ::. n) ys
let f = zip (reverse ys) ys
putStrLn $ test @UniformScheduleFun @NativeKernel $ backpermute (Z_ ::. 5) (\x->x) (reverse ys)
-- print $ run @Native $ f

-- putStrLn "generate:"
-- let f = generate (I1 10) (\(I1 x0) -> 10 :: Exp Int)
Expand Down

0 comments on commit 3ef7bbd

Please sign in to comment.