-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Merged by Bors] - Sibling PR of introduction of Setfield.jl in AbstractPPL.jl #295
Changes from 5 commits
1388502
678ef1d
ddf761c
b867d00
6ad5d95
4bf663f
e4922c9
ab4b384
9dadb3a
26216e3
405d52c
505d690
9f8c47b
0a47953
0c51329
54f8c89
d3fe07c
8134a1a
9d3c1dd
5c4bf0e
4121b0e
a38168a
51e7426
0a3655d
f82579e
9aa7298
b130db9
1ae04b3
475da88
4af7e30
e03ef4e
9b79b61
6dd6de9
4c7e882
4c325c3
fa228d8
edad225
553ae9b
00d8411
a99a2b1
472629d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ function isassumption(expr::Union{Symbol,Expr}) | |
vn = gensym(:vn) | ||
|
||
return quote | ||
let $vn = $(varname(expr)) | ||
let $vn = $(varname(expr, true)) | ||
# This branch should compile nicely in all cases except for partial missing data | ||
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` | ||
if !$(DynamicPPL.inargnames)($vn, __model__) || | ||
|
@@ -38,7 +38,7 @@ isassumption(expr) = :(false) | |
|
||
# If we're working with, say, a `Symbol`, then we're not going to `view`. | ||
maybe_view(x) = x | ||
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x))) | ||
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@views($x))) | ||
|
||
# If the result of a `view` is a zero-dim array then it's just a | ||
# single element. Likely the rest is expecting type `eltype(x)`, hence | ||
|
@@ -90,6 +90,28 @@ left-hand side of a `.~` expression such as `x .~ Normal()`. | |
|
||
This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the | ||
variables. | ||
|
||
# Examples | ||
```jldoctest | ||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns | ||
2-element Vector{VarName{:x, Setfield.IndexLens{Tuple{Colon, Int64}}}}: | ||
x[:,1] | ||
x[:,2] | ||
|
||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns | ||
1Γ2 Matrix{VarName{:x, Setfield.IndexLens{Tuple{Int64, Int64}}}}: | ||
x[1,1] x[1,2] | ||
|
||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns | ||
1Γ2 Matrix{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Colon}}, Setfield.IndexLens{Tuple{Int64, Int64}}}}}: | ||
x[:][1,1] x[:][1,2] | ||
|
||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns | ||
3-element Vector{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Int64}}, Setfield.IndexLens{Tuple{Int64}}}}}: | ||
x[1][1] | ||
x[1][2] | ||
x[1][3] | ||
``` | ||
""" | ||
unwrap_right_left_vns(right, left, vns) = right, left, vns | ||
function unwrap_right_left_vns(right::NamedDist, left, vns) | ||
|
@@ -103,7 +125,7 @@ function unwrap_right_left_vns( | |
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, | ||
# and we therefore add the `Colon()` below. | ||
vns = map(axes(left, 2)) do i | ||
return VarName(vn, (vn.indexing..., Colon(), Tuple(i))) | ||
return vn β Setfield.IndexLens((Colon(), i)) | ||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
return unwrap_right_left_vns(right, left, vns) | ||
end | ||
|
@@ -113,7 +135,7 @@ function unwrap_right_left_vns( | |
vn::VarName, | ||
) | ||
vns = map(CartesianIndices(left)) do i | ||
return VarName(vn, (vn.indexing..., Tuple(i))) | ||
return vn β Setfield.IndexLens(Tuple(i)) | ||
end | ||
return unwrap_right_left_vns(right, left, vns) | ||
end | ||
|
@@ -271,6 +293,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn) | |
# Do not touch interpolated expressions | ||
expr.head === :$ && return expr.args[1] | ||
|
||
# Do we don't want escaped expressions because we unfortunately | ||
# escape the entire body afterwards. | ||
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we do the wrong thing and escape the entire body of the method in This is a hack to essentially ensure that any escaping will be removed. Note that this doesn't break anything because before we couldn't even use escaped expressions within |
||
|
||
# If it's a macro, we expand it | ||
if Meta.isexpr(expr, :macrocall) | ||
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) | ||
|
@@ -303,95 +329,161 @@ function generate_mainbody!(mod, found, expr::Expr, warn) | |
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) | ||
end | ||
|
||
function generate_tilde_literal(left, right) | ||
# If the LHS is a literal, it is always an observation | ||
return quote | ||
$(DynamicPPL.tilde_observe!)( | ||
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ | ||
) | ||
end | ||
end | ||
|
||
""" | ||
generate_tilde(left, right) | ||
|
||
Generate an `observe` expression for data variables and `assume` expression for parameter | ||
variables. | ||
""" | ||
function generate_tilde(left, right) | ||
# If the LHS is a literal, it is always an observation | ||
if isliteral(left) | ||
return quote | ||
$(DynamicPPL.tilde_observe!)( | ||
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ | ||
) | ||
end | ||
end | ||
isliteral(left) && return generate_tilde_literal(left, right) | ||
|
||
# Otherwise it is determined by the model or its value, | ||
# if the LHS represents an observation | ||
@gensym vn inds isassumption | ||
@gensym vn isassumption | ||
|
||
return quote | ||
$vn = $(varname(left)) | ||
$inds = $(vinds(left)) | ||
$isassumption = $(DynamicPPL.isassumption(left)) | ||
$vn = $(remove_escape(varname(left, true))) | ||
$isassumption = $(remove_escape(DynamicPPL.isassumption(left))) | ||
if $isassumption | ||
$left = $(DynamicPPL.tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_vn)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $vn | ||
)..., | ||
$inds, | ||
__varinfo__, | ||
) | ||
$(generate_tilde_assume(left, right, vn)) | ||
else | ||
$(DynamicPPL.tilde_observe!)( | ||
__context__, | ||
$(DynamicPPL.check_tilde_rhs)($right), | ||
$(maybe_view(left)), | ||
$vn, | ||
$inds, | ||
__varinfo__, | ||
) | ||
end | ||
end | ||
end | ||
|
||
function generate_tilde_assume(left::Symbol, right, vn) | ||
return quote | ||
$left = $(DynamicPPL.tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_vn)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $vn | ||
)..., | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
__varinfo__, | ||
) | ||
end | ||
end | ||
|
||
function generate_tilde_assume(left::Expr, right, vn) | ||
expr = :( | ||
$left = $(DynamicPPL.tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_vn)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $vn | ||
)..., | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
__varinfo__, | ||
) | ||
) | ||
|
||
return remove_escape(setmacro(identity, expr, overwrite=true)) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
end | ||
|
||
""" | ||
generate_dot_tilde(left, right) | ||
|
||
Generate the expression that replaces `left .~ right` in the model body. | ||
""" | ||
function generate_dot_tilde(left, right) | ||
# If the LHS is a literal, it is always an observation | ||
if isliteral(left) | ||
return quote | ||
$(DynamicPPL.dot_tilde_observe!)( | ||
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ | ||
) | ||
end | ||
end | ||
isliteral(left) && return generate_tilde_literal(left, right) | ||
|
||
# Otherwise it is determined by the model or its value, | ||
# if the LHS represents an observation | ||
@gensym vn inds isassumption | ||
@gensym vn isassumption | ||
return quote | ||
$vn = $(varname(left)) | ||
$inds = $(vinds(left)) | ||
$vn = $(varname(left, true)) | ||
$isassumption = $(DynamicPPL.isassumption(left)) | ||
if $isassumption | ||
$left .= $(DynamicPPL.dot_tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_left_vns)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn | ||
)..., | ||
$inds, | ||
__varinfo__, | ||
) | ||
$(generate_dot_tilde_assume(left, right, vn)) | ||
else | ||
$(DynamicPPL.dot_tilde_observe!)( | ||
__context__, | ||
$(DynamicPPL.check_tilde_rhs)($right), | ||
$(maybe_view(left)), | ||
$vn, | ||
$inds, | ||
__varinfo__, | ||
) | ||
end | ||
end | ||
end | ||
|
||
function generate_dot_tilde_assume(left::Symbol, right, vn) | ||
return :( | ||
$left .= $(DynamicPPL.dot_tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_left_vns)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn | ||
)..., | ||
__varinfo__, | ||
) | ||
) | ||
end | ||
|
||
function generate_dot_tilde_assume(left::Expr, right, vn) | ||
expr = :( | ||
$left .= $(DynamicPPL.dot_tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_left_vns)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn | ||
)..., | ||
__varinfo__, | ||
) | ||
) | ||
|
||
return remove_escape(setmacro(identity, expr, overwrite=true)) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# HACK: This is unfortunate. It's a consequence of the fact that in | ||
# DynamicPPL we the entire function body. Instead we should be | ||
# more selective with our escape. Until that's the case, we remove them all. | ||
remove_escape(x) = x | ||
function remove_escape(expr::Expr) | ||
Meta.isexpr(expr, :escape) && return remove_escape(expr.args[1]) | ||
return Expr(expr.head, map(x -> remove_escape(x), expr.args)...) | ||
end | ||
|
||
# TODO: Make PR to Setfield.jl to use `gensym` for the `lens` variable. | ||
# This seems like it should be the case anyways since it allows multiple | ||
# calls to `setmacro` without any cost to the current functionality. | ||
function setmacro(lenstransform, ex::Expr; overwrite::Bool=false) | ||
@assert ex.head isa Symbol | ||
@assert length(ex.args) == 2 | ||
ref, val = ex.args | ||
obj, lens = Setfield.parse_obj_lens(ref) | ||
lens_var = gensym("lens") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is essentially copy-paste from Setfields' implementation, but we add this in here. I believe this will also be sorted out if we fix the "escape EVERYTHING" in If not we should just make a PR to Setfield.jl. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding my above comment about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with that, though my intention is to make a PR to Setfield.jl now that we seem to be going in the direction of using it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
dst = overwrite ? obj : gensym("_") | ||
val = esc(val) | ||
ret = if ex.head == :(=) | ||
quote | ||
$lens_var = ($lenstransform)($lens) | ||
$dst = $(Setfield.set)($obj, $lens_var, $val) | ||
end | ||
else | ||
op = get_update_op(ex.head) | ||
f = :($(Setfield._UpdateOp)($op,$val)) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
quote | ||
$lens_var = ($lenstransform)($lens) | ||
$dst = $(Setfield.modify)($f, $obj, $lens_var) | ||
end | ||
end | ||
ret | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} | ||
hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) | ||
hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@views
is now needed since we can have property-access, etc. inx
.