Skip to content

Commit

Permalink
Switch constraint for a singleton in tApply and tbuild1
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 12, 2025
1 parent b912fbd commit a53c7d9
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 47 deletions.
4 changes: 2 additions & 2 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ interpretAst !env = \case
-- as above so that the mixture becomes compatible; if the spans
-- agreed, the AstApply would likely be simplified before
-- getting interpreted
in tApply t2 ll2
in tApply stk t2 ll2
AstVar _ftk var ->
let var2 = mkAstVarName @FullSpan (varNameToSTK var) (varNameToAstVarId var) -- TODO
-- TODO: this unsafe call is needed for benchmark VTO1.
Expand All @@ -237,7 +237,7 @@ interpretAst !env = \case
AstBuild1 snat stk (var, v) ->
withKnownSTK stk $
let f i = interpretAst (extendEnvI var i env) v
in tbuild1 snat f
in tbuild1 snat stk f
AstConcrete (RepF ftk a) -> tconcrete ftk a

AstLet var u v -> case ftkToSTK (ftkAst u) of
Expand Down
3 changes: 2 additions & 1 deletion src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,15 @@ printAstAux cfg d = \case
. printAst cfg 11 acc0
. showString " "
. printAst cfg 11 es
AstApply _ t ll ->
AstApply stk t ll ->
if loseRoudtrip cfg
then showParen (d > 9)
$ printAstHFunOneUnignore cfg 10 t
. showString " "
. printAst cfg 11 ll
else showParen (d > 10)
$ showString "tApply "
. shows stk
. printAstHFunOneUnignore cfg 10 t
. showString " "
. printAst cfg 11 ll
Expand Down
46 changes: 23 additions & 23 deletions src/HordeAd/Core/Ops.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1724,8 +1724,7 @@ class ( Num (IntOf target)
-> target accShs
-> target (BuildTensorKind k eShs)
-> target (TKProduct accShs (BuildTensorKind k bShs))
tApply :: KnownSTK z
=> HFunOf target x z -> target x
tApply :: STensorKind z -> HFunOf target x z -> target x
-> target z
tlambda :: FullTensorKind x -> HFun x z -> HFunOf target x z
tcond :: Boolean (BoolOf target)
Expand All @@ -1740,11 +1739,10 @@ class ( Num (IntOf target)
maxF :: (Boolean (BoolOf target), OrdF target, KnownSTK y)
=> target y -> target y -> target y
maxF u v = ifF (u >=. v) u v
tbuild1 :: forall y k. KnownSTK y
-- y comes first, because k easy to set via SNat
=> SNat k -> (IntOf target -> target y)
tbuild1 :: forall y k. -- y comes first, because k easy to set via SNat
SNat k -> STensorKind y -> (IntOf target -> target y)
-> target (BuildTensorKind k y)
tbuild1 snat@SNat f =
tbuild1 snat@SNat stk0 f =
let replSTK :: STensorKind z -> (IntOf target -> target z)
-> target (BuildTensorKind k z)
replSTK stk g = case stk of
Expand All @@ -1765,7 +1763,7 @@ class ( Num (IntOf target)
-- TODO: looks expensive, but hard to do better,
-- so let's hope g is full of variables
in tpair (replSTK stk1 f1) (replSTK stk2 f2)
in replSTK (knownSTK @y) f
in replSTK stk0 f

tprimalPart :: target y -> PrimalOf target y
tdualPart :: STensorKind y -> target y -> DualOf target y
Expand Down Expand Up @@ -1798,8 +1796,9 @@ class ( Num (IntOf target)
-> FullTensorKind x
-> target x
-> target (ADTensorKind x)
rrev f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es -> tApply (trev @target xftk (HFun f) (knownSTK @(TKR2 n r))) es
rrev f xftk =
\ !es -> tApply (adSTK $ ftkToSTK xftk)
(trev @target xftk (HFun f) (knownSTK @(TKR2 n r))) es
-- We can't get sh from anywhere, so this is not possible:
-- rrev f shs es = rrevDt f shs es (rreplicate0N sh 1)
rrevDt :: forall x r n. (KnownSTK r, KnownNat n)
Expand All @@ -1808,43 +1807,44 @@ class ( Num (IntOf target)
-> target x
-> target (ADTensorKind (TKR2 n r)) -- ^ incoming cotangent (dt)
-> target (ADTensorKind x)
rrevDt f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es !dt -> tApply (trevDt @target xftk $ HFun f)
(tpair dt es)
rrevDt f xftk =
\ !es !dt -> tApply (adSTK $ ftkToSTK xftk)
(trevDt @target xftk $ HFun f) (tpair dt es)
rfwd :: forall x r n. (KnownSTK r, KnownNat n)
=> (forall f. ADReady f => f x -> f (TKR2 n r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind x) -- ^ incoming tangent (ds)
-> target (ADTensorKind (TKR2 n r))
rfwd f xftk | Dict <- lemKnownSTKOfAD (knownSTK @(TKR2 n r)) =
\ !es !ds -> tApply (tfwd @target xftk $ HFun f)
(tpair ds es)
rfwd f xftk =
\ !es !ds -> tApply (adSTK $ knownSTK @(TKR2 n r))
(tfwd @target xftk $ HFun f) (tpair ds es)
srev :: forall x r sh. (KnownSTK r, KnownShS sh)
=> (forall f. ADReady f => f x -> f (TKS2 sh r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind x)
srev f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es -> tApply (trev @target xftk (HFun f) (knownSTK @(TKS2 sh r))) es
srev f xftk =
\ !es -> tApply (adSTK $ ftkToSTK xftk)
(trev @target xftk (HFun f) (knownSTK @(TKS2 sh r))) es
srevDt :: forall x r sh. (KnownSTK r, KnownShS sh)
=> (forall f. ADReady f => f x -> f (TKS2 sh r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind (TKS2 sh r)) -- ^ incoming cotangent (dt)
-> target (ADTensorKind x)
srevDt f xftk | Dict <- lemKnownSTKOfAD (ftkToSTK xftk) =
\ !es !dt -> tApply (trevDt @target xftk $ HFun f)
(tpair dt es)
srevDt f xftk =
\ !es !dt -> tApply (adSTK $ ftkToSTK xftk)
(trevDt @target xftk $ HFun f) (tpair dt es)
sfwd :: forall x r sh. (KnownSTK r, KnownShS sh)
=> (forall f. ADReady f => f x -> f (TKS2 sh r))
-> FullTensorKind x
-> target x
-> target (ADTensorKind x) -- ^ incoming tangent (ds)
-> target (ADTensorKind (TKS2 sh r))
sfwd f xftk | Dict <- lemKnownSTKOfAD (knownSTK @(TKS2 sh r)) =
\ !es !ds -> tApply (tfwd @target xftk $ HFun f)
(tpair ds es)
sfwd f xftk =
\ !es !ds -> tApply (adSTK $ knownSTK @(TKS2 sh r))
(tfwd @target xftk $ HFun f) (tpair ds es)
-- If the result of the argument function is not a scalar,
-- the result of this operation is the gradient of a function that additionally
-- sums all elements of the result. If all elements are equally important
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ instance ( ADReadyNoLet target, ShareTensor target
(q, bs) = tunpair qbs
dual = DeltaMapAccumL k bShs eShs q es df rf acc0' es'
in dD (tpair accFin bs) dual
tApply (HFun f) = f
tApply _ (HFun f) = f
tlambda _ = id
-- Bangs are for the proper order of sharing stamps.
tcond !stk !b !u !v =
Expand Down
8 changes: 4 additions & 4 deletions src/HordeAd/Core/OpsAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
astMapAccumRDer k bShs eShs f df rf acc0 es
tmapAccumLDer _ !k _ !bShs !eShs f df rf acc0 es =
astMapAccumLDer k bShs eShs f df rf acc0 es
tApply t ll = astApply knownSTK t ll
tApply stk t ll = astApply stk t ll
tlambda shss f =
let (var, ast) = funToAst shss $ \ !ll -> unHFun f ll
in AstLambda (var, shss, ast)
Expand Down Expand Up @@ -1317,7 +1317,7 @@ instance AstSpan s => BaseTensor (AstRaw s) where
AstRaw $ AstMapAccumRDer k bShs eShs f df rf (unAstRaw acc0) (unAstRaw es)
tmapAccumLDer _ !k _ !bShs !eShs f df rf acc0 es =
AstRaw $ AstMapAccumLDer k bShs eShs f df rf (unAstRaw acc0) (unAstRaw es)
tApply t ll = AstRaw $ AstApply knownSTK t (unAstRaw ll)
tApply stk t ll = AstRaw $ AstApply stk t (unAstRaw ll)
tlambda = tlambda @(AstTensor AstMethodLet PrimalSpan)
tcond _ !b !u !v = AstRaw $ AstCond b (unAstRaw u) (unAstRaw v)
tprimalPart t = AstRaw $ primalPart $ unAstRaw t
Expand Down Expand Up @@ -1482,7 +1482,7 @@ instance AstSpan s => BaseTensor (AstNoVectorize s) where
tmapAccumLDer _ !k !accShs !bShs !eShs f df rf acc0 es =
AstNoVectorize $ tmapAccumLDer Proxy k accShs bShs eShs f df rf
(unAstNoVectorize acc0) (unAstNoVectorize es)
tApply t ll = AstNoVectorize $ tApply t (unAstNoVectorize ll)
tApply stk t ll = AstNoVectorize $ tApply stk t (unAstNoVectorize ll)
tlambda = tlambda @(AstTensor AstMethodLet PrimalSpan)
tcond !stk !b !u !v =
AstNoVectorize $ tcond stk b (unAstNoVectorize u) (unAstNoVectorize v)
Expand Down Expand Up @@ -1695,7 +1695,7 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where
tmapAccumLDer _ !k !accShs !bShs !eShs f df rf acc0 es =
wAstNoSimplify $ tmapAccumLDer Proxy k accShs bShs eShs f df rf
(wunAstNoSimplify acc0) (wunAstNoSimplify es)
tApply t ll = wAstNoSimplify $ tApply t (wunAstNoSimplify ll)
tApply stk t ll = wAstNoSimplify $ tApply stk t (wunAstNoSimplify ll)
tlambda = tlambda @(AstRaw PrimalSpan)
tprimalPart t = wAstNoSimplify $ tprimalPart $ wunAstNoSimplify t
tfromPrimal stk t = wAstNoSimplify $ tfromPrimal stk $ wunAstNoSimplify t
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/OpsConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ instance BaseTensor RepN where
oRtmapAccumL k accShs bShs eShs f acc0 es
tmapAccumLDer _ k accShs bShs eShs f _df _rf acc0 es =
oRtmapAccumL k accShs bShs eShs (\ !(RepN a) !(RepN b) -> RepN $ f (a, b)) acc0 es
tApply f x = RepN $ f $ unRepN x
tApply _ f x = RepN $ f $ unRepN x
tlambda _ f x = unRepN $ unHFun f $ RepN x
tcond _ b u v = if b then u else v
tprimalPart = id
Expand Down
30 changes: 15 additions & 15 deletions test/simplified/TestRevFwdFold.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,7 @@ testSin0rmapAccumRD01SN51 = do
f x0 = (\res -> ssum @_ @_ @6 (tproject1 $ tproject1 res)
+ ssum0 @_ @_ @'[6, 5, 4, 3]
(tproject2 res))
$ tbuild1 @f (SNat @6) $ \j ->
$ tbuild1 @f (SNat @6) knownSTK $ \j ->
tmapAccumR (Proxy @f) (SNat @5)
(FTKProduct (FTKS ZSS FTKScalar)
(FTKS (SNat @3 :$$ ZSS) FTKScalar))
Expand Down Expand Up @@ -1625,8 +1625,8 @@ testSin0rmapAccumRD01SN531a = do
=> f (TKS '[3] Double) -> f (TKS '[2, 2, 2, 3] Double)
f x0 = (\res -> srepl 2 - sreplicate @_ @2 (tproject1 $ tproject1 res)
- (tproject2 res))
$ tbuild1 @f (SNat @2) $ \i ->
(tbuild1 @f (SNat @2) $ \j ->
$ tbuild1 @f (SNat @2) knownSTK $ \i ->
(tbuild1 @f (SNat @2) knownSTK $ \j ->
tmapAccumR (Proxy @f) (SNat @2)
(FTKProduct (FTKS (SNat @3 :$$ ZSS) FTKScalar)
(FTKS (SNat @6 :$$ ZSS) FTKScalar))
Expand Down Expand Up @@ -1677,8 +1677,8 @@ testSin0rmapAccumRD01SN531b0 = do
=> f (TKR 0 Double)
-> f (TKR 2 Double)
f x0 = rfromS $ tproject1
$ tbuild1 @f (SNat @2) $ \_i ->
(tbuild1 @f (SNat @2) $ \_j ->
$ tbuild1 @f (SNat @2) knownSTK $ \_i ->
(tbuild1 @f (SNat @2) knownSTK $ \_j ->
tmapAccumR (Proxy @f) (SNat @0)
(FTKS ZSS FTKScalar)
ftkUnit
Expand All @@ -1699,8 +1699,8 @@ testSin0rmapAccumRD01SN531b0PP = do
=> f (TKR 0 Double)
-> f (TKR 2 Double)
f x0 = rfromS $ tproject1
$ tbuild1 @f (SNat @2) $ \_i ->
(tbuild1 @f (SNat @2) $ \_j ->
$ tbuild1 @f (SNat @2) knownSTK $ \_i ->
(tbuild1 @f (SNat @2) knownSTK $ \_j ->
tmapAccumR (Proxy @f) (SNat @0)
(FTKS ZSS FTKScalar)
ftkUnit
Expand All @@ -1724,8 +1724,8 @@ testSin0rmapAccumRD01SN531b0PPj = do
let f :: forall f. ADReady f
=> f (TKR 0 Double) -> f (TKR 2 Double)
f x0 = tlet (
(tbuild1 @f (SNat @2) $ \i ->
(tbuild1 @f (SNat @2) $ \j ->
(tbuild1 @f (SNat @2) knownSTK $ \i ->
(tbuild1 @f (SNat @2) knownSTK $ \j ->
(tmapAccumR (Proxy @f) (SNat @0)
(FTKS ZSS FTKScalar)
ftkUnit
Expand All @@ -1750,8 +1750,8 @@ testSin0rmapAccumRD01SN531bRPPj = do
let f :: forall f. ADReady f
=> f (TKR 0 Double) -> f (TKR 2 Double)
f x0 = tlet (
(tbuild1 @f (SNat @2) $ \i ->
(tbuild1 @f (SNat @2) $ \j ->
(tbuild1 @f (SNat @2) knownSTK $ \i ->
(tbuild1 @f (SNat @2) knownSTK $ \j ->
(tmapAccumR (Proxy @f) (SNat @1)
(FTKR ZSR FTKScalar)
ftkUnit
Expand All @@ -1778,8 +1778,8 @@ testSin0rmapAccumRD01SN531c = do
=> f (TKS '[] Double) -> f (TKS '[2, 2, 2] Double)
f x0 = (\res -> srepl 2 - sreplicate @_ @2 (tproject1 res)
- tproject2 res)
$ tbuild1 @f (SNat @2) $ \i ->
(tbuild1 @f (SNat @2) $ \j ->
$ tbuild1 @f (SNat @2) knownSTK $ \i ->
(tbuild1 @f (SNat @2) knownSTK $ \j ->
(tmapAccumR (Proxy @f) (SNat @2)
(FTKS ZSS FTKScalar)
(FTKS ZSS FTKScalar)
Expand All @@ -1802,8 +1802,8 @@ testSin0rmapAccumRD01SN531Slice = do
(rev' (let f :: forall f. ADReady f
=> f (TKS '[] Double) -> f (TKS '[2, 2] Double)
f x0 = tproject1
$ tbuild1 @f (SNat @2) $ \_i ->
(tbuild1 @f (SNat @2) $ \_j ->
$ tbuild1 @f (SNat @2) knownSTK $ \_i ->
(tbuild1 @f (SNat @2) knownSTK $ \_j ->
(tmapAccumR (Proxy @f) (SNat @1)
(FTKS ZSS FTKScalar)
ftkUnit
Expand Down

0 comments on commit a53c7d9

Please sign in to comment.