Skip to content

Commit

Permalink
fixed crashes related to operations on tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed Apr 30, 2024
1 parent f6e6209 commit f38b65a
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 80,720 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ instance EvalOp NativeOp where
i <- intOfIndex shr2 sh' sh2
readBuffer tp TypeInt (aprjBuffer (unsafeCoerce buf) gamma) (op TypeInt i)
| otherwise = pure CN
readInput tp _ (TupRsingle buf) gamma a (_,i,_) = -- assuming no bp, and I'll just make a read at every depth?
-- lift $ CJ . ir tp <$> readBuffer tp TypeInt (aprjBuffer (unsafeCoerce buf) gamma) (op TypeInt i)
-- second attempt, the above segfaults: never read instead
pure CN
-- also segfaults :(
{- weird: this implies that a is a `IsUnit`, but it happens on Int
error $ show tp <> case buf of
TupRsingle _ -> "single"
TupRpair _ _ -> "pair"
TupRunit -> "unit" -}
readInput _ _ _ _ _ _ = error "not single"

evalOp :: (Int, Operands Int, [Operands Int])
Expand Down Expand Up @@ -422,6 +432,7 @@ instance EvalOp (JustAccumulator NativeOp) where
subtup (SubTupRpair a b) (TupRpair x y) = TupRpair (subtup @(JustAccumulator NativeOp) a x) (subtup @(JustAccumulator NativeOp) b y)
subtup _ _ = error "subtup-pair with non-pair TypeR"

readInput ty sh _ gamma (BCA2JA IsUnit) _ = pure TupRunit
readInput ty sh _ gamma (BCA2JA (BCAN2 Nothing d)) _ = StateT $ \(acc,ls) -> pure (TupRsingle ty, (acc, merge ls sh gamma))
readInput ty sh _ gamma (BCA2JA (BCAN2 (Just (BP _ _ _ ls')) d)) _ = StateT $ \(acc,ls) -> pure (TupRsingle ty, (acc, merge ls ls' gamma))

Expand Down Expand Up @@ -480,11 +491,11 @@ instance (StaticClusterAnalysis op, EnvF (JustAccumulator op) ~ EnvF op) => Stat
varToValue x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToValue $ coerce x
varToSh x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToSh $ coerce x
shToVar x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ shToVar $ coerce x
shrinkOrGrow x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ shrinkOrGrow $ coerce x
shrinkOrGrow a b x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ shrinkOrGrow a b $ coerce x
addTup x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ addTup $ coerce x
unitToVar x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ unitToVar $ coerce x
varToUnit x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToUnit $ coerce x
pairinfo x y = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ pairinfo (coerce x) (coerce y)
pairinfo a x y = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ pairinfo a (coerce x) (coerce y)

deriving instance (Eq (BackendClusterArg2 op x y)) => Eq (BackendClusterArg2 (JustAccumulator op) x y)
deriving instance (Show (BackendClusterArg2 op x y)) => Show (BackendClusterArg2 (JustAccumulator op) x y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module : Data.Array.Accelerate.LLVM.Native.Accelerate
Expand Down Expand Up @@ -313,7 +314,7 @@ instance NFData' (BackendClusterArg NativeOp) where

instance ShrinkArg (BackendClusterArg NativeOp) where
shrinkArg _ (BCAN i) = BCAN i
deadArg (BCAN _) = BCAN 0
deadArg (BCAN i) = BCAN i

data IndexPermutation env where
BP :: ShapeR sh1 -> ShapeR sh2 -> Fun env (sh1 -> sh2) -> GroundVars env sh1 -> IndexPermutation env
Expand All @@ -327,9 +328,9 @@ instance Show (IndexPermutation env) where
infenv i = unsafeCoerce $ infenv (i+1) `Push` (pretty $ "x"<>show i)
instance StaticClusterAnalysis NativeOp where
data BackendClusterArg2 NativeOp env arg where
BCAN2 :: Maybe (IndexPermutation env) -> IterationDepth -> BackendClusterArg2 NativeOp env arg
BCAN2 :: Maybe (IndexPermutation env) -> IterationDepth -> BackendClusterArg2 NativeOp env arg -- non-array args just get ths one with '999', should make a new constructor for them
IsUnit ::BackendClusterArg2 NativeOp env (m sh ()) -- units don't get backpermuted because they don't exist
def (ArgArray _ (ArrayR _ TupRunit) _ _) _ _ = IsUnit
def (ArgArray _ (ArrayR _ TupRunit) _ TupRunit) _ _ = IsUnit
def _ _ (BCAN i) = BCAN2 Nothing i
unitToVar = bcan2id
varToUnit = bcan2id
Expand All @@ -343,36 +344,53 @@ instance StaticClusterAnalysis NativeOp where
varToValue = bcan2id
varToSh = bcan2id
shToVar = bcan2id
shrinkOrGrow = bcan2id
shrinkOrGrow _ (ArgArray _ (ArrayR _ TupRunit) _ _) _ = IsUnit
shrinkOrGrow _ a IsUnit = error "can't grow from unit"
shrinkOrGrow _ _ x = bcan2id x
addTup = bcan2id
inToVar = bcan2id
-- onOp propagates the backpermute information from the outputs to the inputs of each operation
onOp NMap (bp :>: ArgsNil) _ _ = BCAN2 Nothing undefined :>: bcan2id bp :>: bp :>: ArgsNil
onOp NBackpermute (BCAN2 (Just bp@(BP shr1 shr2 g sh)) d :>: ArgsNil) (ArgFun f :>: ArgArray In (ArrayR shrI _) _ _ :>: ArgArray Out (ArrayR shrO _) _ _ :>: ArgsNil) _
| Just Refl <- matchShapeR shrO shr2 = BCAN2 Nothing 0 :>: BCAN2 (Just (BP shr1 shrI (compose f g) sh)) d :>: BCAN2 (Just bp) d :>: ArgsNil
| Just Refl <- matchShapeR shrO shr2 = BCAN2 Nothing 999 :>: BCAN2 (Just (BP shr1 shrI (compose f g) sh)) d :>: BCAN2 (Just bp) d :>: ArgsNil
| otherwise = error "BP shapeR doesn't match backpermute output shape"
onOp NBackpermute (BCAN2 Nothing d :>: ArgsNil) (ArgFun f :>: ArgArray In (ArrayR shrI _) _ _ :>: ArgArray Out (ArrayR shrO _) sh _ :>: ArgsNil) _
= BCAN2 Nothing d :>: BCAN2 (Just (BP shrO shrI f sh)) d :>: BCAN2 Nothing d :>: ArgsNil
onOp NGenerate (bp :>: ArgsNil) (_:>:ArgArray Out (ArrayR shR _) _ _ :>:ArgsNil) _ =
bcan2id bp :>: bp :>: ArgsNil -- store the bp in the function, because there is no input array
onOp NPermute ArgsNil (_:>:_:>:_:>:_:>:ArgArray In (ArrayR shR _) _ _ :>:ArgsNil) _ =
BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing (rank shR) :>: ArgsNil
onOp NFold2 (bp :>: ArgsNil) (_ :>: ArgArray In _ fs _ :>: _ :>: ArgsNil) _ = BCAN2 Nothing 0 :>: fold2bp bp (case fs of TupRpair _ x -> x) :>: bp :>: ArgsNil
onOp NFold1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 0 :>: fold1bp bp :>: bp :>: ArgsNil
onOp NScanl1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 0 :>: bcan2id bp :>: bp :>: ArgsNil
pairinfo IsUnit x = shrinkOrGrow x
pairinfo x IsUnit = shrinkOrGrow x
pairinfo a b = if shrinkOrGrow a == b then shrinkOrGrow a else error $ "pairing unequal: " <> show a <> ", " <> show b
BCAN2 Nothing 999 :>: BCAN2 Nothing 999 :>: BCAN2 Nothing 999 :>: BCAN2 Nothing 999 :>: BCAN2 Nothing (rank shR) :>: ArgsNil
onOp NFold2 (bp :>: ArgsNil) (_ :>: ArgArray In _ fs _ :>: _ :>: ArgsNil) _ = BCAN2 Nothing 999 :>: fold2bp bp (case fs of TupRpair _ x -> x) :>: bp :>: ArgsNil
onOp NFold1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 999 :>: fold1bp bp :>: bp :>: ArgsNil
onOp NScanl1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 999 :>: bcan2id bp :>: bp :>: ArgsNil
pairinfo _ IsUnit IsUnit = error "can't yet"
pairinfo a@(ArgArray m (ArrayR shr (TupRpair l r)) sh (TupRpair bufl bufr)) IsUnit x = shrinkOrGrow (ArgArray m (ArrayR shr r) sh bufr) a x
pairinfo a@(ArgArray m (ArrayR shr (TupRpair l r)) sh (TupRpair bufl bufr)) x IsUnit = shrinkOrGrow (ArgArray m (ArrayR shr l) sh bufl) a x
pairinfo _ x y = if bcan2id x == y then bcan2id x else
case (x,y) of
-- these two cases test whether the function is id, but it's still possible that one of the arguments got backpermuted to be smaller.
-- In that case we 'should' error here, but we can't check it
(BCAN2 Nothing xd, BCAN2 (Just (BP _ _ yp _)) yd)
| xd == yd
, Just Refl <- isIdentity yp
-> bcan2id y
(BCAN2 (Just (BP _ _ yp _)) yd, BCAN2 Nothing xd)
| xd == yd
, Just Refl <- isIdentity yp
-> bcan2id x
_ -> error $ "pairing unequal: " <> show x <> ", " <> show y



bcan2id :: BackendClusterArg2 NativeOp env arg -> BackendClusterArg2 NativeOp env arg'
bcan2id (BCAN2 Nothing i) = BCAN2 Nothing i
bcan2id (BCAN2 (Just (BP a b c d)) i) = BCAN2 (Just (BP a b c d)) i
bcan2id IsUnit = unsafeCoerce IsUnit -- error "bcan2id unit"

fold1bp :: BackendClusterArg2 NativeOp env (Out sh e) -> BackendClusterArg2 NativeOp env (In sh e)
fold1bp (BCAN2 Nothing i) = BCAN2 Nothing i
fold1bp (BCAN2 (Just (BP shr1 shr2 g sh)) i) = flip BCAN2 i $ Just $ BP shr1 shr2 (error "todo: multiply the innermost (outer constructor) dimension by the workstealsize" g) undefined
fold1bp IsUnit = error "unit"

fold2bp :: BackendClusterArg2 NativeOp env (Out sh e) -> GroundVars env Int -> BackendClusterArg2 NativeOp env (In (sh,Int) e)
fold2bp (BCAN2 Nothing i) _ = BCAN2 Nothing (i+1)
Expand All @@ -385,13 +403,11 @@ fold2bp (BCAN2 (Just (BP shr1 shr2 g sh)) i) foldsize = flip BCAN2 (i+1) $ Just
Body $ Pair (weakenE (weakenSucc' weakenId) e) (Evar $ Var scalarTypeInt ZeroIdx)
_ -> error "function type in body or non-body below lam in sh1 -> sh2")
(TupRpair sh foldsize)
fold2bp IsUnit _ = error "unit"

instance Eq (BackendClusterArg2 NativeOp env arg) where
IsUnit == IsUnit = True
x@(BCAN2 p i) == y@(BCAN2 p' i') = f $ p == p' && i == i'
where
f True = True
f False = False
x@(BCAN2 p i) == y@(BCAN2 p' i') = p == p' && i == i'
_ == _ = False

instance Eq (IndexPermutation env) where
Expand Down
Loading

0 comments on commit f38b65a

Please sign in to comment.