-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathBasic.hs
90 lines (77 loc) · 2.62 KB
/
Basic.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
{-# LANGUAGE RankNTypes #-}
-- |
-- Module : Control.Monad.Bayes.Traced.Basic
-- Description : Distributions on full execution traces of full programs
-- Copyright : (c) Adam Scibior, 2015-2020
-- License : MIT
-- Maintainer : leonhard.markert@tweag.io
-- Stability : experimental
-- Portability : GHC
module Control.Monad.Bayes.Traced.Basic
( Traced,
hoist,
marginal,
-- mhStep,
-- mh,
)
where
import Control.Applicative (liftA2)
import Control.Monad.Bayes.Class
( MonadDistribution (random),
MonadFactor (..),
MonadMeasure,
)
import Control.Monad.Bayes.Density.Free (Density)
import Control.Monad.Bayes.Traced.Common
( Trace (..),
bind,
mhTrans',
scored,
singleton,
)
import Control.Monad.Bayes.Weighted (Weighted)
import Data.Functor.Identity (Identity)
import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList)
-- | Tracing monad that records random choices made in the program.
data Traced m a = Traced
{ -- | Run the program with a modified trace.
model :: Weighted (Density Identity) a,
-- | Record trace and output.
traceDist :: m (Trace a)
}
instance Monad m => Functor (Traced m) where
fmap f (Traced m d) = Traced (fmap f m) (fmap (fmap f) d)
instance Monad m => Applicative (Traced m) where
pure x = Traced (pure x) (pure (pure x))
(Traced mf df) <*> (Traced mx dx) = Traced (mf <*> mx) (liftA2 (<*>) df dx)
instance Monad m => Monad (Traced m) where
(Traced mx dx) >>= f = Traced my dy
where
my = mx >>= model . f
dy = dx `bind` (traceDist . f)
instance MonadDistribution m => MonadDistribution (Traced m) where
random = Traced random (fmap singleton random)
instance MonadFactor m => MonadFactor (Traced m) where
score w = Traced (score w) (score w >> pure (scored w))
instance MonadMeasure m => MonadMeasure (Traced m)
hoist :: (forall x. m x -> m x) -> Traced m a -> Traced m a
hoist f (Traced m d) = Traced m (f d)
-- | Discard the trace and supporting infrastructure.
marginal :: Monad m => Traced m a -> m a
marginal (Traced _ d) = fmap output d
-- | A single step of the Trace Metropolis-Hastings algorithm.
-- mhStep :: MonadDistribution m => Traced m a -> Traced m a
-- mhStep (Traced m d) = Traced m d'
-- where
-- d' = d >>= mhTrans' m
-- | Full run of the Trace Metropolis-Hastings algorithm with a specified
-- number of steps.
-- mh :: MonadDistribution m => Int -> Traced m a -> m [a]
-- mh n (Traced m d) = fmap (map output . NE.toList) (f n)
-- where
-- f k
-- | k <= 0 = fmap (:| []) d
-- | otherwise = do
-- (x :| xs) <- f (k - 1)
-- y <- mhTrans m x
-- return (y :| x : xs)