Skip to content

Commit

Permalink
Remove constraints from trev and friends
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 12, 2025
1 parent 8d76558 commit 0f9cd0c
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 169 deletions.
2 changes: 1 addition & 1 deletion bench/common/BenchProdTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,4 @@ inspect $ hasNoTypeClassesExcept 'revRankedTProd [''KnownNat, ''KnownSTK, ''Base

{- with --ghc-options="-fpolymorphic-specialisation"
additional classes appear (at the end): -}
inspect $ hasNoTypeClassesExcept 'revRankedTProd [''KnownNat, ''KnownSTK, ''BaseTensor, ''GoodScalar, ''AstSpan, ''Num, ''Show, ''Ord, ''Typeable, ''IfDifferentiable, ''Eq, ''NFData, ''Default.Default, ''Nested.Elt, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Nested.KnownShS, ''Boolean, ''EqF, ''OrdF, ''AllTargetShow, ''ShareTensor, ''LetTensor, ''(~), ''Nested.Storable, ''Nested.KnownShX, ''WithDict, ''RealFrac, ''PermC, ''RealFloatF, ''Nested.FloatElt, ''IntegralF, ''Integral, ''Numeric, ''IsList]
inspect $ hasNoTypeClassesExcept 'revRankedTProd [''KnownNat, ''KnownSTK, ''BaseTensor, ''GoodScalar, ''AstSpan, ''Num, ''Show, ''Ord, ''Typeable, ''IfDifferentiable, ''Eq, ''NFData, ''Default.Default, ''Nested.Elt, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Nested.KnownShS, ''Boolean, ''EqF, ''OrdF, ''AllTargetShow, ''ShareTensor, ''LetTensor, ''(~), ''Nested.Storable, ''Nested.KnownShX, ''WithDict, ''RealFrac, ''PermC, ''RealFloatF, ''Nested.FloatElt, ''IntegralF, ''Integral, ''Numeric, ''IsList, ''AdaptableTarget]
97 changes: 52 additions & 45 deletions src/HordeAd/Core/Engine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ module HordeAd.Core.Engine
import Prelude

import Data.Int (Int64)
import Data.Maybe (fromMaybe, isJust)
import Data.Maybe (isJust)
import GHC.TypeLits (KnownNat)
import Type.Reflection (Typeable)

Expand Down Expand Up @@ -60,7 +60,7 @@ import HordeAd.Core.Unwind
-- from different levels of differentiation if it's done multiple times.
rev
:: forall astvals z.
( X astvals ~ X (Value astvals), KnownSTK (X astvals), KnownSTK z
( X astvals ~ X (Value astvals), KnownSTK (X astvals)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals
, AdaptableTarget RepN (Value astvals) )
=> (astvals -> AstTensor AstMethodLet FullSpan z)
Expand All @@ -79,7 +79,7 @@ rev f vals = revDtMaybe f vals Nothing
-- tensor codomain.
revDt
:: forall astvals z.
( X astvals ~ X (Value astvals), KnownSTK (X astvals), KnownSTK z
( X astvals ~ X (Value astvals), KnownSTK (X astvals)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals
, AdaptableTarget RepN (Value astvals) )
=> (astvals -> AstTensor AstMethodLet FullSpan z)
Expand All @@ -91,7 +91,7 @@ revDt f vals dt = revDtMaybe f vals (Just dt)

revDtMaybe
:: forall astvals z.
( X astvals ~ X (Value astvals), KnownSTK (X astvals), KnownSTK z
( X astvals ~ X (Value astvals), KnownSTK (X astvals)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals
, AdaptableTarget RepN (Value astvals) )
=> (astvals -> AstTensor AstMethodLet FullSpan z)
Expand All @@ -105,8 +105,8 @@ revDtMaybe f vals0 mdt =
g !hv = tlet hv $ \ !hvShared ->
f $ fromTarget hvShared
valsTarget = toTarget vals0
ftk = tftk knownSTK valsTarget
artifact = fst $ revProduceArtifact (isJust mdt) g emptyEnv ftk
xftk = tftk (knownSTK @(X astvals)) valsTarget
artifact = fst $ revProduceArtifact (isJust mdt) g emptyEnv xftk
in fromTargetAD $ fst $ revEvalArtifact artifact valsTarget mdt
{- TODO
{-# SPECIALIZE revDtMaybe
Expand All @@ -121,19 +121,17 @@ revDtMaybe f vals0 mdt =
-}

revArtifactAdapt
:: forall astvals z.
( KnownSTK (X astvals)
, AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals )
:: forall astvals z. AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals
=> Bool
-> (astvals -> AstTensor AstMethodLet FullSpan z)
-> FullTensorKind (X astvals)
-> (AstArtifactRev (X astvals) z, Delta (AstRaw PrimalSpan) z )
revArtifactAdapt hasDt f ftk =
revArtifactAdapt hasDt f xftk =
let g :: AstTensor AstMethodLet FullSpan (X astvals)
-> AstTensor AstMethodLet FullSpan z
g !hv = tlet hv $ \ !hvShared ->
f $ fromTarget hvShared
in revProduceArtifact hasDt g emptyEnv ftk
in revProduceArtifact hasDt g emptyEnv xftk
{- TODO
{-# SPECIALIZE revArtifactAdapt
:: ( KnownNat n
Expand Down Expand Up @@ -170,18 +168,23 @@ forwardPassByApplication g hVectorPrimal _var _hVector =
in g varInputs

revEvalArtifact
:: forall x z. (KnownSTK x, KnownSTK z)
:: forall x z. KnownSTK x
=> AstArtifactRev x z
-> RepN x
-> Maybe (RepN (ADTensorKind z))
-> (RepN (ADTensorKind x), RepN z)
{-# INLINE revEvalArtifact #-}
revEvalArtifact AstArtifactRev{..} parameters mdt
| Dict <- lemKnownSTKOfAD (knownSTK @z) =
let oneAtF = constantTarget 1 $ adFTK $ ftkAst artPrimalRev
dt = fromMaybe oneAtF mdt
revEvalArtifact AstArtifactRev{..} parameters edt =
let aftk = adFTK $ ftkAst artPrimalRev
env = extendEnv artVarDomainRev parameters emptyEnv
envDt = extendEnv artVarDtRev dt env
envDt =
withKnownSTK (ftkToSTK aftk) $
case edt of
Nothing ->
let oneAtF = constantTarget 1 aftk
in extendEnv artVarDtRev oneAtF env
Just dt ->
extendEnv artVarDtRev dt env
gradient = interpretAst envDt artDerivativeRev
primal = interpretAst env artPrimalRev
in (gradient, primal)
Expand All @@ -208,15 +211,16 @@ fwd
-> Value astvals -- morally (ADTensorKind astvals)
-> RepN (ADTensorKind z)
fwd f vals ds =
let g :: AstTensor AstMethodLet FullSpan (X astvals) -> AstTensor AstMethodLet FullSpan z
let g :: AstTensor AstMethodLet FullSpan (X astvals)
-> AstTensor AstMethodLet FullSpan z
g !hv = tlet hv $ \ !hvShared ->
f $ fromTarget hvShared
valsTarget = toTarget vals
ftk = tftk knownSTK valsTarget
artifact = fst $ fwdProduceArtifact g emptyEnv ftk
xftk = tftk (knownSTK @(X astvals)) valsTarget
artifact = fst $ fwdProduceArtifact g emptyEnv xftk
dsTarget = toTarget ds
in fst $ fwdEvalArtifact @_ @z artifact valsTarget
$ toADTensorKindShared knownSTK dsTarget
$ toADTensorKindShared (knownSTK @(X astvals)) dsTarget

fwdEvalArtifact
:: forall x z. KnownSTK x
Expand All @@ -225,16 +229,17 @@ fwdEvalArtifact
-> RepN (ADTensorKind x)
-> (RepN (ADTensorKind z), RepN z)
{-# INLINE fwdEvalArtifact #-}
fwdEvalArtifact AstArtifactFwd{..} parameters ds
| Dict <- lemKnownSTKOfAD (knownSTK @x) =
if adFTK (tftk (knownSTK @x) parameters)
== tftk (knownSTK @(ADTensorKind x)) ds then
let env = extendEnv artVarDomainFwd parameters emptyEnv
envD = extendEnv artVarDsFwd ds env
derivative = interpretAst envD artDerivativeFwd
primal = interpretAst env artPrimalFwd
in (derivative, primal)
else error "fwdEvalArtifact: forward derivative input and sensitivity arguments should have same shapes"
fwdEvalArtifact AstArtifactFwd{..} parameters ds =
let xstk = knownSTK @x
astk = adSTK xstk
in if adFTK (tftk xstk parameters) == tftk astk ds then
let env = extendEnv artVarDomainFwd parameters emptyEnv
envD = withKnownSTK astk $
extendEnv artVarDsFwd ds env
derivative = interpretAst envD artDerivativeFwd
primal = interpretAst env artPrimalFwd
in (derivative, primal)
else error "fwdEvalArtifact: forward derivative input and sensitivity arguments should have same shape"


-- * Old gradient adaptors, with constant and fixed inputs and dt
Expand All @@ -255,40 +260,41 @@ crev
-> DValue advals
-> DValue advals
{-# INLINE crev #-}
crev f vals = crevDtMaybe f vals Nothing
crev f vals = crevDtEither f vals (Left (knownSTK @z))

-- | This version additionally takes the sensitivity parameter.
crevDt
:: forall advals z.
( X advals ~ X (DValue advals), KnownSTK (X advals), KnownSTK z
( X advals ~ X (DValue advals), KnownSTK (X advals)
, AdaptableTarget (ADVal RepN) advals
, AdaptableTarget RepN (DValue advals) )
=> (advals -> ADVal RepN z)
-> DValue advals
-> RepN (ADTensorKind z)
-> DValue advals
{-# INLINE crevDt #-}
crevDt f vals dt = crevDtMaybe f vals (Just dt)
crevDt f vals dt = crevDtEither f vals (Right dt)

crevDtMaybe
crevDtEither
:: forall advals z.
( X advals ~ X (DValue advals), KnownSTK (X advals), KnownSTK z
( X advals ~ X (DValue advals), KnownSTK (X advals)
, AdaptableTarget (ADVal RepN) advals
, AdaptableTarget RepN (DValue advals) )
=> (advals -> ADVal RepN z)
-> DValue advals
-> Maybe (RepN (ADTensorKind z))
-> Either (STensorKind z) (RepN (ADTensorKind z))
-> DValue advals -- morally (ADTensorKind advals)
{-# INLINE crevDtMaybe #-}
crevDtMaybe f vals mdt =
{-# INLINE crevDtEither #-}
crevDtEither f vals edt =
let g :: ADVal RepN (X advals) -> ADVal RepN z
g = f . fromTarget
xftk = tftk (knownSTK @(X advals)) valsTarget
valsTarget = toTarget vals
in fromTargetAD $ fst $ crevOnHVector mdt g valsTarget
in fromTargetAD $ fst $ crevOnHVector edt g xftk valsTarget

{-
{-# SPECIALIZE crevOnHVector
:: Maybe (RepN TKUntyped)
:: Either (RepN TKUntyped)
-> (ADVal RepN TKUntyped
-> ADVal RepN TKUntyped)
-> RepN TKUntyped
Expand All @@ -308,12 +314,13 @@ cfwd
-> DValue advals -- morally (ADTensorKind advals)
-> RepN (ADTensorKind z)
cfwd f vals ds =
let g :: ADVal RepN (X advals) -> ADVal RepN z
g = f . fromTarget
let xftk = tftk (knownSTK @(X advals)) valsTarget
valsTarget = toTarget vals
g :: ADVal RepN (X advals) -> ADVal RepN z
g = f . fromTarget
dsTarget = toTarget ds
in fst $ cfwdOnHVector valsTarget g
$ toADTensorKindShared knownSTK dsTarget
in fst $ cfwdOnHVector xftk valsTarget g
$ toADTensorKindShared (knownSTK @(X advals)) dsTarget



Expand Down
85 changes: 36 additions & 49 deletions src/HordeAd/Core/Ops.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1650,16 +1650,16 @@ class ( Num (IntOf target)
-> target (BuildTensorKind k eShs)
-> target (TKProduct accShs (BuildTensorKind k bShs))
tmapAccumR proxy !k !accShs !bShs !eShs f acc0 es =
let shs = FTKProduct accShs eShs
let xftk = FTKProduct accShs eShs
fl :: forall f. ADReady f
=> f (TKProduct accShs eShs)
-> f (TKProduct accShs bShs)
fl !args = tlet args $ \ !args1 ->
f (tproject1 args1) (tproject2 args1)
in tmapAccumRDer proxy k accShs bShs eShs
(tlambda @target shs (HFun fl))
(tfwd @target shs $ HFun fl)
(trevDt @target shs $ HFun fl)
(tlambda @target xftk (HFun fl))
(tfwd @target xftk $ HFun fl)
(trevDt @target xftk $ HFun fl)
acc0 es
tmapAccumRDer
:: (KnownSTK accShs, KnownSTK bShs, KnownSTK eShs)
Expand Down Expand Up @@ -1696,16 +1696,16 @@ class ( Num (IntOf target)
-> target (BuildTensorKind k eShs)
-> target (TKProduct accShs (BuildTensorKind k bShs))
tmapAccumL proxy !k !accShs !bShs !eShs f acc0 es =
let shs = FTKProduct accShs eShs
let xftk = FTKProduct accShs eShs
fl :: forall f. ADReady f
=> f (TKProduct accShs eShs)
-> f (TKProduct accShs bShs)
fl !args = tlet args $ \ !args1 ->
f (tproject1 args1) (tproject2 args1)
in tmapAccumLDer proxy k accShs bShs eShs
(tlambda @target shs (HFun fl))
(tfwd @target shs $ HFun fl)
(trevDt @target shs $ HFun fl)
(tlambda @target xftk (HFun fl))
(tfwd @target xftk $ HFun fl)
(trevDt @target xftk $ HFun fl)
acc0 es
tmapAccumLDer
:: (KnownSTK accShs, KnownSTK bShs, KnownSTK eShs)
Expand Down Expand Up @@ -1793,68 +1793,57 @@ class ( Num (IntOf target)
-- use the let operations and also their signatures mention @ADReady@,
-- so it's awkward to put the methods into @BaseTensor@,
-- which shouldn't know about lets, etc.
rrev :: forall x r n.
(KnownSTK x, KnownSTK r, KnownNat n)
rrev :: forall x r n. (KnownSTK r, KnownNat n)
=> (forall f. ADReady f => f x -> f (TKR2 n r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind x)
rrev f ftk | Dict <- lemKnownSTKOfAD (knownSTK @x) =
\ !es -> tApply (trev @target ftk $ HFun f) es
rrev f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es -> tApply (trev @target xftk (HFun f) (knownSTK @(TKR2 n r))) es
-- We can't get sh from anywhere, so this is not possible:
-- rrev f shs es = rrevDt f shs es (rreplicate0N sh 1)
rrevDt :: forall x r n.
(KnownSTK x, KnownSTK r, KnownNat n)
rrevDt :: forall x r n. (KnownSTK r, KnownNat n)
=> (forall f. ADReady f => f x -> f (TKR2 n r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind (TKR2 n r)) -- ^ incoming cotangent (dt)
-> target (ADTensorKind x)
rrevDt f ftk | Dict <- lemKnownSTKOfAD (knownSTK @x)
, Dict <- lemKnownSTKOfAD (knownSTK @(TKR2 n r)) =
\ !es !dt -> tApply (trevDt @target ftk $ HFun f)
rrevDt f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es !dt -> tApply (trevDt @target xftk $ HFun f)
(tpair dt es)
rfwd :: forall x r n.
(KnownSTK x, KnownSTK r, KnownNat n)
rfwd :: forall x r n. (KnownSTK r, KnownNat n)
=> (forall f. ADReady f => f x -> f (TKR2 n r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind x) -- ^ incoming tangent (ds)
-> target (ADTensorKind (TKR2 n r))
rfwd f ftk | Dict <- lemKnownSTKOfAD (knownSTK @x)
, Dict <- lemKnownSTKOfAD (knownSTK @(TKR2 n r)) =
\ !es !ds -> tApply (tfwd @target ftk $ HFun f)
rfwd f xftk | Dict <- lemKnownSTKOfAD (knownSTK @(TKR2 n r)) =
\ !es !ds -> tApply (tfwd @target xftk $ HFun f)
(tpair ds es)
srev :: forall x r sh.
( KnownSTK x, KnownSTK r, KnownShS sh
, ADTensorKind (TKS2 sh r) ~ TKS2 sh r )
srev :: forall x r sh. (KnownSTK r, KnownShS sh)
=> (forall f. ADReady f => f x -> f (TKS2 sh r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind x)
srev f ftk | Dict <- lemKnownSTKOfAD (knownSTK @x) =
\ !es -> tApply (trev @target ftk $ HFun f) es
srevDt :: forall x r sh.
(KnownSTK x, KnownSTK r, KnownShS sh)
srev f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es -> tApply (trev @target xftk (HFun f) (knownSTK @(TKS2 sh r))) es
srevDt :: forall x r sh. (KnownSTK r, KnownShS sh)
=> (forall f. ADReady f => f x -> f (TKS2 sh r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind (TKS2 sh r)) -- ^ incoming cotangent (dt)
-> target (ADTensorKind x)
srevDt f ftk | Dict <- lemKnownSTKOfAD (knownSTK @x)
, Dict <- lemKnownSTKOfAD (knownSTK @(TKS2 sh r)) =
\ !es !dt -> tApply (trevDt @target ftk $ HFun f)
srevDt f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es !dt -> tApply (trevDt @target xftk $ HFun f)
(tpair dt es)
sfwd :: forall x r sh.
(KnownSTK x, KnownSTK r, KnownShS sh)
sfwd :: forall x r sh. (KnownSTK r, KnownShS sh)
=> (forall f. ADReady f => f x -> f (TKS2 sh r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind x) -- ^ incoming tangent (ds)
-> target (ADTensorKind (TKS2 sh r))
sfwd f ftk | Dict <- lemKnownSTKOfAD (knownSTK @x)
, Dict <- lemKnownSTKOfAD (knownSTK @(TKS2 sh r)) =
\ !es !ds -> tApply (tfwd @target ftk $ HFun f)
sfwd f xftk | Dict <- lemKnownSTKOfAD (knownSTK @(TKS2 sh r)) =
\ !es !ds -> tApply (tfwd @target xftk $ HFun f)
(tpair ds es)
-- If the result of the argument function is not a scalar,
-- the result of this operation is the gradient of a function that additionally
Expand All @@ -1866,22 +1855,20 @@ class ( Num (IntOf target)
-- These methods (and tlambda) are exactly what is needed as arguments
-- of tmapAccumRDer.
trev
:: (KnownSTK x, KnownSTK z)
=> FullTensorKind x -- shape of a and da
-> HFun x z -- a |-> b
-> HFunOf target x (ADTensorKind x) -- a |-> da
:: FullTensorKind x -- shape of x and dx
-> HFun x z -- x |-> z
-> STensorKind z
-> HFunOf target x (ADTensorKind x) -- x |-> dx
trevDt
:: (KnownSTK x, KnownSTK z)
=> FullTensorKind x -- shape of a and da
-> HFun x z -- a |-> b
:: FullTensorKind x -- shape of x and dx
-> HFun x z -- x |-> z
-> HFunOf target (TKProduct (ADTensorKind z) x) (ADTensorKind x)
-- [db, a] |-> da
-- [dz, x] |-> dx
tfwd
:: (KnownSTK x, KnownSTK z)
=> FullTensorKind x -- shape of a and da
-> HFun x z -- a |-> b
:: FullTensorKind x -- shape of x and dx
-> HFun x z -- x |-> z
-> HFunOf target (TKProduct (ADTensorKind x) x) (ADTensorKind z)
-- [da, a] |-> db
-- [dx, x] |-> dz

tunit :: BaseTensor target
=> target TKUnit
Expand Down
Loading

0 comments on commit 0f9cd0c

Please sign in to comment.