Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subset and merge for VarInfo (clean version) #544

Merged
merged 28 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
028a81a
added `subset` which can extract a subset of the varinfo
torfjelde Oct 8, 2023
caa6e25
added testing of `subset` for `VarInfo`
torfjelde Oct 8, 2023
cac7fa8
formatting
torfjelde Oct 8, 2023
5e41c4f
added implementation of `merge` for `VarInfo` and tests for it
torfjelde Oct 8, 2023
d5a2631
more tests
torfjelde Oct 8, 2023
0ade696
formatting
torfjelde Oct 8, 2023
db21844
improved merge_metadata for NamedTuple inputs
torfjelde Oct 9, 2023
1dbca4c
added proper handling of the `vals` in `subset`
torfjelde Oct 9, 2023
b67288f
added docs for `subset` and `merge`
torfjelde Oct 9, 2023
e43029e
added `subset` and `merge` to documentation
torfjelde Oct 9, 2023
cd4033d
formatting
torfjelde Oct 9, 2023
8f47dfe
made merge and subset part of the AbstractVarInfo interface
torfjelde Oct 13, 2023
aba9008
added implementations `subset` and `merge` for `SimpleVarInfo`
torfjelde Oct 13, 2023
3b621ae
follow standard merge semantics where the right one takes precedence
torfjelde Oct 13, 2023
2c2c90b
added proper testing of merge and subset for SimpleVarInfo too
torfjelde Oct 13, 2023
5c1ece3
forgotten inclusion in previous commit
torfjelde Oct 13, 2023
cfff96c
Update src/simple_varinfo.jl
torfjelde Oct 13, 2023
ed5d948
remove two-argument impl of merge
torfjelde Oct 13, 2023
00c36cf
formatting
torfjelde Oct 13, 2023
cf02816
forgot to add more formatting
torfjelde Oct 13, 2023
d02cb61
Merge branch 'master' into torfjelde/subset-and-merge
torfjelde Oct 13, 2023
7f01ada
removed 2-arg version of merge for abstract varinfo in favour of 3-ar…
torfjelde Oct 13, 2023
14105e0
allow inclusion of threadsafe varinfo in setup_varinfos
torfjelde Oct 13, 2023
c164d32
more tests for thread safe varinfo
torfjelde Oct 13, 2023
743162a
bugfixes for link and invlink methods when using thread safe varinfo
torfjelde Oct 13, 2023
dc9ad94
attempt at fixing docs
torfjelde Oct 13, 2023
2f320e6
fixed missing test coverage
torfjelde Oct 14, 2023
d3a9b56
formatting
torfjelde Oct 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
#### Utils

```@docs
Base.merge(::VarInfo, ::VarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export AbstractVarInfo,
SimpleVarInfo,
push!!,
empty!!,
subset,
getlogp,
setlogp!!,
acclogp!!,
Expand Down
159 changes: 159 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation
bijector::F
end

"""
merge_transformations(transformation_left, transformation_right)

Merge two transformations.

The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref).
"""
function merge_transformations(::NoTransformation, ::NoTransformation)
return NoTransformation()
end
function merge_transformations(::DynamicTransformation, ::DynamicTransformation)
return DynamicTransformation()
end
function merge_transformations(left::StaticTransformation, right::StaticTransformation)
return StaticTransformation(merge_bijectors(left.bijector, right.bijector))
end

function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform)
return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs))
end

"""
default_transformation(model::Model[, vi::AbstractVarInfo])

Expand Down Expand Up @@ -337,6 +358,144 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)}))
end

# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
# might result in a `Vector{Any}`.
"""
subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we would not already have so many getindex methods, I would have thought that getindex would be a natural name for this function. But maybe it's still an option?

Then we could have getindex(::AbstractVarInfo, ::AbstractVector{<:VarName}) -> AbstractVarInfo and getindex(::T, ::VarName) -> typeof_varname_variate, similar to [1,2,3][[1,3]] = [1, 3] and [1,2,3][2] = 2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like this yes, but I also really don't want to touch getindex in this codebase 😅

Happy to make this a long-term goal or something though!


Subset a `varinfo` to only contain the variables `vns`.

!!! warning
The ordering of the variables in the resulting `varinfo` is _not_
guaranteed to follow the ordering of the variables in `varinfo`.
Hence care must be taken, in particular when used in conjunction with
other methods which uses the vector-representation of the `varinfo`,
e.g. `getindex(varinfo, sampler)`.

# Examples
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
julia> @model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x = Vector{Float64}(undef, 2)
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
end
demo (generic function with 2 methods)

julia> model = demo();

julia> varinfo = VarInfo(model);

julia> keys(varinfo)
4-element Vector{VarName}:
s
m
x[1]
x[2]

julia> for (i, vn) in enumerate(keys(varinfo))
varinfo[vn] = i
end

julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0

julia> # Extract one with only `m`.
varinfo_subset1 = subset(varinfo, [@varname(m),]);


julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
m

julia> varinfo_subset1[@varname(m)]
2.0

julia> # Extract one with both `s` and `x[2]`.
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);

julia> keys(varinfo_subset2)
2-element Vector{VarName}:
s
x[2]

julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
1.0
4.0
```

`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref)

```jldoctest varinfo-subset
julia> # Merge the two.
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);

julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
m
s
x[2]

julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
1.0
2.0
4.0

julia> # Merge the two with the original.
varinfo_merged = merge(varinfo, varinfo_subset_merged);

julia> keys(varinfo_merged)
4-element Vector{VarName}:
s
m
x[1]
x[2]

julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
```

# Notes

## Type-stability

!!! warning
This function is only type-stable when `vns` contains only varnames
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be.
"""
function subset end

"""
merge(varinfo, other_varinfos...)

Merge varinfos into one, giving precedence to the right-most varinfo when sensible.

This is particularly useful when combined with [`subset(varinfo, vns)`](@ref).

See docstring of [`subset(varinfo, vns)`](@ref) for examples.
"""
function Base.merge(varinfo::AbstractVarInfo, varinfo_others::AbstractVarInfo...)
return merge(Base.merge(varinfo, first(varinfo_others)), Base.tail(varinfo_others)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that length(varinfo_others) == 0, and then the function would error it seems?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah true; I'll address this.

I meant to add tests for merge with more than two inputs too, but didn't get around to it. Will do that too.

end

# Avoid `StackoverFlowError` if implementation is missing.
function Base.merge(varinfo::AbstractVarInfo, varinfo_other::AbstractVarInfo)
throw(MethodError(Base.merge, (varinfo, varinfo_other)))
end
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# Transformations
"""
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
Expand Down
42 changes: 42 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,48 @@ function Base.eltype(
return V
end

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.values = _subset(varinfo.values, vns)
end

function _subset(x::AbstractDict, vns)
# NOTE: This requires `vns` to be explicitly present in `x`.
if any(!Base.Fix1(haskey, x), vns)
error(
"Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " *
"For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " *
"`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not."
)
end
C = ConstructionBase.constructorof(typeof(x))
return C(vn => x[vn] for vn in vns)
end

function _subset(x::NamedTuple, vns)
# NOTE: Here we can only handle `vns` that contain the `IdentityLens`.
if any(!==(Setfield.IdentityLens()) ∘ getlens, vns)
error(
"Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " *
"For example, `@varname(x)` is allowed, but `@varname(x[1])` is not."
)
end

syms = map(getsym, vns)
return NamedTuple{(syms...,)}((map(Base.Fix2(getindex, x), syms)...,))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

# `merge`
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
values = merge(varinfo_left.values, varinfo_right.values)
logp = getlogp(varinfo_right)
transformation = merge_transformations(
varinfo_left.transformation,
varinfo_right.transformation,
)
return SimpleVarInfo(values, logp, transformation)
end

# Context implementations
# NOTE: Evaluations, i.e. those without `rng` are shared with other
# implementations of `AbstractVarInfo`.
Expand Down
Loading