From 62dafc9db46d7395c376fdb6de1eac1e3b70c240 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 08:13:02 +0200 Subject: [PATCH 01/42] added WrappedContext and impls for PrefixContext and MiniBatchContext --- src/contexts.jl | 74 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 4d4f30bdc..00441a16e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -33,6 +33,44 @@ struct LikelihoodContext{Tvars} <: AbstractContext end LikelihoodContext() = LikelihoodContext(nothing) +######################## +### Wrapped contexts ### +######################## +abstract type WrappedContext{Ctx} <: AbstractContext end + +""" + childcontext(ctx) + +Returns the child-context of `ctx`. + +Returns `nothing` if `ctx` is not a `WrappedContext`. +""" +childcontext(ctx::WrappedContext) = ctx.ctx +childcontext(ctx::AbstractContext) = nothing + +""" + unwrap(ctx::AbstractContext) + +Returns the unwrapped context from `ctx`. +""" +unwrap(ctx::WrappedContext) = unwrap(ctx.ctx) +unwrap(ctx::AbstractContext) = ctx + +""" + unwrappedtype(ctx::AbstractContext) + +Returns the type of the unwrapped context from `ctx`. +""" +unwrappedtype(ctx::AbstractContext) = typeof(ctx) +unwrappedtype(ctx::WrappedContext{LeafCtx}) where {LeafCtx} = LeafCtx + +""" + rewrap(parent::WrappedContext, leaf::AbstractContext) + +Rewraps `leaf` in `parent`. Supports nested `WrappedContext`. +""" +rewrap(::AbstractContext, leaf::AbstractContext) = leaf + """ struct MiniBatchContext{Tctx, T} <: AbstractContext ctx::Tctx @@ -45,21 +83,44 @@ The `MiniBatchContext` enables the computation of This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation. """ -struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx +struct MiniBatchContext{T, Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglike_scalar::T + ctx::Ctx + + function MiniBatchContext(loglike_scalar, ctx::AbstractContext) + new{typeof(loglike_scalar), typeof(ctx), typeof(ctx)}(loglike_scalar, ctx) + end + + function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}) where {LeafCtx} + new{typeof(loglike_scalar), typeof(ctx), LeafCtx}(loglike_scalar, ctx) + end end function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) + return MiniBatchContext(npoints / batch_size, ctx) +end + +function rewrap(parent::MiniBatchContext, leaf::AbstractContext) + return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) end -struct PrefixContext{Prefix,C} <: AbstractContext + +struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} ctx::C + + function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} + return new{Prefix, typeof(ctx), typeof(ctx)}(ctx) + end + function PrefixContext{Prefix}(ctx::WrappedContext{LeafCtx}) where {Prefix, LeafCtx} + return new{Prefix, typeof(ctx), LeafCtx}(ctx) + end end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(DefaultContext()) + +function rewrap(parent::PrefixContext{Prefix}, leaf::AbstractContext) where {Prefix} + return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) end + const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( @@ -81,3 +142,4 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + From 13efcc4d6d61e745186c22d2f0da400bf08ba096 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 08:15:42 +0200 Subject: [PATCH 02/42] formatting --- src/contexts.jl | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 00441a16e..a6434ecee 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -36,7 +36,7 @@ LikelihoodContext() = LikelihoodContext(nothing) ######################## ### Wrapped contexts ### ######################## -abstract type WrappedContext{Ctx} <: AbstractContext end +abstract type WrappedContext{LeafCtx} <: AbstractContext end """ childcontext(ctx) @@ -83,16 +83,16 @@ The `MiniBatchContext` enables the computation of This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation. """ -struct MiniBatchContext{T, Ctx,LeafCtx} <: WrappedContext{LeafCtx} +struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglike_scalar::T ctx::Ctx function MiniBatchContext(loglike_scalar, ctx::AbstractContext) - new{typeof(loglike_scalar), typeof(ctx), typeof(ctx)}(loglike_scalar, ctx) + return new{typeof(loglike_scalar),typeof(ctx),typeof(ctx)}(loglike_scalar, ctx) end function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}) where {LeafCtx} - new{typeof(loglike_scalar), typeof(ctx), LeafCtx}(loglike_scalar, ctx) + return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) end end function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) @@ -103,15 +103,14 @@ function rewrap(parent::MiniBatchContext, leaf::AbstractContext) return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) end - struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} ctx::C function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return new{Prefix, typeof(ctx), typeof(ctx)}(ctx) + return new{Prefix,typeof(ctx),typeof(ctx)}(ctx) end - function PrefixContext{Prefix}(ctx::WrappedContext{LeafCtx}) where {Prefix, LeafCtx} - return new{Prefix, typeof(ctx), LeafCtx}(ctx) + function PrefixContext{Prefix}(ctx::WrappedContext{LeafCtx}) where {Prefix,LeafCtx} + return new{Prefix,typeof(ctx),LeafCtx}(ctx) end end PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(DefaultContext()) @@ -120,7 +119,6 @@ function rewrap(parent::PrefixContext{Prefix}, leaf::AbstractContext) where {Pre return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) end - const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( @@ -142,4 +140,3 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end - From 14f9211da9562b0ff9bd8afbf6ae81e5031bc3b2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 21 May 2021 14:58:41 +0200 Subject: [PATCH 03/42] DefaultContext replaced with SampleContext and EvaluateContext --- src/context_implementations.jl | 89 ++++++++++++++++++++++++++-------- src/contexts.jl | 7 ++- src/model.jl | 4 +- src/varinfo.jl | 2 +- 4 files changed, 76 insertions(+), 26 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index afc5e4da3..5ab6403c5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,8 +18,10 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return _tilde(rng, sampler, right, vn, vi) +function tilde( + rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, vn::VarName, _, vi +) + return _tilde(rng, ctx, sampler, right, vn, vi) end function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) if ctx.vars !== nothing @@ -56,15 +58,19 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end -function _tilde(rng, sampler, right, vn::VarName, vi) +function _tilde(rng, ctx::SampleContext, sampler, right, vn::VarName, vi) return assume(rng, sampler, right, vn, vi) end -function _tilde(rng, sampler, right::NamedDist, vn::VarName, vi) - return _tilde(rng, sampler, right.dist, right.name, vi) +function _tilde(rng, ctx::EvaluateContext, sampler, right, vn::VarName, vi) + return assume(sampler, right, vn, vi) +end + +function _tilde(rng, ctx, sampler, right::NamedDist, vn::VarName, vi) + return _tilde(rng, ctx, sampler, right.dist, right.name, vi) end # observe -function tilde(ctx::DefaultContext, sampler, right, left, vi) +function tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) return _tilde(sampler, right, left, vi) end function tilde(ctx::PriorContext, sampler, right, left, vi) @@ -122,22 +128,21 @@ end function assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi ) + r = init(rng, dist, spl) if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - r = vi[vn] - end + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) else - r = init(rng, dist, spl) push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) end + settrans!(vi, false, vn) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) +end + +function assume( + spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi +) + r = vi[vn] return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end @@ -151,7 +156,9 @@ end # .~ functions # assume -function dot_tilde(rng, ctx::DefaultContext, sampler, right, left, vn::VarName, _, vi) +function dot_tilde( + rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi +) vns, dist = get_vns_and_dist(right, left, vn) return _dot_tilde(rng, sampler, dist, left, vns, vi) end @@ -209,13 +216,22 @@ function get_vns_and_dist( return getvn.(CartesianIndices(var)), dist end -function _dot_tilde(rng, sampler, right, left, vns::AbstractArray{<:VarName}, vi) +function _dot_tilde( + rng, ctx::SampleContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi +) return dot_assume(rng, sampler, right, vns, left, vi) end +function _dot_tilde( + rng, ctx::EvaluateContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi +) + return dot_assume(sampler, right, vns, left, vi) +end + # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics function _dot_tilde( rng, + ctx, sampler::AbstractSampler, right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, @@ -239,6 +255,21 @@ function dot_assume( var .= r return var, lp end + +function dot_assume( + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + vns::AbstractVector{<:VarName}, + var::AbstractMatrix, + vi, +) + @assert length(dist) == size(var, 1) + r = vi[vns] + lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + var .= r + return var, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -253,6 +284,22 @@ function dot_assume( var .= r return var, lp end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var::AbstractArray, + vi, +) + r = vi[vns] + # Make sure `r` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + var .= r + return var, lp +end + function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) return error( "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" @@ -348,7 +395,7 @@ function set_val!( end # observe -function dot_tilde(ctx::DefaultContext, sampler, right, left, vi) +function dot_tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) return _dot_tilde(sampler, right, left, vi) end function dot_tilde(ctx::PriorContext, sampler, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index a6434ecee..4e4b3cd64 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -125,7 +125,7 @@ function PrefixContext{PrefixInner}( ctx::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated - :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( + :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( ctx.ctx )) else @@ -135,8 +135,11 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing)) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) else VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +struct EvaluateContext <: AbstractContext end +struct SampleContext <: AbstractContext end diff --git a/src/model.jl b/src/model.jl index 7189b590e..150208164 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,7 +86,7 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SampleContext(), ) if Threads.nthreads() == 1 return evaluate_threadunsafe(rng, model, varinfo, sampler, context) @@ -183,7 +183,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, SampleFromPrior(), SampleContext()) return getlogp(varinfo) end diff --git a/src/varinfo.jl b/src/varinfo.jl index e5e71eed1..4140f9a0b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -126,7 +126,7 @@ function VarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SampleContext(), ) varinfo = VarInfo() model(rng, varinfo, sampler, context) From 58331ebf108435324b59cb0718efc67aab90cc8f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 21 May 2021 15:18:40 +0200 Subject: [PATCH 04/42] fixed impl for dot_tilde --- src/context_implementations.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5ab6403c5..5eefdb1d2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -160,7 +160,7 @@ function dot_tilde( rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi ) vns, dist = get_vns_and_dist(right, left, vn) - return _dot_tilde(rng, sampler, dist, left, vns, vi) + return _dot_tilde(rng, ctx, sampler, dist, left, vns, vi) end function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) @@ -286,7 +286,6 @@ function dot_assume( end function dot_assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, From 21379bf510c0dbffb73338374f8ea783d25a4c3d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 21 May 2021 15:28:46 +0200 Subject: [PATCH 05/42] make get_and_set! used in dot_assume always overwrite --- src/context_implementations.jl | 40 ++++++++++++---------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5eefdb1d2..2c0f28afd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -314,18 +314,12 @@ function get_and_set_val!( ) n = length(vns) if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = vi[vns] + r = init(rng, dist, spl, n) + for i in 1:n + vn = vns[i] + vi[vn] = vectorize(dist, r[:, i]) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) end else r = init(rng, dist, spl, n) @@ -346,20 +340,14 @@ function get_and_set_val!( spl::Union{SampleFromPrior,SampleFromUniform}, ) if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - for i in eachindex(vns) - vn = vns[i] - dist = dists isa AbstractArray ? dists[i] : dists - vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = reshape(vi[vec(vns)], size(vns)) + f = (vn, dist) -> init(rng, dist, spl) + r = f.(vns, dists) + for i in eachindex(vns) + vn = vns[i] + dist = dists isa AbstractArray ? dists[i] : dists + vi[vn] = vectorize(dist, r[i]) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) end else f = (vn, dist) -> init(rng, dist, spl) From 367a86e55d4351e3febd79727a8c55f0cb39b683 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 21 May 2021 15:43:13 +0200 Subject: [PATCH 06/42] fixed implementations for Prior and Likelihood --- src/context_implementations.jl | 8 ++--- src/contexts.jl | 63 ++++++++++++++++++---------------- src/model.jl | 2 +- 3 files changed, 38 insertions(+), 35 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 2c0f28afd..98c427002 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -28,14 +28,14 @@ function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, sampler, right, vn, vi) + return _tilde(rng, childcontext(ctx), sampler, right, vn, vi) end function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, sampler, NoDist(right), vn, vi) + return _tilde(rng, childcontext(ctx), sampler, NoDist(right), vn, vi) end function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) return tilde(rng, ctx.ctx, sampler, right, left, inds, vi) @@ -171,7 +171,7 @@ function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarNam else vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, sampler, NoDist.(dist), left, vns, vi) + return _dot_tilde(rng, childcontext(ctx), sampler, NoDist.(dist), left, vns, vi) end function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) @@ -185,7 +185,7 @@ function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, in else vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, sampler, dist, left, vns, vi) + return _dot_tilde(rng, childcontext(ctx), sampler, dist, left, vns, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index 4e4b3cd64..80e53a72e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -6,37 +6,14 @@ and parameters when running the model. """ struct DefaultContext <: AbstractContext end -""" - struct PriorContext{Tvars} <: AbstractContext - vars::Tvars - end - -The `PriorContext` enables the computation of the log prior of the parameters `vars` when -running the model. -""" -struct PriorContext{Tvars} <: AbstractContext - vars::Tvars -end -PriorContext() = PriorContext(nothing) - -""" - struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars - end - -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values -of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. -""" -struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars -end -LikelihoodContext() = LikelihoodContext(nothing) +abstract type PrimitiveContext <: AbstractContext end +struct EvaluateContext <: PrimitiveContext end +struct SampleContext <: PrimitiveContext end ######################## ### Wrapped contexts ### ######################## -abstract type WrappedContext{LeafCtx} <: AbstractContext end +abstract type WrappedContext{LeafCtx<:PrimitiveContext} <: AbstractContext end """ childcontext(ctx) @@ -71,6 +48,35 @@ Rewraps `leaf` in `parent`. Supports nested `WrappedContext`. """ rewrap(::AbstractContext, leaf::AbstractContext) = leaf +""" + struct PriorContext{Tvars} <: AbstractContext + vars::Tvars + end + +The `PriorContext` enables the computation of the log prior of the parameters `vars` when +running the model. +""" +struct PriorContext{Tvars, LeafCtx} <: WrappedContext{LeafCtx} + vars::Tvars + ctx::LeafCtx +end +PriorContext(vars=nothing, ctx=EvaluateContext()) = PriorContext{typeof(vars), typeof(ctx)}(vars, ctx) + +""" + struct LikelihoodContext{Tvars} <: AbstractContext + vars::Tvars + end + +The `LikelihoodContext` enables the computation of the log likelihood of the parameters when +running the model. `vars` can be used to evaluate the log likelihood for specific values +of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. +""" +struct LikelihoodContext{Tvars, LeafCtx} <: WrappedContext{LeafCtx} + vars::Tvars + ctx::LeafCtx +end +LikelihoodContext(vars=nothing, ctx=EvaluateContext()) = LikelihoodContext{typeof(vars), typeof(ctx)}(vars, ctx) + """ struct MiniBatchContext{Tctx, T} <: AbstractContext ctx::Tctx @@ -140,6 +146,3 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end - -struct EvaluateContext <: AbstractContext end -struct SampleContext <: AbstractContext end diff --git a/src/model.jl b/src/model.jl index 150208164..3290d4fe3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -183,7 +183,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), SampleContext()) + model(varinfo, SampleFromPrior(), EvaluateContext()) return getlogp(varinfo) end From cc2e8e6290e420b8dab8a59fbf70b79c2c07d776 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 21 May 2021 15:43:56 +0200 Subject: [PATCH 07/42] be explicit about use of EvaluateContext in logprior and loglikelihood --- src/model.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 3290d4fe3..b3742011d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -195,7 +195,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, SampleFromPrior(), PriorContext(nothing, EvaluateContext())) return getlogp(varinfo) end @@ -207,7 +207,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, SampleFromPrior(), LikelihoodContext(nothing, EvaluateContext())) return getlogp(varinfo) end From 7ac6c63143aa8acbae1adebde0198eee3efa4289 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 21 May 2021 17:00:27 +0200 Subject: [PATCH 08/42] fixed constructor for Likelihood and Prior --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 80e53a72e..2dc5cb852 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -60,7 +60,7 @@ struct PriorContext{Tvars, LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars ctx::LeafCtx end -PriorContext(vars=nothing, ctx=EvaluateContext()) = PriorContext{typeof(vars), typeof(ctx)}(vars, ctx) +PriorContext(vars=nothing) = PriorContext(vars, EvaluateContext()) """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -75,7 +75,7 @@ struct LikelihoodContext{Tvars, LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars ctx::LeafCtx end -LikelihoodContext(vars=nothing, ctx=EvaluateContext()) = LikelihoodContext{typeof(vars), typeof(ctx)}(vars, ctx) +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluateContext()) """ struct MiniBatchContext{Tctx, T} <: AbstractContext From 360c33378b9c21334f8fb1f9bba7f7213de05618 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 21 May 2021 19:49:47 +0200 Subject: [PATCH 09/42] now passing also allow passing override value to assume --- src/context_implementations.jl | 71 ++++++++++++++++++++++------------ src/contexts.jl | 6 ++- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 98c427002..77038d802 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -19,29 +19,37 @@ _getindex(x, inds::Tuple{}) = x # assume function tilde( - rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, vn::VarName, _, vi + rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi ) - return _tilde(rng, ctx, sampler, right, vn, vi) + return _tilde(rng, ctx, sampler, right, left, vn, vi) end -function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) +function tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars !== nothing vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, childcontext(ctx), sampler, right, vn, vi) + return _tilde(rng, childcontext(ctx), sampler, right, left, vn, vi) end -function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) +function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, childcontext(ctx), sampler, NoDist(right), vn, vi) + return _tilde( + rng, + rewrap(childcontext(ctx), EvaluateContext()), + sampler, + NoDist(right), + left, + vn, + vi, + ) end -function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, inds, vi) +function tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) + return tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) end -function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) +function tilde(rng, ctx::PrefixContext, sampler, right, left, vn::VarName, inds, vi) + return tilde(rng, ctx.ctx, sampler, right, left, prefix(ctx, vn), inds, vi) end """ @@ -53,20 +61,23 @@ accumulate the log probability, and return the sampled value. Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. """ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde(rng, ctx, sampler, right, vn, inds, vi) + value, logp = tilde(rng, ctx, sampler, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function _tilde(rng, ctx::SampleContext, sampler, right, vn::VarName, vi) - return assume(rng, sampler, right, vn, vi) +function _tilde(rng, ctx::SampleContext, sampler, right, left, vn::VarName, vi) + return assume(rng, sampler, right, nothing, vn, vi) +end +function _tilde(rng, ctx::EvaluateContext, sampler, right, left::Nothing, vn::VarName, vi) + return assume(sampler, right, vi[vn], vn, vi) end -function _tilde(rng, ctx::EvaluateContext, sampler, right, vn::VarName, vi) - return assume(sampler, right, vn, vi) +function _tilde(rng, ctx::EvaluateContext, sampler, right, left, vn::VarName, vi) + return assume(sampler, right, left, vn, vi) end -function _tilde(rng, ctx, sampler, right::NamedDist, vn::VarName, vi) - return _tilde(rng, ctx, sampler, right.dist, right.name, vi) +function _tilde(rng, ctx, sampler, right::NamedDist, left, vn::VarName, vi) + return _tilde(rng, ctx, sampler, right.dist, left, right.name, vi) end # observe @@ -126,7 +137,12 @@ function observe(spl::Sampler, weight) end function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + left::Nothing, + vn::VarName, + vi, ) r = init(rng, dist, spl) if haskey(vi, vn) @@ -140,17 +156,16 @@ function assume( end function assume( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, vn::VarName, vi ) - r = vi[vn] - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return left, Bijectors.logpdf_with_trans(dist, left, istrans(vi, vn)) end function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi + spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, vi ) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(dist, left) end # .~ functions @@ -171,7 +186,15 @@ function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarNam else vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, childcontext(ctx), sampler, NoDist.(dist), left, vns, vi) + return _dot_tilde( + rng, + rewrap(childcontext(ctx), EvaluateContext()), + sampler, + NoDist.(dist), + left, + vns, + vi, + ) end function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 2dc5cb852..137dde42e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -56,11 +56,12 @@ rewrap(::AbstractContext, leaf::AbstractContext) = leaf The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ -struct PriorContext{Tvars, LeafCtx} <: WrappedContext{LeafCtx} +struct PriorContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars ctx::LeafCtx end PriorContext(vars=nothing) = PriorContext(vars, EvaluateContext()) +PriorContext(ctx::AbstractContext) = PriorContext(nothing, ctx) """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -71,11 +72,12 @@ The `LikelihoodContext` enables the computation of the log likelihood of the par running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ -struct LikelihoodContext{Tvars, LeafCtx} <: WrappedContext{LeafCtx} +struct LikelihoodContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars ctx::LeafCtx end LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluateContext()) +LikelihoodContext(ctx::AbstractContext) = LikelihoodContext(nothing, ctx) """ struct MiniBatchContext{Tctx, T} <: AbstractContext From 5408cd055d6087e78a90f0080b072d79bc083865 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 22 May 2021 01:20:32 +0200 Subject: [PATCH 10/42] dont mutate VarInfo variable in dot_assume with EvaluateContext --- src/context_implementations.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 77038d802..d52c0b21a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -287,9 +287,7 @@ function dot_assume( vi, ) @assert length(dist) == size(var, 1) - r = vi[vns] - lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - var .= r + lp = sum(Bijectors.logpdf_with_trans(dist, var, istrans(vi, vns[1]))) return var, lp end @@ -315,10 +313,8 @@ function dot_assume( var::AbstractArray, vi, ) - r = vi[vns] # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - var .= r + lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) return var, lp end From 07a00e6c92694fadba7eb2b4009b4b67373c041b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 08:19:39 +0200 Subject: [PATCH 11/42] rename _tilde to tilde_primitive --- src/context_implementations.jl | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d52c0b21a..e9e2a622f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -21,21 +21,21 @@ _getindex(x, inds::Tuple{}) = x function tilde( rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi ) - return _tilde(rng, ctx, sampler, right, left, vn, vi) + return tilde_primitive(rng, ctx, sampler, right, left, vn, vi) end function tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars !== nothing vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, childcontext(ctx), sampler, right, left, vn, vi) + return tilde_primitive(rng, childcontext(ctx), sampler, right, left, vn, vi) end function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde( + return tilde_primitive( rng, rewrap(childcontext(ctx), EvaluateContext()), sampler, @@ -66,29 +66,29 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end -function _tilde(rng, ctx::SampleContext, sampler, right, left, vn::VarName, vi) +function tilde_primitive(rng, ctx::SampleContext, sampler, right, left, vn::VarName, vi) return assume(rng, sampler, right, nothing, vn, vi) end -function _tilde(rng, ctx::EvaluateContext, sampler, right, left::Nothing, vn::VarName, vi) +function tilde_primitive(rng, ctx::EvaluateContext, sampler, right, left::Nothing, vn::VarName, vi) return assume(sampler, right, vi[vn], vn, vi) end -function _tilde(rng, ctx::EvaluateContext, sampler, right, left, vn::VarName, vi) +function tilde_primitive(rng, ctx::EvaluateContext, sampler, right, left, vn::VarName, vi) return assume(sampler, right, left, vn, vi) end -function _tilde(rng, ctx, sampler, right::NamedDist, left, vn::VarName, vi) - return _tilde(rng, ctx, sampler, right.dist, left, right.name, vi) +function tilde_primitive(rng, ctx, sampler, right::NamedDist, left, vn::VarName, vi) + return tilde_primitive(rng, ctx, sampler, right.dist, left, right.name, vi) end # observe function tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) + return tilde_primitive(sampler, right, left, vi) end function tilde(ctx::PriorContext, sampler, right, left, vi) return 0 end function tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) + return tilde_primitive(sampler, right, left, vi) end function tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) @@ -126,7 +126,7 @@ function tilde_observe(ctx, sampler, right, left, vi) return left end -_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) +tilde_primitive(sampler, right, left, vi) = observe(sampler, right, left, vi) function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") @@ -175,7 +175,7 @@ function dot_tilde( rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi ) vns, dist = get_vns_and_dist(right, left, vn) - return _dot_tilde(rng, ctx, sampler, dist, left, vns, vi) + return dot_tilde_primitive(rng, ctx, sampler, dist, left, vns, vi) end function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) @@ -186,7 +186,7 @@ function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarNam else vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde( + return dot_tilde_primitive( rng, rewrap(childcontext(ctx), EvaluateContext()), sampler, @@ -208,7 +208,7 @@ function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, in else vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, childcontext(ctx), sampler, dist, left, vns, vi) + return dot_tilde_primitive(rng, childcontext(ctx), sampler, dist, left, vns, vi) end """ @@ -239,20 +239,20 @@ function get_vns_and_dist( return getvn.(CartesianIndices(var)), dist end -function _dot_tilde( +function dot_tilde_primitive( rng, ctx::SampleContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi ) return dot_assume(rng, sampler, right, vns, left, vi) end -function _dot_tilde( +function dot_tilde_primitive( rng, ctx::EvaluateContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi ) return dot_assume(sampler, right, vns, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( +function dot_tilde_primitive( rng, ctx, sampler::AbstractSampler, @@ -402,13 +402,13 @@ end # observe function dot_tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) + return dot_tilde_primitive(sampler, right, left, vi) end function dot_tilde(ctx::PriorContext, sampler, right, left, vi) return 0 end function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) + return dot_tilde_primitive(sampler, right, left, vi) end function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) @@ -443,11 +443,11 @@ function dot_tilde_observe(ctx, sampler, right, left, vi) return left end -function _dot_tilde(sampler, right, left::AbstractArray, vi) +function dot_tilde_primitive(sampler, right, left::AbstractArray, vi) return dot_observe(sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( +function dot_tilde_primitive( sampler::AbstractSampler, right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, From 7a96e461196dc50015d61b47edef24e24d444ce2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 08:24:00 +0200 Subject: [PATCH 12/42] refactoring of context_implementations --- src/context_implementations.jl | 87 ++--------------------- src/context_implementations/likelihood.jl | 43 +++++++++++ src/context_implementations/minibatch.jl | 16 +++++ src/context_implementations/prefix.jl | 7 ++ src/context_implementations/prior.jl | 27 +++++++ 5 files changed, 98 insertions(+), 82 deletions(-) create mode 100644 src/context_implementations/likelihood.jl create mode 100644 src/context_implementations/minibatch.jl create mode 100644 src/context_implementations/prefix.jl create mode 100644 src/context_implementations/prior.jl diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e9e2a622f..c3ca9c46a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -17,40 +17,17 @@ require_particles(spl::Sampler) = false _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +include("context_implementations/prior.jl") +include("context_implementations/likelihood.jl") +include("context_implementations/minibatch.jl") +include("context_implementations/prefix.jl") + # assume function tilde( rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi ) return tilde_primitive(rng, ctx, sampler, right, left, vn, vi) end -function tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return tilde_primitive(rng, childcontext(ctx), sampler, right, left, vn, vi) -end -function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return tilde_primitive( - rng, - rewrap(childcontext(ctx), EvaluateContext()), - sampler, - NoDist(right), - left, - vn, - vi, - ) -end -function tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) -end -function tilde(rng, ctx::PrefixContext, sampler, right, left, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, prefix(ctx, vn), inds, vi) -end """ tilde_assume(rng, ctx, sampler, right, vn, inds, vi) @@ -84,18 +61,6 @@ end function tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) return tilde_primitive(sampler, right, left, vi) end -function tilde(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return tilde_primitive(sampler, right, left, vi) -end -function tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) -end -function tilde(ctx::PrefixContext, sampler, right, left, vi) - return tilde(ctx.ctx, sampler, right, left, vi) -end """ tilde_observe(ctx, sampler, right, left, vname, vinds, vi) @@ -177,39 +142,6 @@ function dot_tilde( vns, dist = get_vns_and_dist(right, left, vn) return dot_tilde_primitive(rng, ctx, sampler, dist, left, vns, vi) end -function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) - end - return dot_tilde_primitive( - rng, - rewrap(childcontext(ctx), EvaluateContext()), - sampler, - NoDist.(dist), - left, - vns, - vi, - ) -end -function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) - return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) -end -function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) - end - return dot_tilde_primitive(rng, childcontext(ctx), sampler, dist, left, vns, vi) -end """ dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) @@ -404,15 +336,6 @@ end function dot_tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) return dot_tilde_primitive(sampler, right, left, vi) end -function dot_tilde(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_tilde_primitive(sampler, right, left, vi) -end -function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) -end """ dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl new file mode 100644 index 000000000..c58bd51d9 --- /dev/null +++ b/src/context_implementations/likelihood.jl @@ -0,0 +1,43 @@ +function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) + if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_primitive( + rng, + rewrap(childcontext(ctx), EvaluateContext()), + sampler, + NoDist(right), + left, + vn, + vi, + ) +end + +function tilde(ctx::LikelihoodContext, sampler, right, left, vi) + return tilde_primitive(sampler, right, left, vi) +end + +function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) + if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + var = _getindex(getfield(ctx.vars, getsym(vn)), inds) + vns, dist = get_vns_and_dist(right, var, vn) + set_val!(vi, vns, dist, var) + settrans!.(Ref(vi), false, vns) + else + vns, dist = get_vns_and_dist(right, left, vn) + end + return dot_tilde_primitive( + rng, + rewrap(childcontext(ctx), EvaluateContext()), + sampler, + NoDist.(dist), + left, + vns, + vi, + ) +end + +function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_tilde_primitive(sampler, right, left, vi) +end diff --git a/src/context_implementations/minibatch.jl b/src/context_implementations/minibatch.jl new file mode 100644 index 000000000..30679b8f6 --- /dev/null +++ b/src/context_implementations/minibatch.jl @@ -0,0 +1,16 @@ +function tilde(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) +end + +function tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) + return tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +end + +function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) + return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +end + +function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) +end + diff --git a/src/context_implementations/prefix.jl b/src/context_implementations/prefix.jl new file mode 100644 index 000000000..5a9a3c4d9 --- /dev/null +++ b/src/context_implementations/prefix.jl @@ -0,0 +1,7 @@ +function tilde(rng, ctx::PrefixContext, sampler, right, left, vn::VarName, inds, vi) + return tilde(rng, ctx.ctx, sampler, right, left, prefix(ctx, vn), inds, vi) +end + +function tilde(ctx::PrefixContext, sampler, right, left, vi) + return tilde(ctx.ctx, sampler, right, left, vi) +end diff --git a/src/context_implementations/prior.jl b/src/context_implementations/prior.jl new file mode 100644 index 000000000..631136ca5 --- /dev/null +++ b/src/context_implementations/prior.jl @@ -0,0 +1,27 @@ +function tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) + if ctx.vars !== nothing + vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_primitive(rng, childcontext(ctx), sampler, right, left, vn, vi) +end + +function tilde(ctx::PriorContext, sampler, right, left, vi) + return 0 +end + +function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) + if ctx.vars !== nothing + var = _getindex(getfield(ctx.vars, getsym(vn)), inds) + vns, dist = get_vns_and_dist(right, var, vn) + set_val!(vi, vns, dist, var) + settrans!.(Ref(vi), false, vns) + else + vns, dist = get_vns_and_dist(right, left, vn) + end + return dot_tilde_primitive(rng, childcontext(ctx), sampler, dist, left, vns, vi) +end + +function dot_tilde(ctx::PriorContext, sampler, right, left, vi) + return 0 +end From c61585e7081446a81c4cdebdfd1dc5d8e43bf26c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 08:30:27 +0200 Subject: [PATCH 13/42] added unwrap_right_vn and unwrap_right_left_vn thanks to @devmotion --- src/compiler.jl | 54 ++++++++++++++++++++++++++++++---- src/context_implementations.jl | 4 --- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bef7d11c2..31f33d272 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -52,6 +52,49 @@ end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x +""" + unwrap_right_vn(right, vn) + +Return the unwrapped distribution on the right-hand side and variable name on the left-hand +side of a `~` expression such as `x ~ Normal()`. + +This is used mainly to unwrap `NamedDist` distributions. +""" +unwrap_right_vn(right, vn) = right, vn +unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) + +""" + unwrap_right_left_vns(context, right, left, vns) + +Return the unwrapped distributions on the right-hand side and values and variable names on the +left-hand side of a `.~` expression such as `x .~ Normal()`. + +This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the +variables. +""" +unwrap_right_left_vns(right, left, vns) = right, left, vns +function unwrap_right_left_vns(right::NamedDist, left, vns) + return unwrap_right_left_vns(right.dist, left, right.name) +end +function unwrap_right_left_vns( + right::MultivariateDistribution, left::AbstractMatrix, vn::VarName +) + vns = map(axes(left, 2)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end +function unwrap_right_left_vns( + right::Union{Distribution,AbstractArray{<:Distribution}}, + left::AbstractArray, + vn::VarName, +) + vns = map(CartesianIndices(left)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end + ################# # Main Compiler # ################# @@ -264,8 +307,9 @@ function generate_tilde(left, right) __rng__, __context__, __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $vn, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., $inds, __varinfo__, ) @@ -314,9 +358,9 @@ function generate_dot_tilde(left, right) __rng__, __context__, __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_left_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + )..., $inds, __varinfo__, ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index c3ca9c46a..504cf50fb 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -53,10 +53,6 @@ function tilde_primitive(rng, ctx::EvaluateContext, sampler, right, left, vn::Va return assume(sampler, right, left, vn, vi) end -function tilde_primitive(rng, ctx, sampler, right::NamedDist, left, vn::VarName, vi) - return tilde_primitive(rng, ctx, sampler, right.dist, left, right.name, vi) -end - # observe function tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) return tilde_primitive(sampler, right, left, vi) From e3d6515a77b72b8007a1c8a999c5e3ef25ce584c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 11:54:37 +0200 Subject: [PATCH 14/42] renamed SampleContext and EvaluateContext --- src/DynamicPPL.jl | 3 ++- src/context_implementations.jl | 18 +++++++++--------- src/contexts.jl | 22 ++++++++++------------ src/model.jl | 8 ++++---- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..9eef8fac8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,7 +76,8 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts - DefaultContext, + EvaluationContext, + SamplingContext, LikelihoodContext, PriorContext, MiniBatchContext, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 504cf50fb..4370bc6ca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -24,7 +24,7 @@ include("context_implementations/prefix.jl") # assume function tilde( - rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi + rng, ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vn::VarName, _, vi ) return tilde_primitive(rng, ctx, sampler, right, left, vn, vi) end @@ -43,18 +43,18 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end -function tilde_primitive(rng, ctx::SampleContext, sampler, right, left, vn::VarName, vi) +function tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vn::VarName, vi) return assume(rng, sampler, right, nothing, vn, vi) end -function tilde_primitive(rng, ctx::EvaluateContext, sampler, right, left::Nothing, vn::VarName, vi) +function tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn::VarName, vi) return assume(sampler, right, vi[vn], vn, vi) end -function tilde_primitive(rng, ctx::EvaluateContext, sampler, right, left, vn::VarName, vi) +function tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vn::VarName, vi) return assume(sampler, right, left, vn, vi) end # observe -function tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) +function tilde(ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi) return tilde_primitive(sampler, right, left, vi) end @@ -133,7 +133,7 @@ end # assume function dot_tilde( - rng, ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vn::VarName, _, vi + rng, ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vn::VarName, _, vi ) vns, dist = get_vns_and_dist(right, left, vn) return dot_tilde_primitive(rng, ctx, sampler, dist, left, vns, vi) @@ -168,13 +168,13 @@ function get_vns_and_dist( end function dot_tilde_primitive( - rng, ctx::SampleContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi + rng, ctx::SamplingContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi ) return dot_assume(rng, sampler, right, vns, left, vi) end function dot_tilde_primitive( - rng, ctx::EvaluateContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi + rng, ctx::EvaluationContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi ) return dot_assume(sampler, right, vns, left, vi) end @@ -329,7 +329,7 @@ function set_val!( end # observe -function dot_tilde(ctx::Union{SampleContext,EvaluateContext}, sampler, right, left, vi) +function dot_tilde(ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi) return dot_tilde_primitive(sampler, right, left, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index 137dde42e..83895911e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -4,11 +4,9 @@ The `DefaultContext` is used by default to compute log the joint probability of the data and parameters when running the model. """ -struct DefaultContext <: AbstractContext end - abstract type PrimitiveContext <: AbstractContext end -struct EvaluateContext <: PrimitiveContext end -struct SampleContext <: PrimitiveContext end +struct EvaluationContext <: PrimitiveContext end +struct SamplingContext <: PrimitiveContext end ######################## ### Wrapped contexts ### @@ -42,11 +40,11 @@ unwrappedtype(ctx::AbstractContext) = typeof(ctx) unwrappedtype(ctx::WrappedContext{LeafCtx}) where {LeafCtx} = LeafCtx """ - rewrap(parent::WrappedContext, leaf::AbstractContext) + rewrap(parent::WrappedContext, leaf::PrimitiveContext) Rewraps `leaf` in `parent`. Supports nested `WrappedContext`. """ -rewrap(::AbstractContext, leaf::AbstractContext) = leaf +rewrap(::AbstractContext, leaf::PrimitiveContext) = leaf """ struct PriorContext{Tvars} <: AbstractContext @@ -60,7 +58,7 @@ struct PriorContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars ctx::LeafCtx end -PriorContext(vars=nothing) = PriorContext(vars, EvaluateContext()) +PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) PriorContext(ctx::AbstractContext) = PriorContext(nothing, ctx) """ @@ -76,7 +74,7 @@ struct LikelihoodContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars ctx::LeafCtx end -LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluateContext()) +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) LikelihoodContext(ctx::AbstractContext) = LikelihoodContext(nothing, ctx) """ @@ -103,11 +101,11 @@ struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) end end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) +function MiniBatchContext(ctx=EvaluationContext(); batch_size, npoints) return MiniBatchContext(npoints / batch_size, ctx) end -function rewrap(parent::MiniBatchContext, leaf::AbstractContext) +function rewrap(parent::MiniBatchContext, leaf::PrimitiveContext) return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) end @@ -121,9 +119,9 @@ struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} return new{Prefix,typeof(ctx),LeafCtx}(ctx) end end -PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(DefaultContext()) +PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(EvaluationContext()) -function rewrap(parent::PrefixContext{Prefix}, leaf::AbstractContext) where {Prefix} +function rewrap(parent::PrefixContext{Prefix}, leaf::PrimitiveContext) where {Prefix} return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) end diff --git a/src/model.jl b/src/model.jl index b3742011d..40e2b1b80 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,7 +86,7 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=SampleContext(), + context::AbstractContext=SamplingContext(), ) if Threads.nthreads() == 1 return evaluate_threadunsafe(rng, model, varinfo, sampler, context) @@ -183,7 +183,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), EvaluateContext()) + model(varinfo, SampleFromPrior(), EvaluationContext()) return getlogp(varinfo) end @@ -195,7 +195,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext(nothing, EvaluateContext())) + model(varinfo, SampleFromPrior(), PriorContext(nothing, EvaluationContext())) return getlogp(varinfo) end @@ -207,7 +207,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext(nothing, EvaluateContext())) + model(varinfo, SampleFromPrior(), LikelihoodContext(nothing, EvaluationContext())) return getlogp(varinfo) end From 9d43c7bb0355eb4733a7746b529c2766cb142176 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 12:01:14 +0200 Subject: [PATCH 15/42] separate contexts into separate files --- src/contexts.jl | 104 ++----------------------------------- src/contexts/likelihood.jl | 15 ++++++ src/contexts/minibatch.jl | 31 +++++++++++ src/contexts/prefix.jl | 37 +++++++++++++ src/contexts/prior.jl | 14 +++++ 5 files changed, 101 insertions(+), 100 deletions(-) create mode 100644 src/contexts/likelihood.jl create mode 100644 src/contexts/minibatch.jl create mode 100644 src/contexts/prefix.jl create mode 100644 src/contexts/prior.jl diff --git a/src/contexts.jl b/src/contexts.jl index 83895911e..85a76cb37 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -46,103 +46,7 @@ Rewraps `leaf` in `parent`. Supports nested `WrappedContext`. """ rewrap(::AbstractContext, leaf::PrimitiveContext) = leaf -""" - struct PriorContext{Tvars} <: AbstractContext - vars::Tvars - end - -The `PriorContext` enables the computation of the log prior of the parameters `vars` when -running the model. -""" -struct PriorContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} - vars::Tvars - ctx::LeafCtx -end -PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) -PriorContext(ctx::AbstractContext) = PriorContext(nothing, ctx) - -""" - struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars - end - -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values -of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. -""" -struct LikelihoodContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} - vars::Tvars - ctx::LeafCtx -end -LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) -LikelihoodContext(ctx::AbstractContext) = LikelihoodContext(nothing, ctx) - -""" - struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx - loglike_scalar::T - end - -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing -`log(prior) + log(likelihood of all the data points)` in the expectation. -""" -struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} - loglike_scalar::T - ctx::Ctx - - function MiniBatchContext(loglike_scalar, ctx::AbstractContext) - return new{typeof(loglike_scalar),typeof(ctx),typeof(ctx)}(loglike_scalar, ctx) - end - - function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}) where {LeafCtx} - return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) - end -end -function MiniBatchContext(ctx=EvaluationContext(); batch_size, npoints) - return MiniBatchContext(npoints / batch_size, ctx) -end - -function rewrap(parent::MiniBatchContext, leaf::PrimitiveContext) - return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) -end - -struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} - ctx::C - - function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return new{Prefix,typeof(ctx),typeof(ctx)}(ctx) - end - function PrefixContext{Prefix}(ctx::WrappedContext{LeafCtx}) where {Prefix,LeafCtx} - return new{Prefix,typeof(ctx),LeafCtx}(ctx) - end -end -PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(EvaluationContext()) - -function rewrap(parent::PrefixContext{Prefix}, leaf::PrimitiveContext) where {Prefix} - return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) -end - -const PREFIX_SEPARATOR = Symbol(".") - -function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} -) where {PrefixInner,PrefixOuter} - if @generated - :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - ctx.ctx - )) - else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) - end -end - -function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) - else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) - end -end +include("contexts/prior.jl") +include("contexts/likelihood.jl") +include("contexts/minibatch.jl") +include("contexts/prefix.jl") diff --git a/src/contexts/likelihood.jl b/src/contexts/likelihood.jl new file mode 100644 index 000000000..d1bc20f69 --- /dev/null +++ b/src/contexts/likelihood.jl @@ -0,0 +1,15 @@ +""" + struct LikelihoodContext{Tvars} <: AbstractContext + vars::Tvars + end + +The `LikelihoodContext` enables the computation of the log likelihood of the parameters when +running the model. `vars` can be used to evaluate the log likelihood for specific values +of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. +""" +struct LikelihoodContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} + vars::Tvars + ctx::LeafCtx +end +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) +LikelihoodContext(ctx::AbstractContext) = LikelihoodContext(nothing, ctx) diff --git a/src/contexts/minibatch.jl b/src/contexts/minibatch.jl new file mode 100644 index 000000000..934dbacc7 --- /dev/null +++ b/src/contexts/minibatch.jl @@ -0,0 +1,31 @@ +""" + struct MiniBatchContext{Tctx, T} <: AbstractContext + ctx::Tctx + loglike_scalar::T + end + +The `MiniBatchContext` enables the computation of +`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the +`loglike_scalar` field, typically equal to `the number of data points / batch size`. +This is useful in batch-based stochastic gradient descent algorithms to be optimizing +`log(prior) + log(likelihood of all the data points)` in the expectation. +""" +struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} + loglike_scalar::T + ctx::Ctx + + function MiniBatchContext(loglike_scalar, ctx::AbstractContext) + return new{typeof(loglike_scalar),typeof(ctx),typeof(ctx)}(loglike_scalar, ctx) + end + + function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}) where {LeafCtx} + return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) + end +end +function MiniBatchContext(ctx=EvaluationContext(); batch_size, npoints) + return MiniBatchContext(npoints / batch_size, ctx) +end + +function rewrap(parent::MiniBatchContext, leaf::PrimitiveContext) + return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) +end diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl new file mode 100644 index 000000000..05263c675 --- /dev/null +++ b/src/contexts/prefix.jl @@ -0,0 +1,37 @@ +struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} + ctx::C + + function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} + return new{Prefix,typeof(ctx),typeof(ctx)}(ctx) + end + function PrefixContext{Prefix}(ctx::WrappedContext{LeafCtx}) where {Prefix,LeafCtx} + return new{Prefix,typeof(ctx),LeafCtx}(ctx) + end +end +PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(EvaluationContext()) + +function rewrap(parent::PrefixContext{Prefix}, leaf::PrimitiveContext) where {Prefix} + return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) +end + +const PREFIX_SEPARATOR = Symbol(".") + +function PrefixContext{PrefixInner}( + ctx::PrefixContext{PrefixOuter} +) where {PrefixInner,PrefixOuter} + if @generated + :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( + ctx.ctx + )) + else + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + end +end + +function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} + if @generated + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) + else + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) + end +end diff --git a/src/contexts/prior.jl b/src/contexts/prior.jl new file mode 100644 index 000000000..157daf3ec --- /dev/null +++ b/src/contexts/prior.jl @@ -0,0 +1,14 @@ +""" + struct PriorContext{Tvars} <: AbstractContext + vars::Tvars + end + +The `PriorContext` enables the computation of the log prior of the parameters `vars` when +running the model. +""" +struct PriorContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} + vars::Tvars + ctx::LeafCtx +end +PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) +PriorContext(ctx::AbstractContext) = PriorContext(nothing, ctx) From 0c354ca154d82fb139ed87a000dd25928b9564c6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 11:02:51 +0100 Subject: [PATCH 16/42] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 22 +++++++++++++++++++--- src/context_implementations/minibatch.jl | 1 - 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4370bc6ca..3526221c5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -24,7 +24,14 @@ include("context_implementations/prefix.jl") # assume function tilde( - rng, ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vn::VarName, _, vi + rng, + ctx::Union{SamplingContext,EvaluationContext}, + sampler, + right, + left, + vn::VarName, + _, + vi, ) return tilde_primitive(rng, ctx, sampler, right, left, vn, vi) end @@ -46,7 +53,9 @@ end function tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vn::VarName, vi) return assume(rng, sampler, right, nothing, vn, vi) end -function tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn::VarName, vi) +function tilde_primitive( + rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn::VarName, vi +) return assume(sampler, right, vi[vn], vn, vi) end function tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vn::VarName, vi) @@ -133,7 +142,14 @@ end # assume function dot_tilde( - rng, ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vn::VarName, _, vi + rng, + ctx::Union{SamplingContext,EvaluationContext}, + sampler, + right, + left, + vn::VarName, + _, + vi, ) vns, dist = get_vns_and_dist(right, left, vn) return dot_tilde_primitive(rng, ctx, sampler, dist, left, vns, vi) diff --git a/src/context_implementations/minibatch.jl b/src/context_implementations/minibatch.jl index 30679b8f6..60dc0aae3 100644 --- a/src/context_implementations/minibatch.jl +++ b/src/context_implementations/minibatch.jl @@ -13,4 +13,3 @@ end function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) end - From c63a4e510e94be608b73ab0fc9d71acc8a8d9469 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 12:13:23 +0200 Subject: [PATCH 17/42] added some convenience to MiniBatchContext constructor --- src/contexts/minibatch.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts/minibatch.jl b/src/contexts/minibatch.jl index 934dbacc7..76506ab50 100644 --- a/src/contexts/minibatch.jl +++ b/src/contexts/minibatch.jl @@ -14,11 +14,11 @@ struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglike_scalar::T ctx::Ctx - function MiniBatchContext(loglike_scalar, ctx::AbstractContext) + function MiniBatchContext(loglike_scalar, ctx::AbstractContext=EvaluationContext()) return new{typeof(loglike_scalar),typeof(ctx),typeof(ctx)}(loglike_scalar, ctx) end - function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}) where {LeafCtx} + function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}=EvaluationContext()) where {LeafCtx} return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) end end From 857796824d5ee3e578e23a7ccd4f5af19d3d1b8d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 12:13:52 +0200 Subject: [PATCH 18/42] formatting --- src/contexts/minibatch.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts/minibatch.jl b/src/contexts/minibatch.jl index 76506ab50..168a1abca 100644 --- a/src/contexts/minibatch.jl +++ b/src/contexts/minibatch.jl @@ -18,7 +18,9 @@ struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} return new{typeof(loglike_scalar),typeof(ctx),typeof(ctx)}(loglike_scalar, ctx) end - function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}=EvaluationContext()) where {LeafCtx} + function MiniBatchContext( + loglike_scalar, ctx::WrappedContext{LeafCtx}=EvaluationContext() + ) where {LeafCtx} return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) end end From 5b42b2318d75823051432a72c8294498802c41e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:26:50 +0200 Subject: [PATCH 19/42] missed a rename of SampleContext --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4140f9a0b..5395260f1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -126,7 +126,7 @@ function VarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=SampleContext(), + context::AbstractContext=SamplingContext(), ) varinfo = VarInfo() model(rng, varinfo, sampler, context) From 320d2ed06adbc83b4f2281327e4a7f43728bc00d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:27:05 +0200 Subject: [PATCH 20/42] fixed typo of unwrap_right_left_vns --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 31f33d272..4bb611cca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -358,7 +358,7 @@ function generate_dot_tilde(left, right) __rng__, __context__, __sampler__, - $(DynamicPPL.unwrap_right_left_vn)( + $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., $inds, From 9e7691a1a66f3a3569bb46ca8f5468d722c75fcf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:27:54 +0200 Subject: [PATCH 21/42] removed usage of get_vns_and_dist since we have unwrap_right_left_vns --- src/context_implementations.jl | 38 ++++++++-------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3526221c5..4b7c0e27b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -142,17 +142,9 @@ end # assume function dot_tilde( - rng, - ctx::Union{SamplingContext,EvaluationContext}, - sampler, - right, - left, - vn::VarName, - _, - vi, + rng, ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vn, _, vi ) - vns, dist = get_vns_and_dist(right, left, vn) - return dot_tilde_primitive(rng, ctx, sampler, dist, left, vns, vi) + return dot_tilde_primitive(rng, ctx, sampler, right, left, vn, vi) end """ @@ -164,35 +156,23 @@ model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`. """ function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde(rng, ctx, sampler, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function get_vns_and_dist(dist::NamedDist, var, vn::VarName) - return get_vns_and_dist(dist.dist, var, dist.name) -end -function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName) - getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i))) - return getvn.(1:size(var, 2)), dist -end -function get_vns_and_dist( - dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vn::VarName -) - getvn = ind -> VarName(vn, (vn.indexing..., Tuple(ind))) - return getvn.(CartesianIndices(var)), dist +function dot_tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vns, vi) + return dot_assume(rng, sampler, right, vns, nothing, vi) end -function dot_tilde_primitive( - rng, ctx::SamplingContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi -) - return dot_assume(rng, sampler, right, vns, left, vi) +function dot_tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vns, vi) + return dot_assume(sampler, right, vns, left, vi) end function dot_tilde_primitive( - rng, ctx::EvaluationContext, sampler, right, left, vns::AbstractArray{<:VarName}, vi + rng, ctx::EvaluationContext, sampler, right, left::Nothing, vns, vi ) - return dot_assume(sampler, right, vns, left, vi) + return dot_assume(sampler, right, vns, vi[vns], vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics From 6358ee0ce822b7e48f467710f2c5de13c04e94e6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:28:19 +0200 Subject: [PATCH 22/42] use value passed to dot_assume rather than extracting from var_info --- src/context_implementations.jl | 10 ++++------ src/context_implementations/prefix.jl | 22 ++++++++++++++++++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4b7c0e27b..430cd97fb 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -193,14 +193,13 @@ function dot_assume( spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, vns::AbstractVector{<:VarName}, - var::AbstractMatrix, + var::Nothing, vi, ) @assert length(dist) == size(var, 1) r = get_and_set_val!(rng, vi, vns, dist, spl) lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - var .= r - return var, lp + return r, lp end function dot_assume( @@ -220,14 +219,13 @@ function dot_assume( spl::Union{SampleFromPrior,SampleFromUniform}, dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, - var::AbstractArray, + var::Nothing, vi, ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - var .= r - return var, lp + return r, lp end function dot_assume( diff --git a/src/context_implementations/prefix.jl b/src/context_implementations/prefix.jl index 5a9a3c4d9..e19dc7d73 100644 --- a/src/context_implementations/prefix.jl +++ b/src/context_implementations/prefix.jl @@ -1,7 +1,25 @@ function tilde(rng, ctx::PrefixContext, sampler, right, left, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, prefix(ctx, vn), inds, vi) + return tilde(rng, childcontext(ctx), sampler, right, left, prefix(ctx, vn), inds, vi) end function tilde(ctx::PrefixContext, sampler, right, left, vi) - return tilde(ctx.ctx, sampler, right, left, vi) + return tilde(childcontext(ctx), sampler, right, left, vi) +end + +function dot_tilde(ctx::PrefixContext, sampler, right, left, vi) + return dot_tilde(childcontext(ctx), sampler, right, left, vi) +end +function dot_tilde( + rng::Random.AbstractRNG, ctx::PrefixContext, sampler, right, left, vn, inds, vi +) + return dot_tilde( + rng, + childcontext(ctx), + sampler, + right, + left, + map(Base.Fix1(prefix, ctx), vn), + inds, + vi, + ) end From 95613d288691a60506187ce300d0b1f0289b0794 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:28:52 +0200 Subject: [PATCH 23/42] fixed constructor of MiniBatchContext --- src/contexts/minibatch.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/contexts/minibatch.jl b/src/contexts/minibatch.jl index 168a1abca..bf9ecff19 100644 --- a/src/contexts/minibatch.jl +++ b/src/contexts/minibatch.jl @@ -14,17 +14,17 @@ struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglike_scalar::T ctx::Ctx - function MiniBatchContext(loglike_scalar, ctx::AbstractContext=EvaluationContext()) + function MiniBatchContext(loglike_scalar, ctx::AbstractContext) return new{typeof(loglike_scalar),typeof(ctx),typeof(ctx)}(loglike_scalar, ctx) end - function MiniBatchContext( - loglike_scalar, ctx::WrappedContext{LeafCtx}=EvaluationContext() - ) where {LeafCtx} + function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}) where {LeafCtx} return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) end end -function MiniBatchContext(ctx=EvaluationContext(); batch_size, npoints) + +MiniBatchContext(loglike_scalar) = MiniBatchContext(loglike_scalar, EvaluationContext()) +function MiniBatchContext(ctx::AbstractContext=EvaluationContext(); batch_size, npoints) return MiniBatchContext(npoints / batch_size, ctx) end From 024e75e93476bd456bbebb7efdcbcc18ee89f341 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:42:17 +0200 Subject: [PATCH 24/42] found some more leftover typos --- src/context_implementations/likelihood.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl index c58bd51d9..6d4b27bb0 100644 --- a/src/context_implementations/likelihood.jl +++ b/src/context_implementations/likelihood.jl @@ -5,7 +5,7 @@ function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, i end return tilde_primitive( rng, - rewrap(childcontext(ctx), EvaluateContext()), + rewrap(childcontext(ctx), EvaluationContext()), sampler, NoDist(right), left, @@ -29,7 +29,7 @@ function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarNam end return dot_tilde_primitive( rng, - rewrap(childcontext(ctx), EvaluateContext()), + rewrap(childcontext(ctx), EvaluationContext()), sampler, NoDist.(dist), left, From 04461944810c75acfa63a6e31b9b27c294f2a87f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:42:53 +0200 Subject: [PATCH 25/42] moved includes in context_implementations.jl to end of file --- src/context_implementations.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 430cd97fb..d5b29487f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -17,11 +17,6 @@ require_particles(spl::Sampler) = false _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x -include("context_implementations/prior.jl") -include("context_implementations/likelihood.jl") -include("context_implementations/minibatch.jl") -include("context_implementations/prefix.jl") - # assume function tilde( rng, @@ -407,3 +402,9 @@ function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" ) end + +# includes +include("context_implementations/prior.jl") +include("context_implementations/likelihood.jl") +include("context_implementations/minibatch.jl") +include("context_implementations/prefix.jl") From 49bd18aa76aec027bf2baeb13791a4a1718fbeb2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 14:58:54 +0200 Subject: [PATCH 26/42] added definition of Model to avoid StackOverflowError --- src/model.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/model.jl b/src/model.jl index 40e2b1b80..a2a40a1b8 100644 --- a/src/model.jl +++ b/src/model.jl @@ -108,6 +108,13 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end +# without VarInfo and without AbstractSampler +function (model::Model)( + rng::Random.AbstractRNG, varinfo::AbstractVarInfo, context::AbstractContext +) + return model(rng, varinfo, SampleFromPrior(), context) +end + """ evaluate_threadunsafe(rng, model, varinfo, sampler, context) From f9bb04f261198cb9486b00f099e11a35c84aa8a4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 15:01:55 +0200 Subject: [PATCH 27/42] removed the now redundant setval_and_resample! --- src/varinfo.jl | 87 ------------------------------------------------- test/varinfo.jl | 31 ------------------ 2 files changed, 118 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 5395260f1..ea394a032 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1399,90 +1399,3 @@ function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) return indices end - -""" - setval_and_resample!(vi::AbstractVarInfo, x) - setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")` -for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these -variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::AbstractVarInfo, x) - return _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x)) -end -function setval_and_resample!( - vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - return _apply!( - _setval_and_resample_kernel!, - vi, - chains.value[sample_idx, :, chain_idx], - keys(chains), - ) -end - -function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural) - val = reduce(vcat, values[sorted_indices]) - setval!(vi, val, vn) - settrans!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end diff --git a/test/varinfo.jl b/test/varinfo.jl index c936ad67c..26d5ce9c9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -202,37 +202,6 @@ DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end end From 5bee7f2fa8721f3915e0f9abee33f4b654caa6a7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 15:02:42 +0200 Subject: [PATCH 28/42] fixed prob_macro --- src/prob_macro.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index d761e9fdc..dabd897e6 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -142,11 +142,11 @@ function logprior( # When all of model args are on the lhs of |, this is also equal to the logjoint. model = make_prior_model(left, right, _model) - vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi + vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext(SamplingContext())) : _vi foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - model(vi, SampleFromPrior(), PriorContext(left)) + model(vi, SampleFromPrior(), PriorContext(left, EvaluationContext())) return getlogp(vi) end From 87272f48552e37c2cdb7b39a531129afb2d50378 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 15:02:58 +0200 Subject: [PATCH 29/42] updated sampler.jl to work with new contexts --- src/sampler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 5e97a64e3..c124e8eb4 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -51,7 +51,7 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - model(rng, vi, sampler) + model(rng, vi, sampler, SamplingContext()) return vi, nothing end @@ -78,9 +78,9 @@ function AbstractMCMC.step( # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled if _spl isa SampleFromUniform - model(rng, vi, SampleFromPrior()) + model(rng, vi, SampleFromPrior(), EvaluationContext()) else - model(rng, vi, _spl) + model(rng, vi, _spl, EvaluationContext()) end end From e7d2344d6488ebda093f237965b3bdd666dc0069 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 18:21:03 +0200 Subject: [PATCH 30/42] fixed dot_tilde implementation for LikelihoodContext and PriorContext --- src/context_implementations/likelihood.jl | 24 ++++++++++++++--------- src/context_implementations/prior.jl | 22 +++++++++++++-------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl index 6d4b27bb0..0a1500b50 100644 --- a/src/context_implementations/likelihood.jl +++ b/src/context_implementations/likelihood.jl @@ -18,21 +18,27 @@ function tilde(ctx::LikelihoodContext, sampler, right, left, vi) return tilde_primitive(sampler, right, left, vi) end -function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) +function dot_tilde( + rng, + ctx::LikelihoodContext, + sampler, + right, + left, + vns::AbstractArray{<:VarName{sym}}, + inds, + vi, +) where {sym} + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + _getindex(getfield(ctx.vars, sym), inds) else - vns, dist = get_vns_and_dist(right, left, vn) + vi[vns] end return dot_tilde_primitive( rng, rewrap(childcontext(ctx), EvaluationContext()), sampler, - NoDist.(dist), - left, + NoDist.(right), + var, vns, vi, ) diff --git a/src/context_implementations/prior.jl b/src/context_implementations/prior.jl index 631136ca5..a48b3231e 100644 --- a/src/context_implementations/prior.jl +++ b/src/context_implementations/prior.jl @@ -10,16 +10,22 @@ function tilde(ctx::PriorContext, sampler, right, left, vi) return 0 end -function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) - settrans!.(Ref(vi), false, vns) +function dot_tilde( + rng, + ctx::PriorContext, + sampler, + right, + left, + vns::AbstractArray{<:VarName{sym}}, + inds, + vi, +) where {sym} + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + _getindex(getfield(ctx.vars, sym), inds) else - vns, dist = get_vns_and_dist(right, left, vn) + vi[vns] end - return dot_tilde_primitive(rng, childcontext(ctx), sampler, dist, left, vns, vi) + return dot_tilde_primitive(rng, childcontext(ctx), sampler, right, var, vns, vi) end function dot_tilde(ctx::PriorContext, sampler, right, left, vi) From 586e5c8987f8eddb0ec1b4d24a77513c99063e43 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 18:32:54 +0200 Subject: [PATCH 31/42] add support for replacing value without mutating VarInfo to LikelihoodContext and PriorContext --- src/context_implementations.jl | 5 +++++ src/context_implementations/likelihood.jl | 11 ++++++----- src/context_implementations/prior.jl | 11 ++++++----- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d5b29487f..0499333aa 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -17,6 +17,11 @@ require_particles(spl::Sampler) = false _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +function _getvalue(nt::NamedTuple, sym::Symbol, inds = ()) + value = getfield(nt, sym) + return _getindex(value, inds) +end + # assume function tilde( rng, diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl index 0a1500b50..10922a211 100644 --- a/src/context_implementations/likelihood.jl +++ b/src/context_implementations/likelihood.jl @@ -1,14 +1,15 @@ function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + _getvalue(ctx.vars, getsym(vn), inds) + else + vi[vn] end return tilde_primitive( rng, rewrap(childcontext(ctx), EvaluationContext()), sampler, NoDist(right), - left, + var, vn, vi, ) @@ -29,7 +30,7 @@ function dot_tilde( vi, ) where {sym} var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - _getindex(getfield(ctx.vars, sym), inds) + _getvalue(ctx.vars, sym, inds) else vi[vns] end diff --git a/src/context_implementations/prior.jl b/src/context_implementations/prior.jl index a48b3231e..e0a2869c2 100644 --- a/src/context_implementations/prior.jl +++ b/src/context_implementations/prior.jl @@ -1,9 +1,10 @@ function tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + _getvalue(ctx.vars, getsym(vn), inds) + else + vi[vn] end - return tilde_primitive(rng, childcontext(ctx), sampler, right, left, vn, vi) + return tilde_primitive(rng, childcontext(ctx), sampler, right, var, vn, vi) end function tilde(ctx::PriorContext, sampler, right, left, vi) @@ -21,7 +22,7 @@ function dot_tilde( vi, ) where {sym} var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - _getindex(getfield(ctx.vars, sym), inds) + _getvalue(ctx.vars, getsym(vn), inds) else vi[vns] end From 03fa1fa4b57c86a951dfd43ada22d6d1c9c86dd7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 18:33:59 +0200 Subject: [PATCH 32/42] formatting --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0499333aa..1673156bd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -17,7 +17,7 @@ require_particles(spl::Sampler) = false _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x -function _getvalue(nt::NamedTuple, sym::Symbol, inds = ()) +function _getvalue(nt::NamedTuple, sym::Symbol, inds=()) value = getfield(nt, sym) return _getindex(value, inds) end From 32796cbaa2ae41d14cd9335b965af3c161be42aa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 18:56:07 +0200 Subject: [PATCH 33/42] correct implementation for LikelihoodContext and PriorContext --- src/context_implementations/likelihood.jl | 4 ++-- src/context_implementations/prior.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl index 10922a211..aae710eee 100644 --- a/src/context_implementations/likelihood.jl +++ b/src/context_implementations/likelihood.jl @@ -2,7 +2,7 @@ function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, i var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) _getvalue(ctx.vars, getsym(vn), inds) else - vi[vn] + left end return tilde_primitive( rng, @@ -32,7 +32,7 @@ function dot_tilde( var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) _getvalue(ctx.vars, sym, inds) else - vi[vns] + left end return dot_tilde_primitive( rng, diff --git a/src/context_implementations/prior.jl b/src/context_implementations/prior.jl index e0a2869c2..4843dec4d 100644 --- a/src/context_implementations/prior.jl +++ b/src/context_implementations/prior.jl @@ -2,7 +2,7 @@ function tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) _getvalue(ctx.vars, getsym(vn), inds) else - vi[vn] + left end return tilde_primitive(rng, childcontext(ctx), sampler, right, var, vn, vi) end @@ -24,7 +24,7 @@ function dot_tilde( var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) _getvalue(ctx.vars, getsym(vn), inds) else - vi[vns] + left end return dot_tilde_primitive(rng, childcontext(ctx), sampler, right, var, vns, vi) end From fd3f317540663dbf29baa74573e455df7a00c2ff Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 18:56:23 +0200 Subject: [PATCH 34/42] SamplingContext will now mutate values in VarInfo, even if the values are overridden --- src/context_implementations.jl | 52 ++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 1673156bd..c488cf56b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -51,7 +51,7 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) end function tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vn::VarName, vi) - return assume(rng, sampler, right, nothing, vn, vi) + return assume(rng, sampler, right, left, vn, vi) end function tilde_primitive( rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn::VarName, vi @@ -125,6 +125,26 @@ function assume( return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end +function assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + left, + vn::VarName, + vi, +) + r = left + if haskey(vi, vn) + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) + else + push!(vi, vn, r, dist, spl) + end + settrans!(vi, false, vn) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) +end + + function assume( spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, vn::VarName, vi ) @@ -162,7 +182,7 @@ function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) end function dot_tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vns, vi) - return dot_assume(rng, sampler, right, vns, nothing, vi) + return dot_assume(rng, sampler, right, vns, left, vi) end function dot_tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vns, vi) @@ -202,6 +222,20 @@ function dot_assume( return r, lp end +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + vns::AbstractVector{<:VarName}, + var, + vi, +) + @assert length(dist) == size(var, 1) + r = set_val!(vi, vns, dist, var) + lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) + return r, lp +end + function dot_assume( spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, @@ -228,6 +262,20 @@ function dot_assume( return r, lp end +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var, + vi, +) + r = set_val!(vi, vns, dists, var) + # Make sure `r` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp +end + function dot_assume( spl::Union{SampleFromPrior,SampleFromUniform}, dists::Union{Distribution,AbstractArray{<:Distribution}}, From 0a3fe75f0f851a8b1a8667c181c14cb9879bd4cd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 23 May 2021 19:17:12 +0200 Subject: [PATCH 35/42] remove unnecessary type-specification --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index c488cf56b..d310dc79b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -50,15 +50,15 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end -function tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vn::VarName, vi) +function tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vn, vi) return assume(rng, sampler, right, left, vn, vi) end function tilde_primitive( - rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn::VarName, vi + rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn, vi ) return assume(sampler, right, vi[vn], vn, vi) end -function tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vn::VarName, vi) +function tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vn, vi) return assume(sampler, right, left, vn, vi) end From f02a51064f7ad7ce647a889a8646233fe84cd240 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 25 May 2021 17:52:31 +0100 Subject: [PATCH 36/42] renamed tilde_assume and others to tilde_assume! and similars --- src/compiler.jl | 12 +-- src/context_implementations.jl | 111 +++++++++------------- src/context_implementations/likelihood.jl | 20 ++-- src/context_implementations/minibatch.jl | 21 ++-- src/context_implementations/prefix.jl | 21 ++-- src/context_implementations/prior.jl | 12 +-- src/loglikelihoods.jl | 6 +- 7 files changed, 96 insertions(+), 107 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4bb611cca..901adc962 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -285,7 +285,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.tilde_observe)( + $(DynamicPPL.tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -303,7 +303,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume)( + $left = $(DynamicPPL.tilde_assume!)( __rng__, __context__, __sampler__, @@ -314,7 +314,7 @@ function generate_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.tilde_observe)( + $(DynamicPPL.tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -336,7 +336,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.dot_tilde_observe)( + $(DynamicPPL.dot_tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -354,7 +354,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume)( + $left .= $(DynamicPPL.dot_tilde_assume!)( __rng__, __context__, __sampler__, @@ -365,7 +365,7 @@ function generate_dot_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe)( + $(DynamicPPL.dot_tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d310dc79b..a29bc0fca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -23,48 +23,37 @@ function _getvalue(nt::NamedTuple, sym::Symbol, inds=()) end # assume -function tilde( - rng, - ctx::Union{SamplingContext,EvaluationContext}, - sampler, - right, - left, - vn::VarName, - _, - vi, -) - return tilde_primitive(rng, ctx, sampler, right, left, vn, vi) -end - """ - tilde_assume(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. """ -function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde(rng, ctx, sampler, right, nothing, vn, inds, vi) +function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + value, logp = tilde_assume(rng, ctx, sampler, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vn, vi) - return assume(rng, sampler, right, left, vn, vi) +function tilde_assume(rng, ctx::SamplingContext, sampler, right, left, vn, inds, vi) + return assume(rng, sampler, right, left, vn, inds, vi) end -function tilde_primitive( - rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn, vi +function tilde_assume( + rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn, inds, vi ) - return assume(sampler, right, vi[vn], vn, vi) + return assume(sampler, right, vi[vn], vn, inds, vi) end -function tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vn, vi) - return assume(sampler, right, left, vn, vi) +function tilde_assume(rng, ctx::EvaluationContext, sampler, right, left, vn, inds, vi) + return assume(sampler, right, left, vn, inds, vi) end # observe -function tilde(ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi) - return tilde_primitive(sampler, right, left, vi) +function tilde_observe( + ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi +) + return observe(sampler, right, left, vi) end """ @@ -73,11 +62,11 @@ end Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end @@ -88,16 +77,14 @@ end Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `_tilde_observe(ctx, sampler, right, left, vi)`. """ -function tilde_observe(ctx, sampler, right, left, vi) - logp = tilde(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, sampler, right, left, vi) + logp = tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end -tilde_primitive(sampler, right, left, vi) = observe(sampler, right, left, vi) - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -111,7 +98,8 @@ function assume( spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left::Nothing, - vn::VarName, + vn, + inds, vi, ) r = init(rng, dist, spl) @@ -130,7 +118,8 @@ function assume( spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, - vn::VarName, + vn, + inds, vi, ) r = left @@ -144,9 +133,8 @@ function assume( return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end - function assume( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, vn::VarName, vi + spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, left, vn, inds, vi ) return left, Bijectors.logpdf_with_trans(dist, left, istrans(vi, vn)) end @@ -161,48 +149,41 @@ end # .~ functions # assume -function dot_tilde( - rng, ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vn, _, vi -) - return dot_tilde_primitive(rng, ctx, sampler, right, left, vn, vi) -end - """ - dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. """ -function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde(rng, ctx, sampler, right, nothing, vn, inds, vi) +function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(rng, ctx, sampler, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function dot_tilde_primitive(rng, ctx::SamplingContext, sampler, right, left, vns, vi) +function dot_tilde_assume(rng, ctx::SamplingContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end -function dot_tilde_primitive(rng, ctx::EvaluationContext, sampler, right, left, vns, vi) +function dot_tilde_assume(rng, ctx::EvaluationContext, sampler, right, left, vns, inds, vi) return dot_assume(sampler, right, vns, left, vi) end -function dot_tilde_primitive( - rng, ctx::EvaluationContext, sampler, right, left::Nothing, vns, vi -) +function dot_tilde_assume(rng, ctx::EvaluationContext, sampler, right, left::Nothing, vns, inds, vi) return dot_assume(sampler, right, vns, vi[vns], vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_tilde_primitive( +function dot_tilde_assume( rng, ctx, sampler::AbstractSampler, right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, vn::AbstractVector{<:VarName}, + inds, vi, ) return throw(DimensionMismatch(AMBIGUITY_MSG)) @@ -371,44 +352,40 @@ function set_val!( end # observe -function dot_tilde(ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi) - return dot_tilde_primitive(sampler, right, left, vi) -end - """ - dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, sampler, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. """ -function dot_tilde_observe(ctx, sampler, right, left, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, sampler, right, left, vi) + logp = dot_tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end -function dot_tilde_primitive(sampler, right, left::AbstractArray, vi) +function dot_tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi) return dot_observe(sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_tilde_primitive( +function dot_observe( sampler::AbstractSampler, right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl index aae710eee..72e33fb70 100644 --- a/src/context_implementations/likelihood.jl +++ b/src/context_implementations/likelihood.jl @@ -1,25 +1,28 @@ -function tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) +function tilde_assume( + rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi +) var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) _getvalue(ctx.vars, getsym(vn), inds) else left end - return tilde_primitive( + return tilde_assume( rng, rewrap(childcontext(ctx), EvaluationContext()), sampler, NoDist(right), var, vn, + inds, vi, ) end -function tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return tilde_primitive(sampler, right, left, vi) +function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return tilde_observe(sampler, right, left, vi) end -function dot_tilde( +function dot_tilde_assume( rng, ctx::LikelihoodContext, sampler, @@ -34,17 +37,18 @@ function dot_tilde( else left end - return dot_tilde_primitive( + return dot_tilde_assume( rng, rewrap(childcontext(ctx), EvaluationContext()), sampler, NoDist.(right), var, vns, + inds, vi, ) end -function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_tilde_primitive(sampler, right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_tilde_observe(sampler, right, left, vi) end diff --git a/src/context_implementations/minibatch.jl b/src/context_implementations/minibatch.jl index 60dc0aae3..6e37bc2b2 100644 --- a/src/context_implementations/minibatch.jl +++ b/src/context_implementations/minibatch.jl @@ -1,15 +1,20 @@ -function tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) +function tilde_assume( + rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) end -function tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * tilde_observe(childcontext(ctx), sampler, right, left, vi) end -function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) - return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume( + rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) end -function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * + dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) end diff --git a/src/context_implementations/prefix.jl b/src/context_implementations/prefix.jl index e19dc7d73..c2e8b1741 100644 --- a/src/context_implementations/prefix.jl +++ b/src/context_implementations/prefix.jl @@ -1,18 +1,17 @@ -function tilde(rng, ctx::PrefixContext, sampler, right, left, vn::VarName, inds, vi) - return tilde(rng, childcontext(ctx), sampler, right, left, prefix(ctx, vn), inds, vi) +function tilde_assume(rng, ctx::PrefixContext, sampler, right, left, vn, inds, vi) + return tilde_assume( + rng, childcontext(ctx), sampler, right, left, prefix(ctx, vn), inds, vi + ) end -function tilde(ctx::PrefixContext, sampler, right, left, vi) - return tilde(childcontext(ctx), sampler, right, left, vi) +function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) + return tilde_observe(childcontext(ctx), sampler, right, left, vi) end -function dot_tilde(ctx::PrefixContext, sampler, right, left, vi) - return dot_tilde(childcontext(ctx), sampler, right, left, vi) -end -function dot_tilde( +function dot_tilde_assume( rng::Random.AbstractRNG, ctx::PrefixContext, sampler, right, left, vn, inds, vi ) - return dot_tilde( + return dot_tilde_assume( rng, childcontext(ctx), sampler, @@ -23,3 +22,7 @@ function dot_tilde( vi, ) end + +function dot_tilde_observe(ctx::PrefixContext, sampler, right, left, vi) + return dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) +end diff --git a/src/context_implementations/prior.jl b/src/context_implementations/prior.jl index 4843dec4d..60b7ba8c4 100644 --- a/src/context_implementations/prior.jl +++ b/src/context_implementations/prior.jl @@ -1,17 +1,17 @@ -function tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) +function tilde_assume(rng, ctx::PriorContext, sampler, right, left, vn, inds, vi) var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) _getvalue(ctx.vars, getsym(vn), inds) else left end - return tilde_primitive(rng, childcontext(ctx), sampler, right, var, vn, vi) + return tilde_assume(rng, childcontext(ctx), sampler, right, var, vn, inds, vi) end -function tilde(ctx::PriorContext, sampler, right, left, vi) +function tilde_observe(ctx::PriorContext, sampler, right, left, vi) return 0 end -function dot_tilde( +function dot_tilde_assume( rng, ctx::PriorContext, sampler, @@ -26,9 +26,9 @@ function dot_tilde( else left end - return dot_tilde_primitive(rng, childcontext(ctx), sampler, right, var, vns, vi) + return dot_tilde_assume(rng, childcontext(ctx), sampler, right, var, vns, inds, vi) end -function dot_tilde(ctx::PriorContext, sampler, right, left, vi) +function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) return 0 end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..ea514ff94 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -52,11 +52,11 @@ function Base.push!( return ctx.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume!(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) + return tilde_assume!(rng, ctx.ctx, sampler, right, vn, inds, vi) end -function dot_tilde_assume( +function dot_tilde_assume!( rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi ) value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) From fa804d47c0b36bba18a8198e2c8b0f92faa1c18f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 25 May 2021 17:57:04 +0100 Subject: [PATCH 37/42] formatting --- src/context_implementations.jl | 8 ++++++-- src/context_implementations/minibatch.jl | 8 ++------ src/loglikelihoods.jl | 12 +++++------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a29bc0fca..d860efb4c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -171,7 +171,9 @@ function dot_tilde_assume(rng, ctx::EvaluationContext, sampler, right, left, vns return dot_assume(sampler, right, vns, left, vi) end -function dot_tilde_assume(rng, ctx::EvaluationContext, sampler, right, left::Nothing, vns, inds, vi) +function dot_tilde_assume( + rng, ctx::EvaluationContext, sampler, right, left::Nothing, vns, inds, vi +) return dot_assume(sampler, right, vns, vi[vns], vi) end @@ -381,7 +383,9 @@ function dot_tilde_observe!(ctx, sampler, right, left, vi) return left end -function dot_tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi) +function dot_tilde_observe( + ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi +) return dot_observe(sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics diff --git a/src/context_implementations/minibatch.jl b/src/context_implementations/minibatch.jl index 6e37bc2b2..ccdf01ca1 100644 --- a/src/context_implementations/minibatch.jl +++ b/src/context_implementations/minibatch.jl @@ -1,6 +1,4 @@ -function tilde_assume( - rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi -) +function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi) return tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) end @@ -8,9 +6,7 @@ function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * tilde_observe(childcontext(ctx), sampler, right, left, vi) end -function dot_tilde_assume( - rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi -) +function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi) return dot_tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index ea514ff94..7e6cf2c30 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -52,19 +52,17 @@ function Base.push!( return ctx.loglikelihoods[vn] = logp end -function tilde_assume!(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume!(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) end -function dot_tilde_assume!( +function dot_tilde_assume( rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) end -function tilde_observe( +function tilde_observe!( ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi ) # This is slightly unfortunate since it is not completely generic... From 1e5864b06f22cad2598041bd28d54ddffa812e3f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 25 May 2021 21:45:42 +0100 Subject: [PATCH 38/42] updated PointwiseLikelihoodContext to the new context approach --- src/context_implementations/likelihood.jl | 4 +-- src/loglikelihoods.jl | 41 +++++++++++++++++++---- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl index 72e33fb70..40849ce4c 100644 --- a/src/context_implementations/likelihood.jl +++ b/src/context_implementations/likelihood.jl @@ -19,7 +19,7 @@ function tilde_assume( end function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return tilde_observe(sampler, right, left, vi) + return tilde_observe(childcontext(ctx), sampler, right, left, vi) end function dot_tilde_assume( @@ -50,5 +50,5 @@ function dot_tilde_assume( end function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_tilde_observe(sampler, right, left, vi) + return dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 7e6cf2c30..17b3865e4 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,5 +1,5 @@ # Context version -struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext +struct PointwiseLikelihoodContext{A,Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglikelihoods::A ctx::Ctx end @@ -7,7 +7,9 @@ end function PointwiseLikelihoodContext( likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx),unwrappedtype(ctx)}( + likelihoods, ctx + ) end function Base.push!( @@ -52,8 +54,20 @@ function Base.push!( return ctx.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +# tilde_assume( +# ::Random._GLOBAL_RNG, +# ::PointwiseLikelihoodContext{Dict{String, Vector{Float64}}, LikelihoodContext{Nothing, EvaluationContext}}, +# ::SampleFromPrior, +# ::InverseGamma{Float64}, +# ::Nothing, +# ::VarName{:s, Tuple{}}, +# ::Tuple{}, +# ::TypedVarInfo{NamedTuple{(:s, :m), Tuple{DynamicPPL.Metadata{Dict{VarName{:s, Tuple{}}, Int64}, Vector{InverseGamma{Float64}}, Vector{VarName{:s, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{VarName{:m, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} +# ) +function tilde_assume( + rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi +) + return tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) end function dot_tilde_assume( @@ -68,7 +82,22 @@ function tilde_observe!( # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) + logp = tilde_observe(childcontext(ctx), sampler, right, left, vi) + acclogp!(vi, logp) + + # track loglikelihood value + push!(ctx, vname, logp) + + return left +end + +function dot_tilde_observe!( + ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi +) + # This is slightly unfortunate since it is not completely generic... + # Ideally we would call `tilde_observe` recursively but then we don't get the + # loglikelihood value. + logp = tilde_observe(childcontext(ctx), sampler, right, left, vi) acclogp!(vi, logp) # track loglikelihood value @@ -155,7 +184,7 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model model(vi, spl, ctx) From 1e1b2e6fab1a520c895dad62918576fb47b3fc1f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 May 2021 23:11:00 +0100 Subject: [PATCH 39/42] fixed a typo in dot_tilde_observe! for PointwiseLikelihoodContext --- src/loglikelihoods.jl | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 17b3865e4..e4dd0b789 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -54,16 +54,6 @@ function Base.push!( return ctx.loglikelihoods[vn] = logp end -# tilde_assume( -# ::Random._GLOBAL_RNG, -# ::PointwiseLikelihoodContext{Dict{String, Vector{Float64}}, LikelihoodContext{Nothing, EvaluationContext}}, -# ::SampleFromPrior, -# ::InverseGamma{Float64}, -# ::Nothing, -# ::VarName{:s, Tuple{}}, -# ::Tuple{}, -# ::TypedVarInfo{NamedTuple{(:s, :m), Tuple{DynamicPPL.Metadata{Dict{VarName{:s, Tuple{}}, Int64}, Vector{InverseGamma{Float64}}, Vector{VarName{:s, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{VarName{:m, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:m, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} -# ) function tilde_assume( rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi ) @@ -97,7 +87,7 @@ function dot_tilde_observe!( # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value. - logp = tilde_observe(childcontext(ctx), sampler, right, left, vi) + logp = dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) acclogp!(vi, logp) # track loglikelihood value From 7cbe0cc5d5809fd80ba8b33fa14d0374550450af Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 29 May 2021 14:42:22 +0100 Subject: [PATCH 40/42] added missing rewrap impls and simplified constructors --- src/contexts/likelihood.jl | 12 ++++++++++-- src/contexts/minibatch.jl | 8 +++----- src/contexts/prefix.jl | 5 +---- src/contexts/prior.jl | 10 ++++++++-- src/loglikelihoods.jl | 6 ++++++ 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/contexts/likelihood.jl b/src/contexts/likelihood.jl index d1bc20f69..d2b673af9 100644 --- a/src/contexts/likelihood.jl +++ b/src/contexts/likelihood.jl @@ -7,9 +7,17 @@ The `LikelihoodContext` enables the computation of the log likelihood of the par running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ -struct LikelihoodContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} +struct LikelihoodContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars - ctx::LeafCtx + ctx::Ctx + + function LikelihoodContext(vars, ctx) + return new{typeof(vars),typeof(ctx),unwrappedtype(ctx)}(vars, ctx) + end end LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) LikelihoodContext(ctx::AbstractContext) = LikelihoodContext(nothing, ctx) + +function rewrap(parent::LikelihoodContext, leaf::PrimitiveContext) + return LikelihoodContext(parent.vars, rewrap(childcontext(parent), leaf)) +end diff --git a/src/contexts/minibatch.jl b/src/contexts/minibatch.jl index bf9ecff19..a0c257b84 100644 --- a/src/contexts/minibatch.jl +++ b/src/contexts/minibatch.jl @@ -15,11 +15,9 @@ struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} ctx::Ctx function MiniBatchContext(loglike_scalar, ctx::AbstractContext) - return new{typeof(loglike_scalar),typeof(ctx),typeof(ctx)}(loglike_scalar, ctx) - end - - function MiniBatchContext(loglike_scalar, ctx::WrappedContext{LeafCtx}) where {LeafCtx} - return new{typeof(loglike_scalar),typeof(ctx),LeafCtx}(loglike_scalar, ctx) + return new{typeof(loglike_scalar),typeof(ctx),unwrappedtype(ctx)}( + loglike_scalar, ctx + ) end end diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 05263c675..63da8cbba 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -2,10 +2,7 @@ struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} ctx::C function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return new{Prefix,typeof(ctx),typeof(ctx)}(ctx) - end - function PrefixContext{Prefix}(ctx::WrappedContext{LeafCtx}) where {Prefix,LeafCtx} - return new{Prefix,typeof(ctx),LeafCtx}(ctx) + return new{Prefix,typeof(ctx),unwrappedtype(ctx)}(ctx) end end PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(EvaluationContext()) diff --git a/src/contexts/prior.jl b/src/contexts/prior.jl index 157daf3ec..451752872 100644 --- a/src/contexts/prior.jl +++ b/src/contexts/prior.jl @@ -6,9 +6,15 @@ The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ -struct PriorContext{Tvars,LeafCtx} <: WrappedContext{LeafCtx} +struct PriorContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars - ctx::LeafCtx + ctx::Ctx + + PriorContext(vars, ctx) = new{typeof(vars),typeof(ctx),unwrappedtype(ctx)}(vars, ctx) end PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) PriorContext(ctx::AbstractContext) = PriorContext(nothing, ctx) + +function rewrap(parent::PriorContext, leaf::PrimitiveContext) + return PriorContext(parent.vars, rewrap(childcontext(parent), leaf)) +end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index e4dd0b789..976023969 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -12,6 +12,12 @@ function PointwiseLikelihoodContext( ) end +function rewrap(parent::PointwiseLikelihoodContext, leaf::PrimitiveContext) + return PointwiseLikelihoodContext( + parent.loglikelihoods, rewrap(childcontext(parent), leaf) + ) +end + function Base.push!( ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real ) From 8a44e8a9c14ba66ec7e7ad61706b370e9cc17e70 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 29 May 2021 16:44:48 +0100 Subject: [PATCH 41/42] add rng and sampler to contexts --- src/compiler.jl | 30 +++----- src/context_implementations.jl | 88 +++++++++++------------ src/context_implementations/likelihood.jl | 39 +++------- src/context_implementations/minibatch.jl | 17 +++-- src/context_implementations/prefix.jl | 27 +++---- src/context_implementations/prior.jl | 19 ++--- src/contexts.jl | 13 +++- src/loglikelihoods.jl | 21 +++--- src/model.jl | 26 +++---- 9 files changed, 115 insertions(+), 165 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 901adc962..908849519 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -286,11 +286,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -304,9 +300,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -316,7 +310,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -337,11 +330,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -355,9 +344,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -367,7 +354,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -411,7 +395,15 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + # in case someone accessed these + if __context__ isa $(DynamicPPL.SamplingContext) + __rng__ = __context__.rng + __sampler__ = __context__.sampler + end + + $(modelinfo[:body]) + end ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d860efb4c..344bc0816 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -24,63 +24,59 @@ end # assume """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(ctx, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde(ctx, right, vn, inds, vi)`. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, nothing, vn, inds, vi) +function tilde_assume!(ctx, right, vn, inds, vi) + value, logp = tilde_assume(ctx, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function tilde_assume(rng, ctx::SamplingContext, sampler, right, left, vn, inds, vi) - return assume(rng, sampler, right, left, vn, inds, vi) +function tilde_assume(ctx::SamplingContext, right, left, vn, inds, vi) + return assume(ctx.rng, ctx.sampler, right, left, vn, inds, vi) end -function tilde_assume( - rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn, inds, vi -) - return assume(sampler, right, vi[vn], vn, inds, vi) +function tilde_assume(ctx::EvaluationContext, right, left::Nothing, vn, inds, vi) + return assume(ctx.sampler, right, vi[vn], vn, inds, vi) end -function tilde_assume(rng, ctx::EvaluationContext, sampler, right, left, vn, inds, vi) - return assume(sampler, right, left, vn, inds, vi) +function tilde_assume(ctx::EvaluationContext, right, left, vn, inds, vi) + return assume(ctx.sampler, right, left, vn, inds, vi) end # observe -function tilde_observe( - ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi -) - return observe(sampler, right, left, vi) +function tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) + return observe(ctx.sampler, right, left, vi) end """ - tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe(ctx, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(ctx, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `_tilde_observe(ctx, right, left, vi)`. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end @@ -150,31 +146,29 @@ end # assume """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(ctx, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, nothing, vn, inds, vi) +function dot_tilde_assume!(ctx, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, right, nothing, vn, inds, vi) acclogp!(vi, logp) return value end -function dot_tilde_assume(rng, ctx::SamplingContext, sampler, right, left, vns, inds, vi) - return dot_assume(rng, sampler, right, vns, left, vi) +function dot_tilde_assume(ctx::SamplingContext, right, left, vns, inds, vi) + return dot_assume(ctx.rng, ctx.sampler, right, vns, left, vi) end -function dot_tilde_assume(rng, ctx::EvaluationContext, sampler, right, left, vns, inds, vi) - return dot_assume(sampler, right, vns, left, vi) +function dot_tilde_assume(ctx::EvaluationContext, right, left, vns, inds, vi) + return dot_assume(ctx.sampler, right, vns, left, vi) end -function dot_tilde_assume( - rng, ctx::EvaluationContext, sampler, right, left::Nothing, vns, inds, vi -) - return dot_assume(sampler, right, vns, vi[vns], vi) +function dot_tilde_assume(ctx::EvaluationContext, right, left::Nothing, vns, inds, vi) + return dot_assume(ctx.sampler, right, vns, vi[vns], vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics @@ -355,38 +349,36 @@ end # observe """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end -function dot_tilde_observe( - ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi -) - return dot_observe(sampler, right, left, vi) +function dot_tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) + return dot_observe(ctx.sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics function dot_observe( diff --git a/src/context_implementations/likelihood.jl b/src/context_implementations/likelihood.jl index 40849ce4c..14054856d 100644 --- a/src/context_implementations/likelihood.jl +++ b/src/context_implementations/likelihood.jl @@ -1,36 +1,20 @@ -function tilde_assume( - rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi -) +function tilde_assume(ctx::LikelihoodContext, right, left, vn::VarName, inds, vi) var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) _getvalue(ctx.vars, getsym(vn), inds) else left end return tilde_assume( - rng, - rewrap(childcontext(ctx), EvaluationContext()), - sampler, - NoDist(right), - var, - vn, - inds, - vi, + rewrap(childcontext(ctx), EvaluationContext()), NoDist(right), var, vn, inds, vi ) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return tilde_observe(childcontext(ctx), sampler, right, left, vi) +function tilde_observe(ctx::LikelihoodContext, right, left, vi) + return tilde_observe(childcontext(ctx), right, left, vi) end function dot_tilde_assume( - rng, - ctx::LikelihoodContext, - sampler, - right, - left, - vns::AbstractArray{<:VarName{sym}}, - inds, - vi, + ctx::LikelihoodContext, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi ) where {sym} var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) _getvalue(ctx.vars, sym, inds) @@ -38,17 +22,10 @@ function dot_tilde_assume( left end return dot_tilde_assume( - rng, - rewrap(childcontext(ctx), EvaluationContext()), - sampler, - NoDist.(right), - var, - vns, - inds, - vi, + rewrap(childcontext(ctx), EvaluationContext()), NoDist.(right), var, vns, inds, vi ) end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, right, left, vi) + return dot_tilde_observe(childcontext(ctx), right, left, vi) end diff --git a/src/context_implementations/minibatch.jl b/src/context_implementations/minibatch.jl index ccdf01ca1..91f5be57d 100644 --- a/src/context_implementations/minibatch.jl +++ b/src/context_implementations/minibatch.jl @@ -1,16 +1,15 @@ -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi) - return tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) +function tilde_assume(ctx::MiniBatchContext, right, left, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(childcontext(ctx), sampler, right, left, vi) +function tilde_observe(ctx::MiniBatchContext, right, left, vi) + return ctx.loglike_scalar * tilde_observe(childcontext(ctx), right, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vn, inds, vi) - return dot_tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) +function dot_tilde_assume(ctx::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * - dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) +function dot_tilde_observe(ctx::MiniBatchContext, right, left, vi) + return ctx.loglike_scalar * dot_tilde_observe(childcontext(ctx), right, left, vi) end diff --git a/src/context_implementations/prefix.jl b/src/context_implementations/prefix.jl index c2e8b1741..09329ad2f 100644 --- a/src/context_implementations/prefix.jl +++ b/src/context_implementations/prefix.jl @@ -1,28 +1,17 @@ -function tilde_assume(rng, ctx::PrefixContext, sampler, right, left, vn, inds, vi) - return tilde_assume( - rng, childcontext(ctx), sampler, right, left, prefix(ctx, vn), inds, vi - ) +function tilde_assume(ctx::PrefixContext, right, left, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, left, prefix(ctx, vn), inds, vi) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(childcontext(ctx), sampler, right, left, vi) +function tilde_observe(ctx::PrefixContext, right, left, vi) + return tilde_observe(childcontext(ctx), right, left, vi) end -function dot_tilde_assume( - rng::Random.AbstractRNG, ctx::PrefixContext, sampler, right, left, vn, inds, vi -) +function dot_tilde_assume(ctx::PrefixContext, right, left, vn, inds, vi) return dot_tilde_assume( - rng, - childcontext(ctx), - sampler, - right, - left, - map(Base.Fix1(prefix, ctx), vn), - inds, - vi, + childcontext(ctx), right, left, map(Base.Fix1(prefix, ctx), vn), inds, vi ) end -function dot_tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) +function dot_tilde_observe(ctx::PrefixContext, right, left, vi) + return dot_tilde_observe(childcontext(ctx), right, left, vi) end diff --git a/src/context_implementations/prior.jl b/src/context_implementations/prior.jl index 60b7ba8c4..c259254fe 100644 --- a/src/context_implementations/prior.jl +++ b/src/context_implementations/prior.jl @@ -1,34 +1,27 @@ -function tilde_assume(rng, ctx::PriorContext, sampler, right, left, vn, inds, vi) +function tilde_assume(ctx::PriorContext, right, left, vn, inds, vi) var = if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) _getvalue(ctx.vars, getsym(vn), inds) else left end - return tilde_assume(rng, childcontext(ctx), sampler, right, var, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, var, vn, inds, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) +function tilde_observe(ctx::PriorContext, right, left, vi) return 0 end function dot_tilde_assume( - rng, - ctx::PriorContext, - sampler, - right, - left, - vns::AbstractArray{<:VarName{sym}}, - inds, - vi, + ctx::PriorContext, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi ) where {sym} var = if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) _getvalue(ctx.vars, getsym(vn), inds) else left end - return dot_tilde_assume(rng, childcontext(ctx), sampler, right, var, vns, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, var, vns, inds, vi) end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) +function dot_tilde_observe(ctx::PriorContext, right, left, vi) return 0 end diff --git a/src/contexts.jl b/src/contexts.jl index 85a76cb37..1329d6703 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -5,8 +5,17 @@ The `DefaultContext` is used by default to compute log the joint probability of and parameters when running the model. """ abstract type PrimitiveContext <: AbstractContext end -struct EvaluationContext <: PrimitiveContext end -struct SamplingContext <: PrimitiveContext end +struct EvaluationContext{S<:AbstractSampler} <: PrimitiveContext + # TODO: do we even need the sampler these days? + sampler::S +end +EvaluationContext() = EvaluationContext(SampleFromPrior()) + +struct SamplingContext{R<:Random.AbstractRNG,S<:AbstractSampler} <: PrimitiveContext + rng::R + sampler::S +end +SamplingContext(sampler=SampleFromPrior()) = SamplingContext(Random.GLOBAL_RNG, sampler) ######################## ### Wrapped contexts ### diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 976023969..6758723c6 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -61,24 +61,24 @@ function Base.push!( end function tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi + ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi ) - return tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi + ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi ) - return dot_tilde_assume(rng, childcontext(ctx), sampler, right, left, vn, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end function tilde_observe!( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi + ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi ) # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value. - logp = tilde_observe(childcontext(ctx), sampler, right, left, vi) + logp = tilde_observe(childcontext(ctx), right, left, vi) acclogp!(vi, logp) # track loglikelihood value @@ -88,12 +88,12 @@ function tilde_observe!( end function dot_tilde_observe!( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi + ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi ) # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value. - logp = dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) + logp = dot_tilde_observe(childcontext(ctx), right, left, vi) acclogp!(vi, logp) # track loglikelihood value @@ -173,7 +173,6 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) @@ -183,7 +182,7 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, ctx) end niters = size(chain, 1) @@ -197,6 +196,6 @@ end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) + model(varinfo, ctx) return ctx.loglikelihoods end diff --git a/src/model.jl b/src/model.jl index a2a40a1b8..5151283ad 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,12 +86,12 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=SamplingContext(), + context::AbstractContext=SamplingContext(rng, sampler), ) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end function (model::Model)(args...) @@ -116,7 +116,7 @@ function (model::Model)( end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -125,13 +125,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -141,24 +141,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From ecf72adcb943356a8e557a99eadba5c5a221dd3d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 29 May 2021 16:53:23 +0100 Subject: [PATCH 42/42] formatting --- src/loglikelihoods.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6758723c6..247c19278 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -60,21 +60,15 @@ function Base.push!( return ctx.loglikelihoods[vn] = logp end -function tilde_assume( - ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi -) +function tilde_assume(ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi) return tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end -function dot_tilde_assume( - ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi -) +function dot_tilde_assume(ctx::PointwiseLikelihoodContext, right, left, vn, inds, vi) return dot_tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end -function tilde_observe!( - ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi -) +function tilde_observe!(ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi) # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value. @@ -87,9 +81,7 @@ function tilde_observe!( return left end -function dot_tilde_observe!( - ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi -) +function dot_tilde_observe!(ctx::PointwiseLikelihoodContext, right, left, vname, vinds, vi) # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value.