-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster evaluation: SimpleVarInfo
#267
Faster evaluation: SimpleVarInfo
#267
Conversation
src/simple_varinfo.jl
Outdated
x = map(enumerate(md.ranges)) do (i, r) | ||
reconstruct(md.dists[i], md.vals[r]) | ||
end | ||
|
||
# TODO: Doesn't support batches of `MultivariateDistribution`? | ||
length(x) == 1 ? x[1] : x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where the indexing goes wrong. Constructing a SimpleVarInfo
from TypedVarInfo
is somewhat non-trivial if we allow indexing. E.g.
julia> @model function demo3(m)
m[:, 1] ~ MvNormal(size(m, 1), 1.0)
m[:, 2] ~ MvNormal(size(m, 1), 1.0)
return m
end
demo3 (generic function with 1 method)
julia> setmissings(m::Model, missings...) = Model{missings}(m.name, m.f, m.args, m.defaults);
julia> m3 = setmissings(demo3(rand(3, 2)), :m);
julia> m3()
3×2 Matrix{Float64}:
0.264148 1.30146
-0.974804 0.0924204
1.44424 1.3519
julia> vi = VarInfo(m3);
julia> svi = SimpleVarInfo(vi);
julia> svi.θ.m
2-element Vector{Vector{Float64}}:
[-0.8340592002751136, 0.5670953725697917, -0.5460275730331128]
[0.12464384286023088, -1.2118644862064083, 2.3386842884350765]
julia> svi[@varname(m[:, 1])] # (×) since `svi.θ` is vec of vecs, the indexing produces the wrong result
2-element Vector{Vector{Float64}}:
[-0.8340592002751136, 0.5670953725697917, -0.5460275730331128]
[0.12464384286023088, -1.2118644862064083, 2.3386842884350765]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we de-linearise the matrix variable m while converting vi::VarInfo to svi::SimpleVarInfo? That is, we store m::Matrix with its original shape in SimpleVarInfo. This de-linearise operation is also needed in MCMCChains when we group matrix variables, I remember.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Me knee-jerk reaction is that ordering will be an issue. This is not an issue for VarInfo
because there you'll correctly associate vals[1]
with vns[1]
; it doesn't matter whether vns[1]
is actually m[:, 2]
. MCMCChains similarly doesn't care about this. In contrast, in SimpleVarInfo
we do if we want to ensure that indexing works 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a better way of doing the indexing though! Not ruling out a solution yet; just saying that it's not that easy I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. if we start using view
(see #272 ) we could potentially support more scenarios by allocating parent
:)
Very interesting approach - the amount of code is much less than I would expect, which is a positive surprise. A quick thought on supporting sampling: could we use the |
That won't work since this will only create a new instance of But we could just mutate arrays in place and potentially use Might be a non-breaking way around this though that I'm missing! |
Pinging o great @devmotion ! Would love your thoughts on this 👍 |
Does a simple mutable struct avoid type-instabilities with Zygote, e.g.: mutable struct F{T<:Number}
v::T
end then, we can use Also, it wouldn't hurt to have two versions of |
We don't need the context but the DynamicPPL.jl/src/context_implementations.jl Lines 130 to 134 in 892b971
SimpleVarInfo object. Alternatively, if it is better to update the VarInfo object in leaf contexts, we could propagate it upwards instead or in addition to value .
|
Sorry, that was a typo; I meant returning varinfo:)
Exactly. This is what I had in mind when I mentioned this. This would also probably make things even faster as it would remove the need to use But I didn't immediately suggest this because wouldn't this potentially make models type-unstable due to re-assignment of
We briefly discussed this in some other PR, remember? It has issues because you end up having to do stuff like: lp_old = getlogp(vi)
val, lp = tilde_assume(...)
setlogp!(vi, weight * (getlogp(vi) - lp_old)) or something for |
One would have to check this but hopefully the compiler would be smart enough to realize that the same object is returned for |
The annoying bit from a user-perspective is that this means that you don't have the same amount of control on the type-stability as you did before 😕 Maybe this is a bit of overkill, but what if we add a immutability-trait to if isimmutable(__varinfo__)
$left, __varinfo__ = tilde_assume(...)
else
$left, _ = tilde_assume(...)
end ? |
Annother annoyance is that we're going to have more allocations since we now longer can do EDIT: Could do this update in |
I think this would be quite confusing (
Yes, this was what I assumed one has to do. |
Not sure what you mean here;
But for the mutated versions we preserve the type-stability that we currently have, since we don't need to capture the returned |
Ah my bad, I missed that there was still a second return value that was just discarded in your example.
I still hope that it's not an issue even for complex models 🙂 And I think it should work even for immutable structs, but in this case maybe the compiler is happier with simpler models. But this is just what I would expect from the compiler, of course, one would have to back this up with some examples. |
So how do you feel about it after realizing that?:) Good or bad idea? It's nice because it ensures that we atleast preserve the existing behaviour.
I agree that in most cases it's going to be completely fine, but there will be cases that currently are type-stable which won't be afterwards, so it would be nice if we could avoid messing up the current behaviour. Also, one more thing that's an issue: EDIT: Unless of course we start messing with the return-value of the model, i.e. adding a return if __issubmodel__
$retvalue, __varinfo__
else
$retvalue
end It's a bit hacky, but maybe not the worst idea ever? EDIT 2: Could also just use a context to specify that it's a submodel. But would still require messing with the return-value. |
Do you have an example? I think it won't introduce any type instabilities for
I think it would be cleaner to have a separate |
I'll give it a go and see if there are any example models which fails 👍
You mean adding a separate method to |
I thought one could just implement an alternative to Lines 156 to 161 in 892b971
f does not return the VarInfo object.
Mostly I am not a fan of adding a special argument (or also the trait approach for immutable varinfos) since IMO it makes the code messier just to handle a special case. We put quite some effort into simplifying how the evaluator works and the macro is expanded to, so it would be unfortunate to add all these things to handle some non-standard cases 😞 My initial suggestion would have been to maybe always return the return value and the varinfo object from the evaluator (and in |
Just for the record, I'm also not a big fan of that 😅 But difficult to see how it could be done otherwise. If we want to make an EDIT: And regarding the immutablility-trait: I'm also fine with not introducing it initially, and then we can do so if it indeed turns out that it introduces a bunch of type-instabilities.
Yeah, it would require use to traverse and replace all return-values. But I'm fine with always returning the varinfo too, you prefer that; I was just thinking not returning varinfo in the outer-most model wouldn't change anything from a user-perspective, which is nice. |
I'm fine with adding the return values, I hope it doesn't require too many canges (I guess some additional handling in the function that builds the model body, a way to capture the final value and a final return statement) 🙂
Always returning both the value and the varinfo object won't necessarily change anything from a user perspective - users don't call |
👍
Aaah good point; no way to retrive the samples then. With that in mind I agree always adding is a good idea 👍 Btw, how do we actually replace all the return-values? Do we need to use IRTools for this? Seems like macros won't cut it? |
Hmm yeah with the current macro approach we could only replace the high-level Another thing (not really related to SimpleVarInfo but I thought about this again when I saw the implementation): should |
But how do you do this? E.g. what about statements such as map(x) do x
return x
end ?
I don't like this, nor do I quite see the motivation behind this 😕 I think if we start considering moving it away from varinfo, then I'm more in favour of a |
It works at least: julia> using DynamicPPL, Distributions
julia> import IRTools
julia> function add_varinfo_return!(ir, argidx=3)
for block in IRTools.blocks(ir)
for b in IRTools.branches(block)
if IRTools.isreturn(b)
# Add the new return-statement
retval = IRTools.push!(block, IRTools.xcall(:tuple, b.args[1], IRTools.var(argidx)))
IRTools.return!(block, retval)
end
end
end
return ir
end
add_varinfo_return! (generic function with 2 methods)
julia> function new_model(m::Model)
# Get the IR representation.
ir = IRTools.Inner.code_ir(m.f, Tuple{Model, DynamicPPL.AbstractVarInfo, DynamicPPL.AbstractContext, map(typeof, m.args)...})
# Replace return-values.
add_varinfo_return!(ir)
# Create resulting function.
mf = IRTools.func(ir)
# Re-create new model.
# TODO: Why do I need this first `nothing` argument?
# NOTE: For some reason `Base.Fix1` doesn't work?
function evaluator(args...)
return mf(nothing, args...)
end
return Model(m.name, evaluator, m.args, m.defaults)
end
new_model (generic function with 1 method)
julia> @model function demo(x)
m ~ Normal()
if m < 0
return nothing
end
x ~ Normal(m, 1)
z = map(1:3) do i
i * x
end
return (; m, x, z)
end
demo (generic function with 1 method)
julia> m = demo(1.0)
Model{var"#12#14", (:x,), (), (), Tuple{Float64}, Tuple{}}(:demo, var"#12#14"(), (x = 1.0,), NamedTuple())
julia> m()
(m = 0.06386506414921082, x = 1.0, z = [1.0, 2.0, 3.0])
julia> m()
julia> m()
(m = 1.2132951487644734, x = 1.0, z = [1.0, 2.0, 3.0])
julia> mnew = new_model(m)
Model{var"#evaluator#11"{IRTools.Inner.var"###286"}, (:x,), (), (), Tuple{Float64}, Tuple{}}(:demo, var"#evaluator#11"{IRTools.Inner.var"###286"}(IRTools.Inner.var"##286"), (x = 1.0,), NamedTuple())
julia> mnew()
((m = 1.1227582487925878, x = 1.0, z = [1.0, 2.0, 3.0]), VarInfo (1 variable (m), dimension 1; logp: -2.476))
julia> mnew()
((m = 0.5147520341831271, x = 1.0, z = [1.0, 2.0, 3.0]), VarInfo (1 variable (m), dimension 1; logp: -2.088))
julia> mnew()
(nothing, VarInfo (1 variable (m), dimension 1; logp: -1.118)) EDIT: A significant issue with the above is that it doesn't capture re-assignment since we replace return-values with the variable from the arguments, duh 😅 |
Exactly, one must not replace or modify return statements in such examples.
The main motivation would be to separate the data structure for saving and retrieving variables from the one that accumulates the log density. This would allow a more orthogonal design. E.g., currently the only difference between VarInfo and ThreadSafeVarInfo is that the latter uses a thread-safe way for accumulating the log density. Instead one would have a joint/prior/likelihood context that accumulates the log density in a threadsafe way and could combine them with different varinfo types (currently ThreadSafeVarInfo is designed for VarInfo). Another motivation would be that it means different things and hence there is no clear meaning to varinfo.logp. If it's an accumulator in the LikelihoodContext, then it's clear that it's the log-likelihood.
I think I disagree with this suggestion, I would rather keep it as part of the varinfo objects or possibly of the contexts. It's not needed in every model execution, e.g., the |
Soooo ideas? 🙃 Realized it's a bit difficult to do on an IR-level because we're working with SSA 😅
Hmm, maybe. IMO this should be a separate discussion though 😕
True 👍 |
Maybe something as simple as replace_returns(e) = e
replace_returns(e::Symbol) = e
function replace_returns(e::Expr)
if Meta.isexpr(e, :function) || Meta.isexpr(e, :->)
return e
end
if Meta.isexpr(e, :return)
retval = if length(e.args) > 1
Expr(:tuple, e.args...)
else
e.args[1]
end
return quote
return $retval, __varinfo__
end
end
return Expr(e.head, map(x -> replace_returns(x), e.args)...)
end will do a good enough job? Is there anything else we're missing there? Also, should have this as a "submodel" evaluator instead? EDIT: This requires users to be explicit about their return-statements though. Could maybe extend this to also change the last statement, if it's not a |
I agree. I wanted to mention it since this PR introduces an additional varinfo struct with another logp field.
I think we have to change it, otherwise it will fail in such cases? |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/simple_varinfo.jl
Outdated
function push!!( | ||
vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution | ||
) where {sym} | ||
@set vi.θ = NamedTuple{(sym,)}((value,)) | ||
end | ||
function push!!( | ||
vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution | ||
) where {sym} | ||
@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any good ideas of how we can support cases where we actually have indexing, i.e. VarName{sym}
rather than VarName{sym, Tuple{}}
? @devmotion ? 👼
Might be useful to exploit parent
if we have #272 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess everyone has already had their own ideas about "trie-style" VarInfo, but mine was the following (similar to your views idea, I think): store the complete data in one big array (or perhaps one per type). Then put a trie-like thing on top of that, referencing only views into it. So we'd have something like, conceptually,
backup = [1.0, 2.0, 3.0, 4.0]
vi = (x = @view(backup, 1:2),
y = (var"1" = @view(backup, 3),
var"3" = @view(backup, 4)))
Of course there's a lot of possibilities of implementing this better (like a dictionary for y
, in case there's a lot of indices.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that var"1"
won't be a good idea once the indices are dynamic (this I guess is what you're referring to when saying that it can be implemented using a dictionary).
But for SimpleVarInfo
, I'm happy to only support "simple" expressions meaning expressions of the form x
and x[...]
, not x[1][...]
, or equivalently VarName{<:Any,Tuple{}}
and VarName{<:Any, Tuple{<:Tuple}}
.
Now that everything is a View
, this could be implemented by just preallocating equivalent of parent
and the inserting into this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing: nested indexing is kind of "useless" if we have something like #275 😕
|
||
# Context implementations | ||
function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo) | ||
value, logp, vi_new = tilde_assume(context, right, vn, inds, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice how here we also take the vi_new
.
Ideally this is what we should be doing overall if we're going to support immutable varinfos. But if we force this, we'll end up breaking a lot of downstream samplers since it requires changing assume
statements to also return vi
.
As an intermediate step, it might be worth just overloading tilde_assume!!
as I have done here + assume
as I have done below. This also brings up another annoyance though: vi
should really be at the beginning of the arguments of all the tilde-statements (and ideally assume
too, but we can delay this) to minimize method ambiguities but allow us to have some special behavior for the different impls for AbstractVarInfo
.
…PPL.jl into tor/simple-varinfo-v2
Looking at `to_namedtuple_expr` it seems like it's leftover code to be compatible with Julia 1.2. Given that DPPL only supports Julia 1.3 or higher I've simplified the implementation, doing away with the `namedtuple` method. EDIT: This also fixed a bug I ran into when using Zygote + `@submodel` in #267 . Struggling to come up with a MWE of the bug, but just adding merging this fixed it. This was also my original motivation for making this change, as Zygote was complaining about the `namedtuple` for some reason (something about "non-differentiable getfield"). Co-authored-by: Hong Ge <hg344@cam.ac.uk>
Is this ready for a second look? |
Not yet, no. |
…Tuple unless shapes are specified
This PR introduces a
SimpleVarInfo
, which is a barebonesAbstractVarInfo
that only holds the parameters in aNamedTuple
(orComponentArray
with very little work) and thelogp
.The aim is to allow significant speed up in evaluation of static and simple models by dropping a bunch of the conveniences that comes with
VarInfo
, e.g. linking isn't interleaved with theVarInfo
but instead we just construct aBijector
once and use that in thelogdensity
computation.Used above:
missing
, etc.This PR supersedes #242.
Examples
A couple of simple examples
And
Things to discuss
tilde_assume
and others currently do not specify the type of thevi
argument.isstatic(m::Model)::Bool
where we can also possible introduce some convenience methods for making it easier to add such overloads.Potential limitiations
Parallelism
Problem
Say we have a model including
which will become
If
__varinfo__
is mutable, then we're fine since theThreadedWrapper
will work.In contrast, an immutable
__varinfo__
will fail in this case, since the updates to the__varinfo__
variable will be thread-local. This doesn't mean we can't support multithreading since the user can write it out by hand, but it could break with user-expectation.Solution?
We should probably warn the user that if they're using an immutable
AbstractVarInfo
they should be a bit careful regarding stuff like this?Could potentially just point the user to some documentation which describes the full issue, e.g. "be careful, see HERE if you need elaboration".