-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
subset
and merge
for VarInfo
(clean version)
#544
Changes from 16 commits
028a81a
caa6e25
cac7fa8
5e41c4f
d5a2631
0ade696
db21844
1dbca4c
b67288f
e43029e
cd4033d
8f47dfe
aba9008
3b621ae
2c2c90b
5c1ece3
cfff96c
ed5d948
00c36cf
cf02816
d02cb61
7f01ada
14105e0
c164d32
743162a
dc9ad94
2f320e6
d3a9b56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,7 @@ export AbstractVarInfo, | |
SimpleVarInfo, | ||
push!!, | ||
empty!!, | ||
subset, | ||
getlogp, | ||
setlogp!!, | ||
acclogp!!, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation | |
bijector::F | ||
end | ||
|
||
""" | ||
merge_transformations(transformation_left, transformation_right) | ||
|
||
Merge two transformations. | ||
|
||
The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref). | ||
""" | ||
function merge_transformations(::NoTransformation, ::NoTransformation) | ||
return NoTransformation() | ||
end | ||
function merge_transformations(::DynamicTransformation, ::DynamicTransformation) | ||
return DynamicTransformation() | ||
end | ||
function merge_transformations(left::StaticTransformation, right::StaticTransformation) | ||
return StaticTransformation(merge_bijectors(left.bijector, right.bijector)) | ||
end | ||
|
||
function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform) | ||
return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs)) | ||
end | ||
|
||
""" | ||
default_transformation(model::Model[, vi::AbstractVarInfo]) | ||
|
||
|
@@ -337,6 +358,144 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP | |
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) | ||
end | ||
|
||
# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert | ||
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which | ||
# might result in a `Vector{Any}`. | ||
""" | ||
subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName}) | ||
|
||
Subset a `varinfo` to only contain the variables `vns`. | ||
|
||
!!! warning | ||
The ordering of the variables in the resulting `varinfo` is _not_ | ||
guaranteed to follow the ordering of the variables in `varinfo`. | ||
Hence care must be taken, in particular when used in conjunction with | ||
other methods which uses the vector-representation of the `varinfo`, | ||
e.g. `getindex(varinfo, sampler)`. | ||
|
||
# Examples | ||
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) | ||
julia> @model function demo() | ||
s ~ InverseGamma(2, 3) | ||
m ~ Normal(0, sqrt(s)) | ||
x = Vector{Float64}(undef, 2) | ||
x[1] ~ Normal(m, sqrt(s)) | ||
x[2] ~ Normal(m, sqrt(s)) | ||
end | ||
demo (generic function with 2 methods) | ||
|
||
julia> model = demo(); | ||
|
||
julia> varinfo = VarInfo(model); | ||
|
||
julia> keys(varinfo) | ||
4-element Vector{VarName}: | ||
s | ||
m | ||
x[1] | ||
x[2] | ||
|
||
julia> for (i, vn) in enumerate(keys(varinfo)) | ||
varinfo[vn] = i | ||
end | ||
|
||
julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] | ||
4-element Vector{Float64}: | ||
1.0 | ||
2.0 | ||
3.0 | ||
4.0 | ||
|
||
julia> # Extract one with only `m`. | ||
varinfo_subset1 = subset(varinfo, [@varname(m),]); | ||
|
||
|
||
julia> keys(varinfo_subset1) | ||
1-element Vector{VarName{:m, Setfield.IdentityLens}}: | ||
m | ||
|
||
julia> varinfo_subset1[@varname(m)] | ||
2.0 | ||
|
||
julia> # Extract one with both `s` and `x[2]`. | ||
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); | ||
|
||
julia> keys(varinfo_subset2) | ||
2-element Vector{VarName}: | ||
s | ||
x[2] | ||
|
||
julia> varinfo_subset2[[@varname(s), @varname(x[2])]] | ||
2-element Vector{Float64}: | ||
1.0 | ||
4.0 | ||
``` | ||
|
||
`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) | ||
|
||
```jldoctest varinfo-subset | ||
julia> # Merge the two. | ||
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); | ||
|
||
julia> keys(varinfo_subset_merged) | ||
3-element Vector{VarName}: | ||
m | ||
s | ||
x[2] | ||
|
||
julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] | ||
3-element Vector{Float64}: | ||
1.0 | ||
2.0 | ||
4.0 | ||
|
||
julia> # Merge the two with the original. | ||
varinfo_merged = merge(varinfo, varinfo_subset_merged); | ||
|
||
julia> keys(varinfo_merged) | ||
4-element Vector{VarName}: | ||
s | ||
m | ||
x[1] | ||
x[2] | ||
|
||
julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] | ||
4-element Vector{Float64}: | ||
1.0 | ||
2.0 | ||
3.0 | ||
4.0 | ||
``` | ||
|
||
# Notes | ||
|
||
## Type-stability | ||
|
||
!!! warning | ||
This function is only type-stable when `vns` contains only varnames | ||
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will | ||
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. | ||
""" | ||
function subset end | ||
|
||
""" | ||
merge(varinfo, other_varinfos...) | ||
|
||
Merge varinfos into one, giving precedence to the right-most varinfo when sensible. | ||
|
||
This is particularly useful when combined with [`subset(varinfo, vns)`](@ref). | ||
|
||
See docstring of [`subset(varinfo, vns)`](@ref) for examples. | ||
""" | ||
function Base.merge(varinfo::AbstractVarInfo, varinfo_others::AbstractVarInfo...) | ||
return merge(Base.merge(varinfo, first(varinfo_others)), Base.tail(varinfo_others)...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's possible that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah true; I'll address this. I meant to add tests for |
||
end | ||
|
||
# Avoid `StackoverFlowError` if implementation is missing. | ||
function Base.merge(varinfo::AbstractVarInfo, varinfo_other::AbstractVarInfo) | ||
throw(MethodError(Base.merge, (varinfo, varinfo_other))) | ||
end | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Transformations | ||
""" | ||
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we would not already have so many
getindex
methods, I would have thought thatgetindex
would be a natural name for this function. But maybe it's still an option?Then we could have
getindex(::AbstractVarInfo, ::AbstractVector{<:VarName}) -> AbstractVarInfo
andgetindex(::T, ::VarName) -> typeof_varname_variate
, similar to[1,2,3][[1,3]] = [1, 3]
and[1,2,3][2] = 2
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd really like this yes, but I also really don't want to touch
getindex
in this codebase 😅Happy to make this a long-term goal or something though!