Skip to content

Commit

Permalink
Generalize contraction from recognizing AstDot1InS to also AstMatvecmulS
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 14, 2025
1 parent 217d626 commit 965cabb
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 10 deletions.
48 changes: 48 additions & 0 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3015,6 +3015,54 @@ contractAst t = case t of
Ast.AstDot0S (contractAst t2) (contractAst u)
Ast.AstSum _ (STKS ZSS _) (Ast.AstReshapeS _ (AstN2S TimesOp t2 u)) ->
Ast.AstDot0S (contractAst t2) (contractAst u)
Ast.AstSum
n@(SNat @n)
(STKS (m@(SNat @m) :$$ ZSS) _)
(AstN2S TimesOp
(Ast.AstTransposeS @perm @sh
(SNat' @1 `Permutation.PCons` SNat' @0
`Permutation.PCons` Permutation.PNil)
(Ast.AstReplicate _ STKS{} t2))
u) ->
gcastWith (unsafeCoerceRefl :: Permutation.Permute perm [n, m] :~: sh) $
let perm10 = Permutation.makePerm @'[1, 0]
in Ast.AstMatvecmulS m n (contractAst $ astTransposeS perm10 u)
(contractAst t2)
Ast.AstSum
n@(SNat @n)
(STKS (m@(SNat @m) :$$ ZSS) _)
(AstN2S TimesOp
t2
(Ast.AstTransposeS @perm2 @sh2
(SNat' @1 `Permutation.PCons` SNat' @0
`Permutation.PCons` Permutation.PNil)
(Ast.AstReplicate _ STKS{} u))) ->
gcastWith (unsafeCoerceRefl :: Permutation.Permute perm2 [n, m] :~: sh2) $
let perm10 = Permutation.makePerm @'[1, 0]
in Ast.AstMatvecmulS m n (contractAst $ astTransposeS perm10 t2)
(contractAst u)
Ast.AstSum
n@(SNat @n)
(STKS (m@(SNat @m) :$$ ZSS) _)
(Ast.AstTransposeS @perm @sh
(SNat' @1 `Permutation.PCons` SNat' @0
`Permutation.PCons` Permutation.PNil)
(AstN2S TimesOp
(Ast.AstReplicate _ STKS{} t2)
u)) ->
gcastWith (unsafeCoerceRefl :: Permutation.Permute perm [n, m] :~: sh) $
Ast.AstMatvecmulS m n (contractAst u) (contractAst t2)
Ast.AstSum
n@(SNat @n)
(STKS (m@(SNat @m) :$$ ZSS) _)
(Ast.AstTransposeS @perm @sh
(SNat' @1 `Permutation.PCons` SNat' @0
`Permutation.PCons` Permutation.PNil)
(AstN2S TimesOp
t2
(Ast.AstReplicate _ STKS{} u))) ->
gcastWith (unsafeCoerceRefl :: Permutation.Permute perm [n, m] :~: sh) $
Ast.AstMatvecmulS m n (contractAst t2) (contractAst u)
Ast.AstSum
n@(SNat @n)
(STKS (m@(SNat @m) :$$ ZSS) _)
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/Ops.hs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class ( Num (IntOf target)
-- differ in types but all are far from matmul2.
-- rmatvecmul m v = rbuild1 (rlength m) (\i -> rdot0 v (m ! [i]))
-- rmatvecmul m v = rflatten $ rmap1 (rreplicate 1 . rdot0 v) m
rmatvecmul m v = rsum (rtranspose [1,0] (rreplicate (rlength m) v * m))
rmatvecmul m v = rsum (rtr (rreplicate (rlength m) v * m))
rmatmul2 :: (GoodScalar r, Numeric r)
=> target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
-- How to generalize to tmatmul (#69)?
Expand Down Expand Up @@ -727,7 +727,7 @@ class ( Num (IntOf target)
smatvecmul :: forall r m n. (GoodScalar r, KnownNat m, KnownNat n)
=> target (TKS '[m, n] r) -> target (TKS '[n] r)
-> target (TKS '[m] r)
smatvecmul m v = ssum (stranspose @_ @'[1, 0] (sreplicate @_ @m v * m))
smatvecmul m v = ssum (str (sreplicate @_ @m v * m))
smatmul2 :: forall r n m p.
(GoodScalar r, Numeric r, KnownNat n, KnownNat m, KnownNat p)
=> target (TKS '[m, n] r) -> target (TKS '[n, p] r)
Expand Down
4 changes: 2 additions & 2 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1137,9 +1137,9 @@ testMatvecmulPP = do
printArtifactPrimalPretty renames artifactRev
@?= "\\m1 -> rfromS (ssum @_ @3 (str (sfromR (rreplicate 2 (tproject2 m1))) * str (sfromR (tproject1 m1))))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\v2 m1 -> tfromS (tpair (sreplicate @_ @2 (sfromR (tproject2 m1)) * str (sreplicate @_ @3 (sfromR v2)), ssum @_ @2 (sfromR (tproject1 m1) * str (sreplicate @_ @3 (sfromR v2)))))"
@?= "\\v2 m1 -> tfromS (tpair (sreplicate @_ @2 (sfromR (tproject2 m1)) * str (sreplicate @_ @3 (sfromR v2)), smatvecmul (str (sfromR (tproject1 m1))) (sfromR v2)))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\m1 -> rfromS (ssdot1In (sreplicate @_ @2 (sfromR (tproject2 m1))) (sfromR (tproject1 m1)))"
@?= "\\m1 -> rfromS (smatvecmul (sfromR (tproject1 m1)) (sfromR (tproject2 m1)))"

testMatmul2PP :: Assertion
testMatmul2PP = do
Expand Down
12 changes: 6 additions & 6 deletions test/simplified/TestMnistFCNNR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ tensorADOnceMnistTests2 = testGroup "Ranked2 Once MNIST tests"
, mnistTestCase2VTO "VTO2 artificial 1 2 3 4 5" 1 2 3 4 5 500
(0.884 :: Float)
, mnistTestCase2VTO "VTO2 artificial 5 4 3 2 1" 5 4 3 2 1 500
(0.8225 :: Double)
(0.6739999999999999 :: Double)
, mnistTestCase2VTO "VTO2 1 epoch, 0 batch" 1 0 300 100 0.02 500
(1 :: Float)
]
Expand Down Expand Up @@ -785,9 +785,9 @@ testVT2OPP = do
printArtifactPrimalPretty renames artifactRev
@?= "\\m1 -> let m2 = str (sreplicate @_ @5 (ssum @_ @3 (str (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0))) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))))) ; m3 = str (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) ; m4 = str (sreplicate @_ @2 (ssum @_ @4 (m2 * m3) + sfromR (tproject2 (tproject2 (tproject1 m1))))) in rfromS (ssum @_ @5 (m4 * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1)))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\v5 m1 -> tfromS (let m2 = str (sreplicate @_ @5 (ssdot1In (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0))) (sfromR (tproject1 (tproject1 (tproject1 m1)))) + sfromR (tproject2 (tproject1 (tproject1 m1))))) ; m3 = str (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) ; v6 = ssum @_ @2 (sfromR (tproject1 (tproject2 m1)) * str (sreplicate @_ @5 (sfromR v5))) ; v8 = ssdot1In m3 (sreplicate @_ @4 v6) in tpair (tpair (tpair (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)) * str (sreplicate @_ @3 v8), v8), tpair (str (scast m2) * str (sreplicate @_ @4 (scast v6)), v6)), tpair (sreplicate @_ @2 (ssum @_ @4 (m2 * m3) + sfromR (tproject2 (tproject2 (tproject1 m1)))) * str (sreplicate @_ @5 (sfromR v5)), v5)))"
@?= "\\v5 m1 -> tfromS (let m2 = str (sreplicate @_ @5 (smatvecmul (sfromR (tproject1 (tproject1 (tproject1 m1)))) (sreplicate @_ @3 (sscalar 7.0)) + sfromR (tproject2 (tproject1 (tproject1 m1))))) ; m3 = str (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) ; v6 = smatvecmul (str (sfromR (tproject1 (tproject2 m1)))) (sfromR v5) ; v8 = smatvecmul m3 v6 in tpair (tpair (tpair (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)) * str (sreplicate @_ @3 v8), v8), tpair (str (scast m2) * str (sreplicate @_ @4 (scast v6)), v6)), tpair (sreplicate @_ @2 (ssum @_ @4 (m2 * m3) + sfromR (tproject2 (tproject2 (tproject1 m1)))) * str (sreplicate @_ @5 (sfromR v5)), v5)))"
printArtifactSimple renames (simplifyArtifact artifactRev)
@?= "\\v5 m1 -> tfromS (tlet (str (sreplicate @_ @5 (ssdot1In (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0))) (sfromR (tproject1 (tproject1 (tproject1 m1)))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) (\\m2 -> tlet (str (scast (sfromR (tproject1 (tproject2 (tproject1 m1)))))) (\\m3 -> tlet (ssum @_ @2 (sfromR (tproject1 (tproject2 m1)) * str (sreplicate @_ @5 (sfromR v5)))) (\\v6 -> tlet (ssdot1In m3 (sreplicate @_ @4 v6)) (\\v8 -> tpair (tpair (tpair (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)) * str (sreplicate @_ @3 v8), v8), tpair (str (scast m2) * str (sreplicate @_ @4 (scast v6)), v6)), tpair (sreplicate @_ @2 (ssum @_ @4 (m2 * m3) + sfromR (tproject2 (tproject2 (tproject1 m1)))) * str (sreplicate @_ @5 (sfromR v5)), v5)))))))"
@?= "\\v5 m1 -> tfromS (tlet (str (sreplicate @_ @5 (smatvecmul (sfromR (tproject1 (tproject1 (tproject1 m1)))) (sreplicate @_ @3 (sscalar 7.0)) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) (\\m2 -> tlet (str (scast (sfromR (tproject1 (tproject2 (tproject1 m1)))))) (\\m3 -> tlet (smatvecmul (str (sfromR (tproject1 (tproject2 m1)))) (sfromR v5)) (\\v6 -> tlet (smatvecmul m3 v6) (\\v8 -> tpair (tpair (tpair (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)) * str (sreplicate @_ @3 v8), v8), tpair (str (scast m2) * str (sreplicate @_ @4 (scast v6)), v6)), tpair (sreplicate @_ @2 (ssum @_ @4 (m2 * m3) + sfromR (tproject2 (tproject2 (tproject1 m1)))) * str (sreplicate @_ @5 (sfromR v5)), v5)))))))"

testVT2OPPNonLin :: Assertion
testVT2OPPNonLin = do
Expand All @@ -813,7 +813,7 @@ testVT2OPPNonLin = do
"\\dummy" ++ " -> " ++ printAstSimple renames ast3
@?= "\\dummy -> rfromS (tlet (exp (ssum @_ @5 (str (sreplicate @_ @2 (tlet (ssum @_ @4 (str (sreplicate @_ @5 (tlet (ssum @_ @3 (tfromPrimal (STKS [3,4] STKScalar) (str (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)))) * tfromPrimal (STKS [3,4] STKScalar) (tconcrete (FTKS [3,4] FTKScalar) (sfromListLinear [3,4] [1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0]))) + tfromPrimal (STKS [4] STKScalar) (tconcrete (FTKS [4] FTKScalar) (sfromListLinear [4] [1.0,2.0,3.0,4.0]))) (\\v2 -> tlet (tfromPrimal (STKS [4] STKScalar) (recip (sreplicate @_ @4 (sscalar 1.0) + exp (negate (sfromR (tprimalPart (rfromS v2))))))) (\\v3 -> tlet (tfromDual (tdualPart (STKS [4] STKScalar) (tfromPrimal (STKS [4] STKScalar) (sfromR (tprimalPart (rfromS v3)) * (sreplicate @_ @4 (sscalar 1.0) - sfromR (tprimalPart (rfromS v3))))) * tdualPart (STKS [4] STKScalar) (sfromR (tfromDual (tdualPart (STKR (SNat @1) STKScalar) (rfromS v2)))))) (\\v4 -> sfromR (tfromPrimal (STKR (SNat @1) STKScalar) (tprimalPart (rfromS v3))) + v4))))) * tfromPrimal (STKS [4,5] STKScalar) (tconcrete (FTKS [4,5] FTKScalar) (sfromListLinear [4,5] [1.0,1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0,3.0,4.0,4.0,4.0,4.0,4.0]))) + tfromPrimal (STKS [5] STKScalar) (tconcrete (FTKS [5] FTKScalar) (sfromListLinear [5] [1.0,2.0,3.0,4.0,5.0]))) (\\v5 -> tlet (tfromPrimal (STKS [5] STKScalar) (recip (sreplicate @_ @5 (sscalar 1.0) + exp (negate (sfromR (tprimalPart (rfromS v5))))))) (\\v6 -> tlet (tfromDual (tdualPart (STKS [5] STKScalar) (tfromPrimal (STKS [5] STKScalar) (sfromR (tprimalPart (rfromS v6)) * (sreplicate @_ @5 (sscalar 1.0) - sfromR (tprimalPart (rfromS v6))))) * tdualPart (STKS [5] STKScalar) (sfromR (tfromDual (tdualPart (STKR (SNat @1) STKScalar) (rfromS v5)))))) (\\v7 -> sfromR (tfromPrimal (STKR (SNat @1) STKScalar) (tprimalPart (rfromS v6))) + v7))))) * tfromPrimal (STKS [5,2] STKScalar) (tconcrete (FTKS [5,2] FTKScalar) (sfromListLinear [5,2] [1.0,1.0,2.0,2.0,3.0,3.0,4.0,4.0,5.0,5.0]))) + tfromPrimal (STKS [2] STKScalar) (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [1.0,2.0])))) (\\v8 -> sreplicate @_ @2 (recip (ssum @_ @2 v8)) * v8))"
"\\dummy" ++ " -> " ++ printAstSimple renames (simplifyInlineContract ast3)
@?= "\\dummy -> rfromS (tlet (exp (ssum @_ @5 (str (sreplicate @_ @2 (tlet (ssum @_ @4 (str (sreplicate @_ @5 (tlet (ssum @_ @3 (tfromPrimal (STKS [3,4] STKScalar) (str (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)))) * tfromPrimal (STKS [3,4] STKScalar) (tconcrete (FTKS [3,4] FTKScalar) (sfromListLinear [3,4] [1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0]))) + tfromPrimal (STKS [4] STKScalar) (tconcrete (FTKS [4] FTKScalar) (sfromListLinear [4] [1.0,2.0,3.0,4.0]))) (\\v2 -> tlet (tfromPrimal (STKS [4] STKScalar) (recip (sreplicate @_ @4 (sscalar 1.0) + exp (negate (tprimalPart v2))))) (\\v3 -> tfromPrimal (STKS [4] STKScalar) (tprimalPart v3) + tfromDual (tdualPart (STKS [4] STKScalar) (tfromPrimal (STKS [4] STKScalar) (tprimalPart v3 * (sreplicate @_ @4 (sscalar 1.0) - tprimalPart v3))) * tdualPart (STKS [4] STKScalar) v2))))) * tfromPrimal (STKS [4,5] STKScalar) (tconcrete (FTKS [4,5] FTKScalar) (sfromListLinear [4,5] [1.0,1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0,3.0,4.0,4.0,4.0,4.0,4.0]))) + tfromPrimal (STKS [5] STKScalar) (tconcrete (FTKS [5] FTKScalar) (sfromListLinear [5] [1.0,2.0,3.0,4.0,5.0]))) (\\v5 -> tlet (tfromPrimal (STKS [5] STKScalar) (recip (sreplicate @_ @5 (sscalar 1.0) + exp (negate (tprimalPart v5))))) (\\v6 -> tfromPrimal (STKS [5] STKScalar) (tprimalPart v6) + tfromDual (tdualPart (STKS [5] STKScalar) (tfromPrimal (STKS [5] STKScalar) (tprimalPart v6 * (sreplicate @_ @5 (sscalar 1.0) - tprimalPart v6))) * tdualPart (STKS [5] STKScalar) v5))))) * tfromPrimal (STKS [5,2] STKScalar) (tconcrete (FTKS [5,2] FTKScalar) (sfromListLinear [5,2] [1.0,1.0,2.0,2.0,3.0,3.0,4.0,4.0,5.0,5.0]))) + tfromPrimal (STKS [2] STKScalar) (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [1.0,2.0])))) (\\v8 -> sreplicate @_ @2 (recip (ssum @_ @2 v8)) * v8))"
@?= "\\dummy -> rfromS (tlet (exp (smatvecmul (tfromPrimal (STKS [2,5] STKScalar) (tconcrete (FTKS [2,5] FTKScalar) (sfromListLinear [2,5] [1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]))) (tlet (smatvecmul (tfromPrimal (STKS [5,4] STKScalar) (tconcrete (FTKS [5,4] FTKScalar) (sfromListLinear [5,4] [1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0,1.0,2.0,3.0,4.0]))) (tlet (ssum @_ @3 (tfromPrimal (STKS [3,4] STKScalar) (str (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)))) * tfromPrimal (STKS [3,4] STKScalar) (tconcrete (FTKS [3,4] FTKScalar) (sfromListLinear [3,4] [1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0]))) + tfromPrimal (STKS [4] STKScalar) (tconcrete (FTKS [4] FTKScalar) (sfromListLinear [4] [1.0,2.0,3.0,4.0]))) (\\v2 -> tlet (tfromPrimal (STKS [4] STKScalar) (recip (sreplicate @_ @4 (sscalar 1.0) + exp (negate (tprimalPart v2))))) (\\v3 -> tfromPrimal (STKS [4] STKScalar) (tprimalPart v3) + tfromDual (tdualPart (STKS [4] STKScalar) (tfromPrimal (STKS [4] STKScalar) (tprimalPart v3 * (sreplicate @_ @4 (sscalar 1.0) - tprimalPart v3))) * tdualPart (STKS [4] STKScalar) v2)))) + tfromPrimal (STKS [5] STKScalar) (tconcrete (FTKS [5] FTKScalar) (sfromListLinear [5] [1.0,2.0,3.0,4.0,5.0]))) (\\v5 -> tlet (tfromPrimal (STKS [5] STKScalar) (recip (sreplicate @_ @5 (sscalar 1.0) + exp (negate (tprimalPart v5))))) (\\v6 -> tfromPrimal (STKS [5] STKScalar) (tprimalPart v6) + tfromDual (tdualPart (STKS [5] STKScalar) (tfromPrimal (STKS [5] STKScalar) (tprimalPart v6 * (sreplicate @_ @5 (sscalar 1.0) - tprimalPart v6))) * tdualPart (STKS [5] STKScalar) v5)))) + tfromPrimal (STKS [2] STKScalar) (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [1.0,2.0])))) (\\v8 -> sreplicate @_ @2 (recip (ssum @_ @2 v8)) * v8))"

testVT2OPPNonLin2 :: Assertion
testVT2OPPNonLin2 = do
Expand All @@ -834,6 +834,6 @@ testVT2OPPNonLin2 = do
printArtifactPrimalPretty renames artifactRevnonLin
@?= "\\m1 -> let v9 = ssum @_ @3 (str (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0))) * str (sfromR (tproject1 (tproject1 (tproject1 m1))))) + sfromR (tproject2 (tproject1 (tproject1 m1))) ; v10 = exp (negate v9) ; v11 = sreplicate @_ @4 (sscalar 1.0) + v10 ; v12 = recip v11 ; v16 = sconcrete (sfromListLinear [4] [0.0,0.0,0.0,0.0]) ; m17 = str (sreplicate @_ @5 (v12 + v16)) ; m18 = str (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) ; v19 = ssum @_ @4 (m17 * m18) + sfromR (tproject2 (tproject2 (tproject1 m1))) ; v20 = exp (negate v19) ; v21 = sreplicate @_ @5 (sscalar 1.0) + v20 ; v22 = recip v21 ; v26 = sconcrete (sfromListLinear [5] [0.0,0.0,0.0,0.0,0.0]) ; m27 = str (sreplicate @_ @2 (v22 + v26)) ; v28 = exp (ssum @_ @5 (m27 * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x29 = ssum @_ @2 v28 ; v30 = sreplicate @_ @2 (recip x29) in rfromS (v30 * v28)"
printArtifactPretty renames (simplifyArtifact artifactRevnonLin)
@?= "\\v31 m1 -> tfromS (let v12 = recip (sreplicate @_ @4 (sscalar 1.0) + exp (negate (ssdot1In (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0))) (sfromR (tproject1 (tproject1 (tproject1 m1)))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m17 = str (sreplicate @_ @5 (sconcrete (sfromListLinear [4] [0.0,0.0,0.0,0.0]) + v12)) ; m18 = str (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) ; v22 = recip (sreplicate @_ @5 (sscalar 1.0) + exp (negate (ssum @_ @4 (m17 * m18) + sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; m27 = str (sreplicate @_ @2 (sconcrete (sfromListLinear [5] [0.0,0.0,0.0,0.0,0.0]) + v22)) ; v28 = exp (ssum @_ @5 (m27 * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x29 = ssum @_ @2 v28 ; v32 = v28 * (sreplicate @_ @2 (negate (recip (x29 * x29)) * sdot0 v28 (sfromR v31)) + sreplicate @_ @2 (recip x29) * sfromR v31) ; v35 = (v22 * (sreplicate @_ @5 (sscalar 1.0) - v22)) * ssum @_ @2 (sfromR (tproject1 (tproject2 m1)) * str (sreplicate @_ @5 v32)) ; v38 = (v12 * (sreplicate @_ @4 (sscalar 1.0) - v12)) * ssdot1In m18 (sreplicate @_ @4 v35) in tpair (tpair (tpair (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)) * str (sreplicate @_ @3 v38), v38), tpair (str (scast m17) * str (sreplicate @_ @4 (scast v35)), v35)), tpair (str m27 * str (sreplicate @_ @5 v32), v32)))"
@?= "\\v31 m1 -> tfromS (let v12 = recip (sreplicate @_ @4 (sscalar 1.0) + exp (negate (smatvecmul (sfromR (tproject1 (tproject1 (tproject1 m1)))) (sreplicate @_ @3 (sscalar 7.0)) + sfromR (tproject2 (tproject1 (tproject1 m1)))))) ; m17 = str (sreplicate @_ @5 (sconcrete (sfromListLinear [4] [0.0,0.0,0.0,0.0]) + v12)) ; m18 = str (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) ; v22 = recip (sreplicate @_ @5 (sscalar 1.0) + exp (negate (ssum @_ @4 (m17 * m18) + sfromR (tproject2 (tproject2 (tproject1 m1)))))) ; m27 = str (sreplicate @_ @2 (sconcrete (sfromListLinear [5] [0.0,0.0,0.0,0.0,0.0]) + v22)) ; v28 = exp (ssum @_ @5 (m27 * str (sfromR (tproject1 (tproject2 m1)))) + sfromR (tproject2 (tproject2 m1))) ; x29 = ssum @_ @2 v28 ; v32 = v28 * (sreplicate @_ @2 (negate (recip (x29 * x29)) * sdot0 v28 (sfromR v31)) + sreplicate @_ @2 (recip x29) * sfromR v31) ; v35 = (v22 * (sreplicate @_ @5 (sscalar 1.0) - v22)) * smatvecmul (str (sfromR (tproject1 (tproject2 m1)))) v32 ; v38 = (v12 * (sreplicate @_ @4 (sscalar 1.0) - v12)) * smatvecmul m18 v35 in tpair (tpair (tpair (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0)) * str (sreplicate @_ @3 v38), v38), tpair (str (scast m17) * str (sreplicate @_ @4 (scast v35)), v35)), tpair (str m27 * str (sreplicate @_ @5 v32), v32)))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRevnonLin)
@?= "\\m1 -> rfromS (let v28 = exp (ssdot1In (sreplicate @_ @2 (sconcrete (sfromListLinear [5] [0.0,0.0,0.0,0.0,0.0]) + recip (sreplicate @_ @5 (sscalar 1.0) + exp (negate (ssdot1In (sreplicate @_ @5 (sconcrete (sfromListLinear [4] [0.0,0.0,0.0,0.0]) + recip (sreplicate @_ @4 (sscalar 1.0) + exp (negate (ssdot1In (sreplicate @_ @4 (sreplicate @_ @3 (sscalar 7.0))) (sfromR (tproject1 (tproject1 (tproject1 m1)))) + sfromR (tproject2 (tproject1 (tproject1 m1)))))))) (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) + sfromR (tproject2 (tproject2 (tproject1 m1)))))))) (sfromR (tproject1 (tproject2 m1))) + sfromR (tproject2 (tproject2 m1))) in sreplicate @_ @2 (recip (ssum @_ @2 v28)) * v28)"
@?= "\\m1 -> rfromS (let v28 = exp (smatvecmul (sfromR (tproject1 (tproject2 m1))) (sconcrete (sfromListLinear [5] [0.0,0.0,0.0,0.0,0.0]) + recip (sreplicate @_ @5 (sscalar 1.0) + exp (negate (smatvecmul (scast (sfromR (tproject1 (tproject2 (tproject1 m1))))) (sconcrete (sfromListLinear [4] [0.0,0.0,0.0,0.0]) + recip (sreplicate @_ @4 (sscalar 1.0) + exp (negate (smatvecmul (sfromR (tproject1 (tproject1 (tproject1 m1)))) (sreplicate @_ @3 (sscalar 7.0)) + sfromR (tproject2 (tproject1 (tproject1 m1))))))) + sfromR (tproject2 (tproject2 (tproject1 m1))))))) + sfromR (tproject2 (tproject2 m1))) in sreplicate @_ @2 (recip (ssum @_ @2 v28)) * v28)"

0 comments on commit 965cabb

Please sign in to comment.