Skip to content

Commit

Permalink
Remove most constraints from code to do with Delta
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 10, 2025
1 parent b40cb09 commit 124da73
Show file tree
Hide file tree
Showing 13 changed files with 918 additions and 602 deletions.
6 changes: 4 additions & 2 deletions bench/common/BenchProdTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ revRankedLProdr =
withKnownSTK (stkOfListR (knownSTK @(TKScalar Double)) (SNat @n)) $
rev rankedLProdr

rankedNotSharedLProd :: (BaseTensor target, GoodScalar r)
rankedNotSharedLProd :: ( BaseTensor target, GoodScalar r
, Show (target (TKScalar r)) )
=> ListR n (ADVal target (TKScalar r))
-> ADVal target (TKScalar r)
rankedNotSharedLProd = foldr1 multNotShared
Expand Down Expand Up @@ -167,7 +168,8 @@ revRankedLtProdr =
withKnownSTK (stkOfListR (knownSTK @(TKS '[] Double)) (SNat @n)) $
rev rankedLtProdr

rankedNotSharedLtProd :: (BaseTensor target, GoodScalar r)
rankedNotSharedLtProd :: ( BaseTensor target, GoodScalar r
, Show (target (TKS '[] r)) )
=> ListR n (ADVal target (TKS '[] r))
-> ADVal target (TKS '[] r)
rankedNotSharedLtProd = foldr1 multNotShared
Expand Down
8 changes: 4 additions & 4 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,11 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstSFromK :: GoodScalar r
=> AstTensor ms s (TKScalar r) -> AstTensor ms s (TKS '[] r)
AstSFromR :: forall sh x ms s.
ShS sh
-> AstTensor ms s (TKR2 (Rank sh) x) -> AstTensor ms s (TKS2 sh x)
ShS sh -> AstTensor ms s (TKR2 (Rank sh) x)
-> AstTensor ms s (TKS2 sh x)
AstSFromX :: forall sh sh' x ms s. Rank sh ~ Rank sh'
=> ShS sh
-> AstTensor ms s (TKX2 sh' x) -> AstTensor ms s (TKS2 sh x)
=> ShS sh -> AstTensor ms s (TKX2 sh' x)
-> AstTensor ms s (TKS2 sh x)

-- Backend-specific primitives
AstReplicate0NS :: ShS sh -> AstTensor ms s (TKS2 '[] x)
Expand Down
10 changes: 5 additions & 5 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ ftkAst t = case t of
FTKS sh FTKScalar -> FTKS sh FTKScalar

AstIndexS shn v _ix -> case ftkAst v of
FTKS _sh1sh2 x -> FTKS shn x
FTKS _ x -> FTKS shn x
AstScatterS shn v (_ , ix) -> case ftkAst v of
FTKS _ x -> FTKS (ixsToShS ix `shsAppend` shn) x
AstGatherS shn v (vars, _) -> case ftkAst v of
Expand All @@ -118,13 +118,13 @@ ftkAst t = case t of
AstIotaS n@SNat -> FTKS (n :$$ ZSS) FTKScalar
AstAppendS a b -> case (ftkAst a, ftkAst b) of
(FTKS (m :$$ sh) x, FTKS (n :$$ _) _) -> FTKS (snatPlus m n :$$ sh) x
AstSliceS _ nsnat@SNat _ a -> case ftkAst a of
FTKS (_ :$$ sh) x -> FTKS (nsnat :$$ sh) x
AstSliceS _ n@SNat _ a -> case ftkAst a of
FTKS (_ :$$ sh) x -> FTKS (n :$$ sh) x
AstReverseS v -> ftkAst v
AstTransposeS perm v -> case ftkAst v of
FTKS sh x -> FTKS (shsPermutePrefix perm sh) x
AstReshapeS sh v -> case ftkAst v of
FTKS _ x -> FTKS sh x
AstReshapeS sh2 v -> case ftkAst v of
FTKS _ x -> FTKS sh2 x
AstZipS v -> case ftkAst v of
FTKProduct (FTKS sh y) (FTKS _ z) -> FTKS sh (FTKProduct y z)
AstUnzipS v -> case ftkAst v of
Expand Down
54 changes: 26 additions & 28 deletions src/HordeAd/Core/CarriersADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ module HordeAd.Core.CarriersADVal
import Prelude

import Data.Type.Equality ((:~:) (Refl))
import GHC.TypeLits (KnownNat)

import Data.Array.Nested (KnownShS (..), KnownShX (..), Rank)
import Data.Array.Nested (Rank, ShS (..))

import HordeAd.Core.Delta
import HordeAd.Core.DeltaFreshId
Expand Down Expand Up @@ -56,8 +55,8 @@ deriving instance (Show (f z), Show (Delta f z))
-- of the dual number is an AST term or not).
-- The bare constructor should not be used directly (which is not enforced
-- by the types yet), except when deconstructing via pattern-matching.
dD :: forall f z. KnownSTK z
=> f z -> Delta f z -> ADVal f z
dD :: forall f z.
f z -> Delta f z -> ADVal f z
dD !a !dual = dDnotShared a (shareDelta dual)

-- | This a not so smart a constructor for 'D' of 'ADVal' that does not record
Expand Down Expand Up @@ -124,8 +123,7 @@ dDnotShared = ADVal
-- terms get an identifier. Alternatively, 'HordeAd.Core.CarriersADVal.dD'
-- or library definitions that use it could be made smarter.

unDeltaPair :: (KnownSTK x, KnownSTK y)
=> Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPair :: Delta target (TKProduct x y) -> (Delta target x, Delta target y)
unDeltaPair (DeltaPair a b) = (a, b)
unDeltaPair (DeltaZero (FTKProduct ftk1 ftk2)) = (DeltaZero ftk1, DeltaZero ftk2)
unDeltaPair d = let dShared = shareDelta d -- TODO: more cases
Expand All @@ -142,7 +140,7 @@ unDeltaPairUnshared d = case ftkDelta d of
withKnownSTK (ftkToSTK ftk2) $
(DeltaProject1 d, DeltaProject2 d)

dScale :: Num (f z) => f z -> Delta f z -> Delta f z
dScale :: (Num (f z), Show (f z)) => f z -> Delta f z -> Delta f z
dScale _ (DeltaZero ftk) = DeltaZero ftk
dScale v u' = DeltaScale v u'

Expand All @@ -152,33 +150,33 @@ dAdd v DeltaZero{} = v
dAdd v w = DeltaAdd v w

-- Avoids building huge Delta terms, not only evaluating them.
dFromS :: forall y z target. (KnownSTK y, KnownSTK z)
=> Delta target y -> Delta target z
dFromS (DeltaSFromR @sh @x d)
| Just Refl <- sameKnownSTS @z @(TKR2 (Rank sh) x) = d
dFromS (DeltaSFromX @_ @sh' @x d)
| Just Refl <- sameKnownSTS @z @(TKX2 sh' x) = d
dFromS d = DeltaFromS d
dFromS :: forall y z target.
STensorKind z -> Delta target y -> Delta target z
dFromS stk (DeltaSFromR _sh d)
| y2 <- ftkDelta d
, Just Refl <- sameSTK stk (ftkToSTK y2) = d
dFromS stk (DeltaSFromX _sh d)
| y2 <- ftkDelta d
, Just Refl <- sameSTK stk (ftkToSTK y2) = d
dFromS stk d = DeltaFromS stk d

dSFromR :: forall sh x target.
(KnownShS sh, KnownNat (Rank sh), KnownSTK x)
=> Delta target (TKR2 (Rank sh) x)
ShS sh -> Delta target (TKR2 (Rank sh) x)
-> Delta target (TKS2 sh x)
dSFromR (DeltaFromS @y d) =
case sameKnownSTS @y @(TKS2 sh x) of
dSFromR sh (DeltaFromS (STKR _ x) d) = case ftkDelta d of
y2 -> case sameSTK (ftkToSTK y2) (STKS sh x) of
Just Refl -> d
_ -> error "sfromR: different shapes in DeltaSFromR(DeltaFromS)"
dSFromR d = DeltaSFromR d
dSFromR sh d = DeltaSFromR sh d

dSFromX :: forall sh sh' x target.
(KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', KnownSTK x)
=> Delta target (TKX2 sh' x)
dSFromX :: forall sh sh' x target. Rank sh ~ Rank sh'
=> ShS sh -> Delta target (TKX2 sh' x)
-> Delta target (TKS2 sh x)
dSFromX (DeltaFromS @y d) =
case sameKnownSTS @y @(TKS2 sh x) of
dSFromX sh (DeltaFromS (STKX _ x) d) = case ftkDelta d of
y2 -> case sameSTK (ftkToSTK y2) (STKS sh x) of
Just Refl -> d
_ -> error "sfromR: different shapes in DeltaSFromX(DeltaFromS)"
dSFromX d = DeltaSFromX d
dSFromX sh d = DeltaSFromX sh d

-- This hack is needed to recover shape from tensors,
-- in particular in case of numeric literals and also for forward derivative.
Expand All @@ -193,18 +191,18 @@ fromPrimalADVal a = dDnotShared a (DeltaZero $ tftk knownSTK a)
-- constructed using multiple applications of the `dDnotShared` operation.
-- The resulting term may not have sharing information inside,
-- but is ready to be shared as a whole.
ensureToplevelSharing :: KnownSTK z => ADVal f z -> ADVal f z
ensureToplevelSharing :: ADVal f z -> ADVal f z
ensureToplevelSharing (D u u') = dD u u'

scaleNotShared :: Num (f z)
scaleNotShared :: (Num (f z), Show (f z))
=> f z -> ADVal f z -> ADVal f z
scaleNotShared !a (D u u') = dDnotShared (a * u) (dScale a u')

addNotShared :: forall f z. Num (f z)
=> ADVal f z -> ADVal f z -> ADVal f z
addNotShared (D u u') (D v v') = dDnotShared (u + v) (dAdd u' v')

multNotShared :: forall f z. Num (f z)
multNotShared :: forall f z. (Num (f z), Show (f z))
=> ADVal f z -> ADVal f z -> ADVal f z
multNotShared (D u u') (D v v') =
dDnotShared (u * v) (dAdd (dScale v u') (dScale u v'))
Expand Down
Loading

0 comments on commit 124da73

Please sign in to comment.