-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Merged by Bors] - Sibling PR of introduction of Setfield.jl in AbstractPPL.jl #295
Changes from 36 commits
1388502
678ef1d
ddf761c
b867d00
6ad5d95
4bf663f
e4922c9
ab4b384
9dadb3a
26216e3
405d52c
505d690
9f8c47b
0a47953
0c51329
54f8c89
d3fe07c
8134a1a
9d3c1dd
5c4bf0e
4121b0e
a38168a
51e7426
0a3655d
f82579e
9aa7298
b130db9
1ae04b3
475da88
4af7e30
e03ef4e
9b79b61
6dd6de9
4c7e882
4c325c3
fa228d8
edad225
553ae9b
00d8411
a99a2b1
472629d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ function isassumption(expr::Union{Symbol,Expr}) | |
vn = gensym(:vn) | ||
|
||
return quote | ||
let $vn = $(varname(expr)) | ||
let $vn = $(AbstractPPL.drop_escape(varname(expr))) | ||
if $(DynamicPPL.contextual_isassumption)(__context__, $vn) | ||
# Considered an assumption by `__context__` which means either: | ||
# 1. We hit the default implementation, e.g. using `DefaultContext`, | ||
|
@@ -133,17 +133,17 @@ variables. | |
|
||
# Example | ||
```jldoctest; setup=:(using Distributions, LinearAlgebra) | ||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); string(vns[end]) | ||
"x[:,2]" | ||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] | ||
x[:,2] | ||
|
||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end]) | ||
"x[:][1,2]" | ||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] | ||
x[1,2] | ||
|
||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); string(vns[end]) | ||
"x[1][3]" | ||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end] | ||
x[:][1,2] | ||
|
||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); string(vns[end]) | ||
"x[1,2,3]" | ||
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end] | ||
x[1][3] | ||
``` | ||
""" | ||
unwrap_right_left_vns(right, left, vns) = right, left, vns | ||
|
@@ -158,7 +158,7 @@ function unwrap_right_left_vns( | |
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, | ||
# and we therefore add the `Colon()` below. | ||
vns = map(axes(left, 2)) do i | ||
return VarName(vn, (vn.indexing..., (Colon(), i))) | ||
return vn β Setfield.IndexLens((Colon(), i)) | ||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
return unwrap_right_left_vns(right, left, vns) | ||
end | ||
|
@@ -168,7 +168,7 @@ function unwrap_right_left_vns( | |
vn::VarName, | ||
) | ||
vns = map(CartesianIndices(left)) do i | ||
return VarName(vn, (vn.indexing..., Tuple(i))) | ||
return vn β Setfield.IndexLens(Tuple(i)) | ||
end | ||
return unwrap_right_left_vns(right, left, vns) | ||
end | ||
|
@@ -317,6 +317,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn) | |
# Do not touch interpolated expressions | ||
expr.head === :$ && return expr.args[1] | ||
|
||
# Do we don't want escaped expressions because we unfortunately | ||
# escape the entire body afterwards. | ||
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we do the wrong thing and escape the entire body of the method in This is a hack to essentially ensure that any escaping will be removed. Note that this doesn't break anything because before we couldn't even use escaped expressions within |
||
|
||
# If it's a macro, we expand it | ||
if Meta.isexpr(expr, :macrocall) | ||
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) | ||
|
@@ -349,38 +353,36 @@ function generate_mainbody!(mod, found, expr::Expr, warn) | |
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) | ||
end | ||
|
||
function generate_tilde_literal(left, right) | ||
# If the LHS is a literal, it is always an observation | ||
return quote | ||
$(DynamicPPL.tilde_observe!)( | ||
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ | ||
) | ||
end | ||
end | ||
|
||
""" | ||
generate_tilde(left, right) | ||
|
||
Generate an `observe` expression for data variables and `assume` expression for parameter | ||
variables. | ||
""" | ||
function generate_tilde(left, right) | ||
# If the LHS is a literal, it is always an observation | ||
if isliteral(left) | ||
return quote | ||
$(DynamicPPL.tilde_observe!)( | ||
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ | ||
) | ||
end | ||
end | ||
isliteral(left) && return generate_tilde_literal(left, right) | ||
|
||
# Otherwise it is determined by the model or its value, | ||
# if the LHS represents an observation | ||
@gensym vn inds isassumption | ||
@gensym vn isassumption | ||
|
||
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact | ||
# that in DynamicPPL we the entire function body. Instead we should be | ||
# more selective with our escape. Until that's the case, we remove them all. | ||
return quote | ||
$vn = $(varname(left)) | ||
$inds = $(vinds(left)) | ||
$vn = $(AbstractPPL.drop_escape(varname(left))) | ||
$isassumption = $(DynamicPPL.isassumption(left)) | ||
if $isassumption | ||
$left = $(DynamicPPL.tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_vn)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $vn | ||
)..., | ||
$inds, | ||
__varinfo__, | ||
) | ||
$(generate_tilde_assume(left, right, vn)) | ||
else | ||
# If `vn` is not in `argnames`, we need to make sure that the variable is defined. | ||
if !$(DynamicPPL.inargnames)($vn, __model__) | ||
|
@@ -392,44 +394,46 @@ function generate_tilde(left, right) | |
$(DynamicPPL.check_tilde_rhs)($right), | ||
$(maybe_view(left)), | ||
$vn, | ||
$inds, | ||
__varinfo__, | ||
) | ||
end | ||
end | ||
end | ||
|
||
function generate_tilde_assume(left, right, vn) | ||
expr = :( | ||
$left = $(DynamicPPL.tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., | ||
__varinfo__, | ||
) | ||
) | ||
|
||
return if left isa Expr | ||
AbstractPPL.drop_escape( | ||
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) | ||
) | ||
else | ||
return expr | ||
end | ||
end | ||
|
||
""" | ||
generate_dot_tilde(left, right) | ||
|
||
Generate the expression that replaces `left .~ right` in the model body. | ||
""" | ||
function generate_dot_tilde(left, right) | ||
# If the LHS is a literal, it is always an observation | ||
if isliteral(left) | ||
return quote | ||
$(DynamicPPL.dot_tilde_observe!)( | ||
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ | ||
) | ||
end | ||
end | ||
isliteral(left) && return generate_tilde_literal(left, right) | ||
|
||
# Otherwise it is determined by the model or its value, | ||
# if the LHS represents an observation | ||
@gensym vn inds isassumption | ||
@gensym vn isassumption | ||
return quote | ||
$vn = $(varname(left)) | ||
$inds = $(vinds(left)) | ||
$vn = $(AbstractPPL.drop_escape(varname(left))) | ||
$isassumption = $(DynamicPPL.isassumption(left)) | ||
if $isassumption | ||
$left .= $(DynamicPPL.dot_tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_left_vns)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn | ||
)..., | ||
$inds, | ||
__varinfo__, | ||
) | ||
$(generate_dot_tilde_assume(left, right, vn)) | ||
else | ||
# If `vn` is not in `argnames`, we need to make sure that the variable is defined. | ||
if !$(DynamicPPL.inargnames)($vn, __model__) | ||
|
@@ -441,13 +445,27 @@ function generate_dot_tilde(left, right) | |
$(DynamicPPL.check_tilde_rhs)($right), | ||
$(maybe_view(left)), | ||
$vn, | ||
$inds, | ||
__varinfo__, | ||
) | ||
end | ||
end | ||
end | ||
|
||
function generate_dot_tilde_assume(left, right, vn) | ||
# We don't need to use `Setfield.@set` here since | ||
# `.=` is always going to be inplace + needs `left` to | ||
# be something that supports `.=`. | ||
return :( | ||
$left .= $(DynamicPPL.dot_tilde_assume!)( | ||
__context__, | ||
$(DynamicPPL.unwrap_right_left_vns)( | ||
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn | ||
)..., | ||
__varinfo__, | ||
) | ||
) | ||
end | ||
|
||
const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} | ||
hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) | ||
hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll move the updates to the benchmarks to a separate PR if this PR delays.