Skip to content

Commit

Permalink
Complete contractAst for shaped terms
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 5, 2025
1 parent b87abfc commit 32f1617
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 77 deletions.
163 changes: 90 additions & 73 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ import System.IO.Unsafe (unsafePerformIO)
import Type.Reflection (typeRep)

import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation (DropLen, TakeLen, permInverse)
import Data.Array.Mixed.Permutation (DropLen, Perm (..), TakeLen, permInverse)
import Data.Array.Mixed.Permutation qualified as Permutation
import Data.Array.Mixed.Shape
(ssxAppend, ssxFromShape, ssxReplicate, withKnownShX)
Expand Down Expand Up @@ -3851,7 +3851,7 @@ contractAst t = case t of
Ast.AstSum snat stk@(STKR (SNat @n) (STKScalar rRep))
v@(AstN2R TimesOp (Ast.AstTranspose tperm (Ast.AstReplicate _tk stkt t2))
(Ast.AstTranspose uperm (Ast.AstReplicate _uk stku u2)))
| Just Refl <- sameNat (Proxy @n) (Proxy @2) -> case (stkt, stku) of
| Just Refl <- sameNat (Proxy @n) (Proxy @2) -> case (stkt, stku) of
(STKR{}, STKR{}) ->
let interpretMatmul2 t3 u3 =
let t4 = contractAst t3
Expand Down Expand Up @@ -3904,77 +3904,94 @@ contractAst t = case t of
-- Ast.AstTranspose [1, 0]
-- $ interpretMatmul2 (astTranspose [1, 0] u2) (astTranspose [1, 0] t2)
_ -> astSum snat stk (contractAst v)
{-TODO: Ast.AstSum snat stk@(STKS (_ :$$ _ :$$ ZSS) (STKScalar @r rRep))
v@(AstN2S TimesOp (Ast.AstTransposeS tperm (Ast.AstReplicate _tk stkt t2))
(Ast.AstTransposeS uperm (Ast.AstReplicate _uk stku u2)))
-> case (stkt, stku) of
(STKS{}, STKS{}) ->
let perm10 = Permutation.makePerm @'[1, 0]
attemptMatmul2 :: forall m' n' p'.
AstTensor AstMethodLet s (TKS '[m', n'] r)
-> AstTensor AstMethodLet s (TKS '[n', p'] r)
-> AstTensor AstMethodLet s (TKS '[m', p'] r)
attemptMatmul2 t3 u3 =
let t4 = contractAst t3
u4 = contractAst u3
in case testEquality rRep (typeRep @Double) of
Just Refl -> Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> case testEquality rRep (typeRep @Float) of
Just Refl -> Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> case testEquality rRep (typeRep @Int64) of
Just Refl -> Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> case testEquality rRep (typeRep @CInt) of
Just Refl -> Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> astSum snat stk (contractAst v)
in if | Just Refl <- geq tperm (Permutation.makePerm @'[2, 1, 0])
, Just Refl <- geq uperm (Permutation.makePerm @'[1, 0]) ->
-- tk and uk are fine due to perms matching
attemptMatmul2 t2 u2
([1, 0], [2, 1, 0]) ->
attemptMatmul2 u2 t2
([2, 1, 0], [2, 0, 1]) ->
attemptMatmul2 t2 (astTransposeS perm10 u2)
([2, 0, 1], [2, 1, 0]) ->
attemptMatmul2 u2 (astTransposeS perm10 t2)
([1, 2, 0], [1, 0]) ->
attemptMatmul2 (astTransposeS perm10 t2) u2
([1, 0], [1, 2, 0]) ->
attemptMatmul2 (astTransposeS perm10 u2) t2
-- ([1, 2, 0], [2, 0, 1]) ->
-- attemptMatmul2 (astTransposeS perm10 t2)
-- (astTransposeS perm10 u2)
-- ([2, 0, 1], [1, 2, 0]) ->
-- attemptMatmul2 (astTransposeS perm10 u2)
-- (astTransposeS perm10 t2)
-- The variants below emerge when the whole term is transposed.
-- All overlap with variants above and the cheaper one is selected.
([2, 0, 1], [1, 2, 0]) ->
Ast.AstTransposeS perm10 $ attemptMatmul2 t2 u2
([1, 2, 0], [2, 0, 1]) ->
Ast.AstTransposeS perm10 $ attemptMatmul2 u2 t2
-- ([2, 0, 1], [2, 1, 0]) ->
-- Ast.AstTranspose [1, 0]
-- $ attemptMatmul2 t2 (astTransposeS perm10 u2)
-- ([2, 1, 0], [2, 0, 1]) ->
-- Ast.AstTranspose [1, 0]
-- $ attemptMatmul2 u2 (astTransposeS perm10 t2)
-- ([1, 2, 0], [1, 0]) ->
-- Ast.AstTranspose [1, 0]
-- $ attemptMatmul2 (astTransposeS perm10 u2) t2
-- ([1, 0], [2, 1, 0]) ->
-- Ast.AstTransposeS perm10
-- $ attemptMatmul2 (astTransposeS perm10 t2)
-- (astTransposeS perm10 u2)
-- ([2, 1, 0], [1, 0]) ->
-- Ast.AstTransposeS perm10
-- $ attemptMatmul2 (astTranspose 0S perm10 u2)
-- (astTransposeS perm10 t2)
_ -> astSum snat stk (contractAst v) -}
Ast.AstSum
snat@(SNat @m2)
stk@(STKS (SNat @n2 :$$ SNat @p2 :$$ ZSS) (STKScalar @r rRep))
v@(AstN2S TimesOp (Ast.AstTransposeS @permt permt
(Ast.AstReplicate (SNat @kt) (STKS @sht _ _) t2))
(Ast.AstTransposeS @permu permu
(Ast.AstReplicate (SNat @ku) (STKS @shu _ _) u2))) ->
let perm10 = Permutation.makePerm @'[1, 0]
attemptMatmul2
:: forall m' n' p'. (KnownNat m', KnownNat n', KnownNat p')
=> AstTensor AstMethodLet s (TKS '[m', n'] r)
-> AstTensor AstMethodLet s (TKS '[n', p'] r)
-> Maybe (AstTensor AstMethodLet s (TKS '[m', p'] r))
attemptMatmul2 t3 u3 =
let t4 = contractAst t3
u4 = contractAst u3
in case testEquality rRep (typeRep @Double) of
Just Refl ->
Just $ Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> case testEquality rRep (typeRep @Float) of
Just Refl ->
Just $ Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> case testEquality rRep (typeRep @Int64) of
Just Refl ->
Just $ Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> case testEquality rRep (typeRep @CInt) of
Just Refl ->
Just $ Ast.AstMatmul2S
(SNat @m') (SNat @n') (SNat @p') t4 u4
_ -> Nothing
in fromMaybe (astSum snat stk (contractAst v))
$ case (permt, permu) of
( SNat' @2 `PCons` SNat' @1 `PCons` SNat' @0 `PCons` PNil
,SNat' @1 `PCons` SNat' @0 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl
:: Permutation.PermutePrefix permt (kt ': sht)
:~: [m2, n2, p2]) $
gcastWith (unsafeCoerceRefl
:: Permutation.PermutePrefix permu (ku ': shu)
:~: [m2, n2, p2]) $
-- Sadly, the casts below, though implied by the permutations
-- (as edundantly spelled out by the casts above) are required
-- to make it type-check and they easily mask bugs, too.
-- In the result, this is as type-unsafe as ranked code would be.
gcastWith (unsafeCoerceRefl :: sht :~: [n2, m2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [m2, p2]) $
attemptMatmul2 t2 u2
( SNat' @1 `PCons` SNat' @0 `PCons` PNil
,SNat' @2 `PCons` SNat' @1 `PCons` SNat' @0 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl :: sht :~: [m2, p2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [n2, m2]) $
attemptMatmul2 u2 t2
( SNat' @2 `PCons` SNat' @1 `PCons` SNat' @0 `PCons` PNil
,SNat' @2 `PCons` SNat' @0 `PCons` SNat' @1 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl :: sht :~: [n2, m2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [p2, m2]) $
attemptMatmul2 t2 (astTransposeS perm10 u2)
( SNat' @2 `PCons` SNat' @0 `PCons` SNat' @1 `PCons` PNil
,SNat' @2 `PCons` SNat' @1 `PCons` SNat' @0 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl :: sht :~: [p2, m2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [n2, m2]) $
attemptMatmul2 u2 (astTransposeS perm10 t2)
( SNat' @1 `PCons` SNat' @2 `PCons` SNat' @0 `PCons` PNil
,SNat' @1 `PCons` SNat' @0 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl :: sht :~: [m2, n2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [m2, p2]) $
attemptMatmul2 (astTransposeS perm10 t2) u2
( SNat' @1 `PCons` SNat' @0 `PCons` PNil
,SNat' @1 `PCons` SNat' @2 `PCons` SNat' @0 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl :: sht :~: [m2, p2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [m2, n2]) $
attemptMatmul2 (astTransposeS perm10 u2) t2
( SNat' @1 `PCons` SNat' @2 `PCons` SNat' @0 `PCons` PNil
,SNat' @2 `PCons` SNat' @0 `PCons` SNat' @1 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl :: sht :~: [m2, n2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [p2, m2]) $
attemptMatmul2 (astTransposeS perm10 t2)
(astTransposeS perm10 u2)
( SNat' @2 `PCons` SNat' @0 `PCons` SNat' @1 `PCons` PNil
,SNat' @1 `PCons` SNat' @2 `PCons` SNat' @0 `PCons` PNil ) ->
gcastWith (unsafeCoerceRefl :: sht :~: [p2, m2]) $
gcastWith (unsafeCoerceRefl :: shu :~: [m2, n2]) $
attemptMatmul2 (astTransposeS perm10 u2)
(astTransposeS perm10 t2)
_ -> Nothing
Ast.AstSum _ (STKR (SNat @n) _) (AstN2R TimesOp t2 u)
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
Ast.AstDot0R (SNat @1) (contractAst t2) (contractAst u)
Expand Down
5 changes: 3 additions & 2 deletions src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes, DerivingVia, ImpredicativeTypes,
UndecidableInstances, UndecidableSuperClasses #-}
UndecidableInstances, UndecidableSuperClasses, ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | Some fundamental type families and types.
Expand Down Expand Up @@ -49,7 +49,7 @@ import Data.Int (Int64)
import Data.Kind (Type)
import Data.List (sort)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (gcastWith, testEquality, (:~:))
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import Data.Vector.Storable qualified as V
import Foreign.C (CInt)
import Foreign.Storable (Storable (..))
Expand All @@ -61,6 +61,7 @@ import GHC.TypeLits
, SNat
, fromSNat
, pattern SNat
, sameNat
, type (+)
, type (-)
, withSomeSNat
Expand Down
4 changes: 2 additions & 2 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1216,9 +1216,9 @@ testMatmul2PPS = do
printArtifactPrimalPretty renames artifactRev
@?= "\\x1 -> ssum (stranspose (sreplicate (tproject1 m1)) * stranspose (sreplicate (tproject2 m1)))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\m2 x1 -> tpair (ssum (stranspose (sreplicate (tproject2 m1)) * stranspose (sreplicate m2)), ssum (stranspose (sreplicate (tproject1 m1)) * stranspose (sreplicate m2)))"
@?= "\\m2 x1 -> tpair (smatmul2 m2 (stranspose (tproject2 m1)), smatmul2 (stranspose (tproject1 m1)) m2)"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> ssum (stranspose (sreplicate (tproject1 m1)) * stranspose (sreplicate (tproject2 m1)))"
@?= "\\x1 -> smatmul2 (tproject1 m1) (tproject2 m1)"

testAddSpeedBig :: Assertion
testAddSpeedBig =
Expand Down

0 comments on commit 32f1617

Please sign in to comment.