Skip to content

Commit

Permalink
Remove unnecessary namespace qualifications (#508)
Browse files Browse the repository at this point in the history
* Remove unnecessary namespace qualifications

* Update prob_macro.jl
  • Loading branch information
devmotion authored Jul 30, 2023
1 parent 5e0e562 commit 5dd7c53
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 24 deletions.
7 changes: 4 additions & 3 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Markdown: Markdown

using LibGit2: LibGit2
using Pkg: Pkg
using Random: Random

export weave_benchmarks

Expand All @@ -33,9 +34,9 @@ function benchmark_typed_varinfo!(suite, m)
end

function typed_code(m, vi=VarInfo(m))
rng = DynamicPPL.Random.MersenneTwister(42)
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.SamplingContext(rng, spl, DynamicPPL.DefaultContext())
rng = Random.MersenneTwister(42)
spl = SampleFromPrior()
ctx = SamplingContext(rng, spl, DefaultContext())

results = code_typed(m.f, Base.typesof(m, vi, ctx, m.args...))
return first(results)
Expand Down
4 changes: 2 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
logjoint(model, argvals_dict)
end
end

Expand Down Expand Up @@ -1138,7 +1138,7 @@ function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
logprior(model, argvals_dict)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
Mutate `out` to map each variable name in `model`/`varinfo` to its value in
`chain` at `chain_idx` and `iteration_idx`.
"""
function values_from_chain!(model::DynamicPPL.Model, chain, chain_idx, iteration_idx, out)
function values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out)
return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out)
end

Expand Down Expand Up @@ -196,7 +196,7 @@ julia> conditioned_model() # <= results in same values as the `first(iter)` abo
(0.5805148626851955, 0.7393275279160691)
```
"""
function value_iterator_from_chain(model::DynamicPPL.Model, chain)
function value_iterator_from_chain(model::Model, chain)
return value_iterator_from_chain(VarInfo(model), chain)
end

Expand Down
4 changes: 1 addition & 3 deletions src/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ function logprior(
foreach(keys(vi.metadata)) do n
@assert n in keys(left) "Variable $n is not defined."
end
return getlogp(
last(DynamicPPL.evaluate!!(model, vi, SampleFromPrior(), PriorContext(left)))
)
return getlogp(last(evaluate!!(model, vi, SampleFromPrior(), PriorContext(left))))
end

@generated function make_prior_model(
Expand Down
4 changes: 2 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ end
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names}
nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end
Expand All @@ -503,7 +503,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
for vn in keys(vi)
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
Expand Down
8 changes: 4 additions & 4 deletions src/submodel_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ end
prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx)
function prefix_submodel_context(prefix, ctx)
# E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated.
return :($(DynamicPPL.PrefixContext){$(Symbol)($(esc(prefix)))}($ctx))
return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx))
end

function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx)
# E.g. `prefix="asd"`.
return :($(DynamicPPL.PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx))
return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx))
end

function prefix_submodel_context(prefix::Bool, ctx)
Expand Down Expand Up @@ -225,7 +225,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__))
return if args_assign === nothing
ctx = prefix_submodel_context(prefix, ctx)
quote
$retval, $(esc(:__varinfo__)) = $(DynamicPPL._evaluate!!)(
$retval, $(esc(:__varinfo__)) = $(_evaluate!!)(
$(esc(expr)), $(esc(:__varinfo__)), $(ctx)
)
$retval
Expand All @@ -241,7 +241,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__))
)
end
quote
$retval, $(esc(:__varinfo__)) = $(DynamicPPL._evaluate!!)(
$retval, $(esc(:__varinfo__)) = $(_evaluate!!)(
$(esc(R)), $(esc(:__varinfo__)), $(ctx)
)
$(esc(L)) = $retval
Expand Down
12 changes: 6 additions & 6 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,10 +609,10 @@ const UnivariateAssumeDemoModels = Union{
function posterior_mean(model::UnivariateAssumeDemoModels)
return (s=49 / 24, m=7 / 6)
end
function likelihood_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
function likelihood_optima(::UnivariateAssumeDemoModels)
return (s=1 / 16, m=7 / 4)
end
function posterior_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
function posterior_optima(::UnivariateAssumeDemoModels)
# TODO: Figure out exact for `s`.
return (s=0.907407, m=7 / 6)
end
Expand Down Expand Up @@ -649,7 +649,7 @@ function posterior_mean(model::MultivariateAssumeDemoModels)

return vals
end
function likelihood_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
function likelihood_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

Expand All @@ -662,7 +662,7 @@ function likelihood_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoMod

return vals
end
function posterior_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
function posterior_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

Expand Down Expand Up @@ -704,7 +704,7 @@ function posterior_mean(model::MatrixvariateAssumeDemoModels)

return vals
end
function likelihood_optima(model::DynamicPPL.TestUtils.MatrixvariateAssumeDemoModels)
function likelihood_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

Expand All @@ -717,7 +717,7 @@ function likelihood_optima(model::DynamicPPL.TestUtils.MatrixvariateAssumeDemoMo

return vals
end
function posterior_optima(model::DynamicPPL.TestUtils.MatrixvariateAssumeDemoModels)
function posterior_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

Expand Down
2 changes: 1 addition & 1 deletion src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function dot_tilde_assume(
for (vn, ri) in zip(vns, eachcol(r))
# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
vi = DynamicPPL.setindex!!(vi, isinverse ? ri : b(ri), vn)
vi = setindex!!(vi, isinverse ? ri : b(ri), vn)
end

return r, lp, vi
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ function varname_leaves(vn::VarName, val::AbstractArray)
I in CartesianIndices(val)
)
end
function varname_leaves(vn::DynamicPPL.VarName, val::NamedTuple)
function varname_leaves(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
lens = Setfield.PropertyLens{sym}()
varname_leaves(vn lens, get(val, lens))
Expand Down

0 comments on commit 5dd7c53

Please sign in to comment.