Skip to content

Commit

Permalink
Add varname tests from DPPL + format repo (#111)
Browse files Browse the repository at this point in the history
* Add varname tests from DPPL

cf. TuringLang/DynamicPPL.jl#737

* Format

* Format readme
  • Loading branch information
penelopeysm authored Dec 5, 2024
1 parent 1c5408b commit 903e0c6
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 191 deletions.
210 changes: 99 additions & 111 deletions README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ using Documenter
using AbstractPPL

# Doctest setup
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive = true)
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true)

makedocs(;
sitename = "AbstractPPL",
modules = [AbstractPPL],
pages = ["Home" => "index.md", "API" => "api.md"],
checkdocs = :exports,
doctest = false,
sitename="AbstractPPL",
modules=[AbstractPPL],
pages=["Home" => "index.md", "API" => "api.md"],
checkdocs=:exports,
doctest=false,
)
5 changes: 2 additions & 3 deletions src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ export VarName,
varname_to_string,
string_to_varname


# Abstract model functions
export AbstractProbabilisticProgram, condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!!
export AbstractProbabilisticProgram,
condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!!

# Abstract traces
export AbstractModelTrace


include("varname.jl")
include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
Expand Down
5 changes: 0 additions & 5 deletions src/abstractprobprog.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using AbstractMCMC
using DensityInterface
using Random


"""
AbstractProbabilisticProgram
Expand All @@ -12,7 +11,6 @@ abstract type AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel end

DensityInterface.DensityKind(::AbstractProbabilisticProgram) = HasDensity()


"""
logdensityof(model, trace)
Expand All @@ -26,7 +24,6 @@ probability theory.
"""
DensityInterface.logdensityof(::AbstractProbabilisticProgram, ::AbstractModelTrace)


"""
decondition(conditioned_model)
Expand All @@ -43,7 +40,6 @@ should hold for models `m` with conditioned variables `obs`.
"""
function decondition end


"""
condition(model, observations)
Expand Down Expand Up @@ -84,7 +80,6 @@ should hold for any model `m` and parameters `params`.
"""
function fix end


"""
unfix(model)
Expand Down
115 changes: 77 additions & 38 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ struct VarName{sym,T}

function VarName{sym}(optic=identity) where {sym}
if !is_static_optic(typeof(optic))
throw(ArgumentError("attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))"))
throw(
ArgumentError(
"attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))",
),
)
end
return new{sym,typeof(optic)}(optic)
end
Expand Down Expand Up @@ -168,7 +172,7 @@ end

function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T}
print(io, getsym(vn))
_show_optic(io, getoptic(vn))
return _show_optic(io, getoptic(vn))
end

# modified from https://github.com/JuliaObjects/Accessors.jl/blob/01528a81fdf17c07436e1f3d99119d3f635e4c26/src/sugar.jl#L502
Expand All @@ -181,7 +185,7 @@ function _show_optic(io::IO, optic)
print(io, "")
end
shortstr = reduce(_shortstring, inner; init="")
print(io, shortstr)
return print(io, shortstr)
end

_shortstring(prev, o::IndexLens) = "$prev[$(join(map(prettify_index, o.indices), ", "))]"
Expand All @@ -207,7 +211,6 @@ Symbol("x[1][:]")
"""
Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol


"""
inspace(vn::Union{VarName, Symbol}, space::Tuple)
Expand Down Expand Up @@ -244,7 +247,6 @@ inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space)
_in(vn::VarName, s::Symbol) = getsym(vn) == s
_in(vn::VarName, s::VarName) = subsumes(s, vn)


"""
subsumes(u::VarName, v::VarName)
Expand Down Expand Up @@ -297,8 +299,9 @@ subsumes(::typeof(identity), ::typeof(identity)) = true
subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true
subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false

subsumes(t::ComposedOptic, u::ComposedOptic) =
subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
function subsumes(t::ComposedOptic, u::ComposedOptic)
return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
end

# If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a
# leaf of the "lens-tree".
Expand All @@ -317,11 +320,12 @@ subsumes(t::PropertyLens, u::PropertyLens) = false
# FIXME: Does not support `DynamicIndexLens`.
# FIXME: Does not correctly handle cases such as `subsumes(x, x[:])`
# (but neither did old implementation).
subsumes(
function subsumes(
t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}
) = subsumes_indices(t, u)

u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
)
return subsumes_indices(t, u)
end

"""
subsumedby(t, u)
Expand Down Expand Up @@ -444,7 +448,6 @@ subsumes_index(i::Colon, j) = true
subsumes_index(i::AbstractVector, j) = issubset(j, i)
subsumes_index(i, j) = i == j


"""
ConcretizedSlice(::Base.Slice)
Expand All @@ -455,10 +458,13 @@ struct ConcretizedSlice{T,R} <: AbstractVector{T}
range::R
end

ConcretizedSlice(s::Base.Slice{R}) where {R} = ConcretizedSlice{eltype(s.indices),R}(s.indices)
function ConcretizedSlice(s::Base.Slice{R}) where {R}
return ConcretizedSlice{eltype(s.indices),R}(s.indices)
end
Base.show(io::IO, s::ConcretizedSlice) = print(io, ":")
Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) =
print(io, "ConcretizedSlice(", s.range, ")")
function Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice)
return print(io, "ConcretizedSlice(", s.range, ")")
end
Base.size(s::ConcretizedSlice) = size(s.range)
Base.iterate(s::ConcretizedSlice, state...) = Base.iterate(s.range, state...)
Base.collect(s::ConcretizedSlice) = collect(s.range)
Expand All @@ -480,8 +486,9 @@ The only purpose of this are special cases like `:`, which we want to avoid beco
`ConcretizedSlice` based on the `lowered_index`, just what you'd get with an explicit `begin:end`
"""
reconcretize_index(original_index, lowered_index) = lowered_index
reconcretize_index(original_index::Colon, lowered_index::Base.Slice) =
ConcretizedSlice(lowered_index)
function reconcretize_index(original_index::Colon, lowered_index::Base.Slice)
return ConcretizedSlice(lowered_index)
end

"""
concretize(l, x)
Expand All @@ -495,7 +502,9 @@ the result close to the original indexing.
"""
concretize(I::ALLOWED_OPTICS, x) = I
concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x)
concretize(I::IndexLens, x) = IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
function concretize(I::IndexLens, x)
return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
end
function concretize(I::ComposedOptic, x)
x_inner = I.inner(x) # TODO: get view here
return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x))
Expand Down Expand Up @@ -646,11 +655,9 @@ function varname(expr::Expr, concretize=Accessors.need_dynamic_optic(expr))
end

if concretize
return :(
$(AbstractPPL.VarName){$sym}(
return :($(AbstractPPL.VarName){$sym}(
$(AbstractPPL.concretize)($optics, $sym_escaped)
)
)
))
elseif Accessors.need_dynamic_optic(expr)
error("Variable name `$(expr)` is dynamic and requires concretization!")
else
Expand All @@ -672,7 +679,7 @@ end
function _parse_obj_optic(ex)
obj, optics = _parse_obj_optics(ex)
optic = Expr(:call, Accessors.opticcompose, optics...)
obj, optic
return obj, optic
end

# Accessors doesn't have the same support for interpolation
Expand All @@ -688,7 +695,8 @@ function _parse_obj_optics(ex)
indices = Accessors.replace_underscore.(indices, collection)
dims = length(indices) == 1 ? nothing : 1:length(indices)
lindices = esc.(Accessors.lower_index.(collection, indices, dims))
optics = :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),)))
optics =
:($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),)))
else
index = esc(Expr(:tuple, indices...))
optics = :($(Accessors.IndexLens)($index))
Expand All @@ -702,16 +710,20 @@ function _parse_obj_optics(ex)
elseif Meta.isexpr(property, :$, 1)
optics = :($(Accessors.PropertyLens){$(esc(property.args[1]))}())
else
throw(ArgumentError(
string("Error while parsing :($ex). Second argument to `getproperty` can only be",
"a `Symbol` or `String` literal, received `$property` instead.")
))
throw(
ArgumentError(
string(
"Error while parsing :($ex). Second argument to `getproperty` can only be",
"a `Symbol` or `String` literal, received `$property` instead.",
),
),
)
end
else
obj = esc(ex)
return obj, ()
end
obj, tuple(frontoptics..., optics)
return obj, tuple(frontoptics..., optics)
end

"""
Expand Down Expand Up @@ -778,12 +790,27 @@ Convert an index `i` to a dictionary representation.
"""
index_to_dict(i::Integer) = Dict("type" => _BASE_INTEGER_TYPE, "value" => i)
index_to_dict(v::Vector{Int}) = Dict("type" => _BASE_VECTOR_TYPE, "values" => v)
index_to_dict(r::UnitRange) = Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop)
index_to_dict(r::StepRange) = Dict("type" => _BASE_STEPRANGE_TYPE, "start" => r.start, "stop" => r.stop, "step" => r.step)
index_to_dict(r::Base.OneTo{I}) where {I} = Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop)
function index_to_dict(r::UnitRange)
return Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop)
end
function index_to_dict(r::StepRange)
return Dict(
"type" => _BASE_STEPRANGE_TYPE,
"start" => r.start,
"stop" => r.stop,
"step" => r.step,
)
end
function index_to_dict(r::Base.OneTo{I}) where {I}
return Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop)
end
index_to_dict(::Colon) = Dict("type" => _BASE_COLON_TYPE)
index_to_dict(s::ConcretizedSlice{T,R}) where {T,R} = Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range))
index_to_dict(t::Tuple) = Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t))
function index_to_dict(s::ConcretizedSlice{T,R}) where {T,R}
return Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range))
end
function index_to_dict(t::Tuple)
return Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t))
end

"""
dict_to_index(dict)
Expand Down Expand Up @@ -839,9 +866,17 @@ function dict_to_index(dict)
end

optic_to_dict(::typeof(identity)) = Dict("type" => "identity")
optic_to_dict(::PropertyLens{sym}) where {sym} = Dict("type" => "property", "field" => String(sym))
function optic_to_dict(::PropertyLens{sym}) where {sym}
return Dict("type" => "property", "field" => String(sym))
end
optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices))
optic_to_dict(c::ComposedOptic) = Dict("type" => "composed", "outer" => optic_to_dict(c.outer), "inner" => optic_to_dict(c.inner))
function optic_to_dict(c::ComposedOptic)
return Dict(
"type" => "composed",
"outer" => optic_to_dict(c.outer),
"inner" => optic_to_dict(c.inner),
)
end

function dict_to_optic(dict)
if dict["type"] == "identity"
Expand All @@ -857,9 +892,13 @@ function dict_to_optic(dict)
end
end

varname_to_dict(vn::VarName) = Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn)))
function varname_to_dict(vn::VarName)
return Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn)))
end

dict_to_varname(dict::Dict{<:AbstractString, Any}) = VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"]))
function dict_to_varname(dict::Dict{<:AbstractString,Any})
return VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"]))
end

"""
varname_to_string(vn::VarName)
Expand Down
5 changes: 1 addition & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ using Test
include("abstractprobprog.jl")
@testset "doctests" begin
DocMeta.setdocmeta!(
AbstractPPL,
:DocTestSetup,
:(using AbstractPPL);
recursive=true,
AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true
)
doctest(AbstractPPL; manual=false)
end
Expand Down
Loading

0 comments on commit 903e0c6

Please sign in to comment.