Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WrappedContext and separation between sampling and evaluation #249

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
62dafc9
added WrappedContext and impls for PrefixContext and MiniBatchContext
torfjelde May 20, 2021
13efcc4
formatting
torfjelde May 20, 2021
14f9211
DefaultContext replaced with SampleContext and EvaluateContext
torfjelde May 21, 2021
58331eb
fixed impl for dot_tilde
torfjelde May 21, 2021
21379bf
make get_and_set! used in dot_assume always overwrite
torfjelde May 21, 2021
367a86e
fixed implementations for Prior and Likelihood
torfjelde May 21, 2021
cc2e8e6
be explicit about use of EvaluateContext in logprior and loglikelihood
torfjelde May 21, 2021
7ac6c63
fixed constructor for Likelihood and Prior
torfjelde May 21, 2021
360c333
now passing also allow passing override value to assume
torfjelde May 21, 2021
5408cd0
dont mutate VarInfo variable in dot_assume with EvaluateContext
torfjelde May 21, 2021
07a00e6
rename _tilde to tilde_primitive
torfjelde May 23, 2021
7a96e46
refactoring of context_implementations
torfjelde May 23, 2021
c61585e
added unwrap_right_vn and unwrap_right_left_vn thanks to @devmotion
torfjelde May 23, 2021
e3d6515
renamed SampleContext and EvaluateContext
torfjelde May 23, 2021
9d43c7b
separate contexts into separate files
torfjelde May 23, 2021
0c354ca
Apply suggestions from code review
torfjelde May 23, 2021
c63a4e5
added some convenience to MiniBatchContext constructor
torfjelde May 23, 2021
8577968
formatting
torfjelde May 23, 2021
5b42b23
missed a rename of SampleContext
torfjelde May 23, 2021
320d2ed
fixed typo of unwrap_right_left_vns
torfjelde May 23, 2021
9e7691a
removed usage of get_vns_and_dist since we have unwrap_right_left_vns
torfjelde May 23, 2021
6358ee0
use value passed to dot_assume rather than extracting from var_info
torfjelde May 23, 2021
95613d2
fixed constructor of MiniBatchContext
torfjelde May 23, 2021
024e75e
found some more leftover typos
torfjelde May 23, 2021
0446194
moved includes in context_implementations.jl to end of file
torfjelde May 23, 2021
49bd18a
added definition of Model to avoid StackOverflowError
torfjelde May 23, 2021
f9bb04f
removed the now redundant setval_and_resample!
torfjelde May 23, 2021
5bee7f2
fixed prob_macro
torfjelde May 23, 2021
87272f4
updated sampler.jl to work with new contexts
torfjelde May 23, 2021
e7d2344
fixed dot_tilde implementation for LikelihoodContext and PriorContext
torfjelde May 23, 2021
586e5c8
add support for replacing value without mutating VarInfo to Likelihoo…
torfjelde May 23, 2021
03fa1fa
formatting
torfjelde May 23, 2021
32796cb
correct implementation for LikelihoodContext and PriorContext
torfjelde May 23, 2021
fd3f317
SamplingContext will now mutate values in VarInfo, even if the values…
torfjelde May 23, 2021
0a3fe75
remove unnecessary type-specification
torfjelde May 23, 2021
f02a510
renamed tilde_assume and others to tilde_assume! and similars
torfjelde May 25, 2021
fa804d4
formatting
torfjelde May 25, 2021
1e5864b
updated PointwiseLikelihoodContext to the new context approach
torfjelde May 25, 2021
1e1b2e6
fixed a typo in dot_tilde_observe! for PointwiseLikelihoodContext
torfjelde May 28, 2021
7cbe0cc
added missing rewrap impls and simplified constructors
torfjelde May 29, 2021
8a44e8a
add rng and sampler to contexts
torfjelde May 29, 2021
ecf72ad
formatting
torfjelde May 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ export AbstractVarInfo,
SampleFromPrior,
SampleFromUniform,
# Contexts
DefaultContext,
EvaluationContext,
SamplingContext,
LikelihoodContext,
PriorContext,
MiniBatchContext,
Expand Down
66 changes: 55 additions & 11 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,49 @@ end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x

"""
unwrap_right_vn(right, vn)

Return the unwrapped distribution on the right-hand side and variable name on the left-hand
side of a `~` expression such as `x ~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions.
"""
unwrap_right_vn(right, vn) = right, vn
unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name)

"""
unwrap_right_left_vns(context, right, left, vns)

Return the unwrapped distributions on the right-hand side and values and variable names on the
left-hand side of a `.~` expression such as `x .~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
variables.
"""
unwrap_right_left_vns(right, left, vns) = right, left, vns
function unwrap_right_left_vns(right::NamedDist, left, vns)
return unwrap_right_left_vns(right.dist, left, right.name)
end
function unwrap_right_left_vns(
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
)
vns = map(axes(left, 2)) do i
return VarName(vn, (vn.indexing..., Tuple(i)))
end
return unwrap_right_left_vns(right, left, vns)
end
function unwrap_right_left_vns(
right::Union{Distribution,AbstractArray{<:Distribution}},
left::AbstractArray,
vn::VarName,
)
vns = map(CartesianIndices(left)) do i
return VarName(vn, (vn.indexing..., Tuple(i)))
end
return unwrap_right_left_vns(right, left, vns)
end

#################
# Main Compiler #
#################
Expand Down Expand Up @@ -242,7 +285,7 @@ function generate_tilde(left, right)
# If the LHS is a literal, it is always an observation
if !(left isa Symbol || left isa Expr)
return quote
$(DynamicPPL.tilde_observe)(
$(DynamicPPL.tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand All @@ -260,17 +303,18 @@ function generate_tilde(left, right)
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
$left = $(DynamicPPL.tilde_assume!)(
__rng__,
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$vn,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$inds,
__varinfo__,
)
else
$(DynamicPPL.tilde_observe)(
$(DynamicPPL.tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand All @@ -292,7 +336,7 @@ function generate_dot_tilde(left, right)
# If the LHS is a literal, it is always an observation
if !(left isa Symbol || left isa Expr)
return quote
$(DynamicPPL.dot_tilde_observe)(
$(DynamicPPL.dot_tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand All @@ -310,18 +354,18 @@ function generate_dot_tilde(left, right)
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__rng__,
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$inds,
__varinfo__,
)
else
$(DynamicPPL.dot_tilde_observe)(
$(DynamicPPL.dot_tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand Down
Loading