Skip to content

Commit

Permalink
Fix for unwrap_left_right_vns (#297)
Browse files Browse the repository at this point in the history
I just noticed a bug I introduced in a recent PR when looking at #295 . This PR fixes it. I'll add tests, a sec. @yebai
  • Loading branch information
torfjelde committed Jul 30, 2021
1 parent 5472d9d commit fa88dd3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.13.0"
version = "0.13.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
31 changes: 30 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,35 @@ left-hand side of a `.~` expression such as `x .~ Normal()`.
This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
variables.
# Example
```jldoctest; setup=:(using Distributions)
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); map(string, vns)
2-element Vector{String}:
"x[:,1]"
"x[:,2]"
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); map(string, vns)
1×2 Matrix{String}:
"x[:][1,1]" "x[:][1,2]"
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); map(string, vns)
3-element Vector{String}:
"x[1][1]"
"x[1][2]"
"x[1][3]"
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); map(string, vns)
1×2×3 Array{String, 3}:
[:, :, 1] =
"x[1,1,1]" "x[1,2,1]"
[:, :, 2] =
"x[1,1,2]" "x[1,2,2]"
[:, :, 3] =
"x[1,1,3]" "x[1,2,3]"
```
"""
unwrap_right_left_vns(right, left, vns) = right, left, vns
function unwrap_right_left_vns(right::NamedDist, left, vns)
Expand All @@ -103,7 +132,7 @@ function unwrap_right_left_vns(
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
return VarName(vn, (vn.indexing..., (Colon(), i)))
end
return unwrap_right_left_vns(right, left, vns)
end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0"
AbstractPPL = "0.1.4, 0.2"
AbstractPPL = "0.2"
Bijectors = "0.9.5"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
Expand Down

0 comments on commit fa88dd3

Please sign in to comment.