Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster evaluation: SimpleVarInfo #267

Merged
merged 28 commits into from
Aug 14, 2021

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jun 18, 2021

This PR introduces a SimpleVarInfo, which is a barebones AbstractVarInfo that only holds the parameters in a NamedTuple (or ComponentArray with very little work) and the logp.

The aim is to allow significant speed up in evaluation of static and simple models by dropping a bunch of the conveniences that comes with VarInfo, e.g. linking isn't interleaved with the VarInfo but instead we just construct a Bijector once and use that in the logdensity computation.

Used above:

  • static: variables, sizes, etc. does not change between runs.
  • simple: no indexing, partial missing, etc.

This PR supersedes #242.

Examples

A couple of simple examples
julia> @model function demo(x)
           s ~ InverseGamma(2, 3)
           m ~ Normal(0, s)

           x = ismissing(x) ? Vector{Float64}(undef, 2) : x
           x .~ Normal(m, s)

           return (; s, m, x, logp = getlogp(__varinfo__))
       end;

julia> m = demo([1.0, 1.0]);

julia> svi = SimpleVarInfo((s = 1.0, m = 0.0))
SimpleVarInfo{NamedTuple{(:s, :m), Tuple{Float64, Float64}}, Float64}((s = 1.0, m = 0.0), Base.RefValue{Float64}(0.0))

julia> m(svi, DefaultContext())
(s = 1.0, m = 0.0, x = [1.0, 1.0], logp = -4.5595910222777984)

julia> vi = VarInfo(m);

julia> @benchmark $m($vi, $(DefaultContext()))
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.288 μs (0.00% GC)
  median time:      1.377 μs (0.00% GC)
  mean time:        1.514 μs (0.00% GC)
  maximum time:     4.380 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark $m($svi, $(DefaultContext()))
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     55.902 ns (0.00% GC)
  median time:      56.314 ns (0.00% GC)
  mean time:        60.289 ns (0.00% GC)
  maximum time:     123.256 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     979

julia> m = demo(missing);

julia> svi = DynamicPPL.SimpleVarInfo{Float64}((s = 1.0, m = 0.0, x = [1.0, 1.0]));

julia> m(svi, DefaultContext())
(s = 1.0, m = 0.0, x = [1.0, 1.0], logp = -4.5595910222777984)

julia> vi = VarInfo(m);

julia> @benchmark $m($vi, $(DefaultContext()))
BenchmarkTools.Trial: 
  memory estimate:  480 bytes
  allocs estimate:  5
  --------------
  minimum time:     2.061 μs (0.00% GC)
  median time:      2.213 μs (0.00% GC)
  mean time:        2.418 μs (1.83% GC)
  maximum time:     445.382 μs (99.18% GC)
  --------------
  samples:          10000
  evals/sample:     9

julia> @benchmark $m($svi, $(DefaultContext()))
BenchmarkTools.Trial: 
  memory estimate:  384 bytes
  allocs estimate:  4
  --------------
  minimum time:     163.438 ns (0.00% GC)
  median time:      187.497 ns (0.00% GC)
  mean time:        230.662 ns (12.62% GC)
  maximum time:     6.815 μs (96.73% GC)
  --------------
  samples:          10000
  evals/sample:     719

And

@model function prior()
    m ~ Normal()
end

@model function demo(x)
    @submodel prefix2 m ~ prior()
    @submodel prior()
    x ~ MvNormal(m * ones(length(x)), 1.0)

    return (; m, logp = getlogp(__varinfo__))
end

m = demo([1.0, 1.0]);

vi = VarInfo(m);
svi = SimpleVarInfo(vi);

using BenchmarkTools
@benchmark $m($vi, $DefaultContext())
# BenchmarkTools.Trial: 
#   memory estimate:  288 bytes
#   allocs estimate:  3
#   --------------
#   minimum time:     1.200 μs (0.00% GC)
#   median time:      1.302 μs (0.00% GC)
#   mean time:        1.325 μs (0.00% GC)
#   maximum time:     5.421 μs (0.00% GC)
#   --------------
#   samples:          10000
#   evals/sample:     10
@benchmark $m($svi, $DefaultContext())
# BenchmarkTools.Trial: 
#   memory estimate:  288 bytes
#   allocs estimate:  3
#   --------------
#   minimum time:     111.117 ns (0.00% GC)
#   median time:      118.714 ns (0.00% GC)
#   mean time:        148.049 ns (16.18% GC)
#   maximum time:     4.723 μs (96.89% GC)
#   --------------
#   samples:          10000
#   evals/sample:     923

Things to discuss

  1. Should we also have limited support for sampling?
    • One immediate issue is method-ambiguity since tilde_assume and others currently do not specify the type of the vi argument.
  2. Can we lose restrictions on this, i.e. make it more general, without sacrificing performance?
  3. Should the restrictions, e.g. the model being static, be implemented in a trait-like manner as we discussed on Slack? E.g. isstatic(m::Model)::Bool where we can also possible introduce some convenience methods for making it easier to add such overloads.

Potential limitiations

Parallelism

Problem

Say we have a model including

@threads for i = 1:10
    x[i] ~ Normal()
end

which will become

@threads for i = 1:10
    x[i], __varinfo__ = tilde_assume(...)
end

If __varinfo__ is mutable, then we're fine since the ThreadedWrapper will work.

In contrast, an immutable __varinfo__ will fail in this case, since the updates to the __varinfo__ variable will be thread-local. This doesn't mean we can't support multithreading since the user can write it out by hand, but it could break with user-expectation.

Solution?

We should probably warn the user that if they're using an immutable AbstractVarInfo they should be a bit careful regarding stuff like this?
Could potentially just point the user to some documentation which describes the full issue, e.g. "be careful, see HERE if you need elaboration".

src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
@torfjelde torfjelde marked this pull request as draft June 18, 2021 13:07
Comment on lines 112 to 133
x = map(enumerate(md.ranges)) do (i, r)
reconstruct(md.dists[i], md.vals[r])
end

# TODO: Doesn't support batches of `MultivariateDistribution`?
length(x) == 1 ? x[1] : x
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the indexing goes wrong. Constructing a SimpleVarInfo from TypedVarInfo is somewhat non-trivial if we allow indexing. E.g.

julia> @model function demo3(m)
           m[:, 1] ~ MvNormal(size(m, 1), 1.0)
           m[:, 2] ~ MvNormal(size(m, 1), 1.0)
           return m
       end
demo3 (generic function with 1 method)

julia> setmissings(m::Model, missings...) = Model{missings}(m.name, m.f, m.args, m.defaults);

julia> m3 = setmissings(demo3(rand(3, 2)), :m);

julia> m3()
3×2 Matrix{Float64}:
  0.264148  1.30146
 -0.974804  0.0924204
  1.44424   1.3519

julia> vi = VarInfo(m3);

julia> svi = SimpleVarInfo(vi);

julia> svi.θ.m
2-element Vector{Vector{Float64}}:
 [-0.8340592002751136, 0.5670953725697917, -0.5460275730331128]
 [0.12464384286023088, -1.2118644862064083, 2.3386842884350765]

julia> svi[@varname(m[:, 1])] # (×) since `svi.θ` is vec of vecs, the indexing produces the wrong result
2-element Vector{Vector{Float64}}:
 [-0.8340592002751136, 0.5670953725697917, -0.5460275730331128]
 [0.12464384286023088, -1.2118644862064083, 2.3386842884350765]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we de-linearise the matrix variable m while converting vi::VarInfo to svi::SimpleVarInfo? That is, we store m::Matrix with its original shape in SimpleVarInfo. This de-linearise operation is also needed in MCMCChains when we group matrix variables, I remember.

Copy link
Member Author

@torfjelde torfjelde Jun 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me knee-jerk reaction is that ordering will be an issue. This is not an issue for VarInfo because there you'll correctly associate vals[1] with vns[1]; it doesn't matter whether vns[1] is actually m[:, 2]. MCMCChains similarly doesn't care about this. In contrast, in SimpleVarInfo we do if we want to ensure that indexing works 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a better way of doing the indexing though! Not ruling out a solution yet; just saying that it's not that easy I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. if we start using view (see #272 ) we could potentially support more scenarios by allocating parent :)

@yebai
Copy link
Member

yebai commented Jun 18, 2021

Very interesting approach - the amount of code is much less than I would expect, which is a positive surprise. A quick thought on supporting sampling: could we use the Setfield package to mutate named tuples, such that sampling can be added?

@torfjelde
Copy link
Member Author

Very interesting approach - the amount of code is much less than I would expect, which is a positive surprise. A quick thought on supporting sampling: could we use the Setfield package to mutate named tuples, such that sampling can be added?

That won't work since this will only create a new instance of SimpleVarInfo, which won't propagate to the callee since we don't return the context from tilde_*!.

But we could just mutate arrays in place and potentially use Ref for univariates (ugly though and will lead to type-instabilities in gradients when using Zygote, but we already have those thanks to svi.logp being a Ref: FluxML/Zygote.jl#999), but even then we'd run into method ambiguities for all the *_tilde_* I believe (since vi is one of the latter arguments). We could of course fix this, but it would make this very breaking vs. just not having support for sampling which guarantees being non-breaking.

Might be a non-breaking way around this though that I'm missing!

@torfjelde
Copy link
Member Author

Pinging o great @devmotion ! Would love your thoughts on this 👍

@yebai
Copy link
Member

yebai commented Jun 18, 2021

potentially use Ref for univariates (ugly though and will lead to type-instabilities in gradients when using Zygote, but we already have those thanks to svi.logp being a Ref: FluxML/Zygote.jl#999),

Does a simple mutable struct avoid type-instabilities with Zygote, e.g.:

mutable struct F{T<:Number}
       v::T
end

then, we can use F instead of Ref for all univariates.

Also, it wouldn't hurt to have two versions of SimpleVarInfo, one for sampling, the other for evaluation. Then we can disable mutation in the evaluation version of SimpleVarInfo. This introduces some redundancy, but it might be possible to define only one parametric SimpleVarInfo type. We can then specialise this type for sampling and evaluation modes accordingly.

@devmotion
Copy link
Member

That won't work since this will only create a new instance of SimpleVarInfo, which won't propagate to the callee since we don't return the context from tilde_*!.

We don't need the context but the SimpleVarInfo object, it seems? What about changing tilde_assume! to tilde_assume!! that returns both the value and the varinfo object and assigning the latter to __varinfo__ in the model? For VarInfo we would update varinfo only in

function tilde_assume!(context, right, vn, inds, vi)
value, logp = tilde_assume(context, right, vn, inds, vi)
acclogp!(vi, logp)
return value
end
(this might be the main issue?) and add an additional return statement, and for SimpleVarInfo we would return a new SimpleVarInfo object. Alternatively, if it is better to update the VarInfo object in leaf contexts, we could propagate it upwards instead or in addition to value.

@torfjelde
Copy link
Member Author

We don't need the context but the SimpleVarInfo object, it seems?

Sorry, that was a typo; I meant returning varinfo:)

What about changing tilde_assume! to tilde_assume!! that returns both the value and the varinfo object and assigning the latter to varinfo in the model?

Exactly. This is what I had in mind when I mentioned this. This would also probably make things even faster as it would remove the need to use Ref.

But I didn't immediately suggest this because wouldn't this potentially make models type-unstable due to re-assignment of __varinfo__ everywhere in the model?

Alternatively, if it is better to update the VarInfo object in leaf contexts, we could propagate it upwards instead or in addition to value.

We briefly discussed this in some other PR, remember? It has issues because you end up having to do stuff like:

lp_old = getlogp(vi)
val, lp = tilde_assume(...)
setlogp!(vi, weight * (getlogp(vi) - lp_old))

or something for MiniBatchContext, and potentially more issues.

@devmotion
Copy link
Member

But I didn't immediately suggest this because wouldn't this potentially make models type-unstable due to re-assignment of varinfo everywhere in the model?

One would have to check this but hopefully the compiler would be smart enough to realize that the same object is returned for VarInfo and to figure out the type in the other case as well. Maybe it breaks for more complicated models but then one still has the function barrier when calling the next tilde_assume or tilde_observe statement at least.

@torfjelde
Copy link
Member Author

The annoying bit from a user-perspective is that this means that you don't have the same amount of control on the type-stability as you did before 😕

Maybe this is a bit of overkill, but what if we add a immutability-trait to AbstractVarInfo and do:

if isimmutable(__varinfo__)
    $left, __varinfo__ = tilde_assume(...)
else
    $left, _ = tilde_assume(...)
end

?

@torfjelde
Copy link
Member Author

torfjelde commented Jun 19, 2021

Annother annoyance is that we're going to have more allocations since we now longer can do $left .= dot_tilde_assume(...), but now need to do $tmp, __varinfo__ = dot_tilde_assume!!(...); $left .= $tmp or something along those lines?

EDIT: Could do this update in dot_tilde_assume!! though.

@devmotion
Copy link
Member

Maybe this is a bit of overkill, but what if we add a immutability-trait to AbstractVarInfo and do:

I think this would be quite confusing (tilde_assume! does not have consistent return types, even the number is different). I still think the bangbang version should work fine if varinfo is mutated, so I don't think it would be more type stable. But that's something one has to check.

EDIT: Could do this update in dot_tilde_assume!! though.

Yes, this was what I assumed one has to do.

@torfjelde
Copy link
Member Author

torfjelde commented Jun 19, 2021

I think this would be quite confusing (tilde_assume! does not have consistent return types, even the number is different).

Not sure what you mean here; tilde_assume! doesn't have consistent number of return values? AFAIK in this case it would always return value, vi, no?

I still think the bangbang version should work fine if varinfo is mutated, so I don't think it would be more type stable. But that's something one has to check.

But for the mutated versions we preserve the type-stability that we currently have, since we don't need to capture the returned __varinfo__. I'm referring to those complex models you mentioned where the compiler might struggle wtih with reassignment of __varinfo__. In those cases the $left, _ = ... wouldn't have any type-inference issues, i.e. preserving the type-stability that we currently have.

@devmotion
Copy link
Member

Not sure what you mean here; tilde_assume! doesn't have consistent number of return values? AFAIK in this case it would always return value, vi, no?

Ah my bad, I missed that there was still a second return value that was just discarded in your example.

I'm referring to those complex models you mentioned where the compiler might struggle wtih with reassignment of varinfo. In those cases the $left, _ = ... wouldn't have any type-inference issues, i.e. preserving the type-stability that we currently have.

I still hope that it's not an issue even for complex models 🙂 And I think it should work even for immutable structs, but in this case maybe the compiler is happier with simpler models. But this is just what I would expect from the compiler, of course, one would have to back this up with some examples.

@torfjelde
Copy link
Member Author

torfjelde commented Jun 19, 2021

Ah my bad, I missed that there was still a second return value that was just discarded in your example.

So how do you feel about it after realizing that?:) Good or bad idea?

It's nice because it ensures that we atleast preserve the existing behaviour.

I still hope that it's not an issue even for complex models slightly_smiling_face And I think it should work even for immutable structs, but in this case maybe the compiler is happier with simpler models. But this is just what I would expect from the compiler, of course, one would have to back this up with some examples.

I agree that in most cases it's going to be completely fine, but there will be cases that currently are type-stable which won't be afterwards, so it would be nice if we could avoid messing up the current behaviour.

Also, one more thing that's an issue: @submodel will break for immutables no matter what 😕

EDIT: Unless of course we start messing with the return-value of the model, i.e. adding a ::Val{__issubmodel__} argument to the evaluator or something and

return if __issubmodel__
    $retvalue, __varinfo__
else
    $retvalue
end

It's a bit hacky, but maybe not the worst idea ever?

EDIT 2: Could also just use a context to specify that it's a submodel. But would still require messing with the return-value.

@devmotion
Copy link
Member

there will be cases that currently are type-stable which won't be afterwards, so it would be nice if we could avoid messing up the current behaviour.

Do you have an example? I think it won't introduce any type instabilities for VarInfo, unless there is a concrete example where it becomes unstable 🙂

EDIT: Unless of course we start messing with the return-value of the model, i.e. adding a ::Val{issubmodel} argument to the evaluator or something and

I think it would be cleaner to have a separate submodel_evaluate function.

@torfjelde
Copy link
Member Author

Do you have an example? I think it won't introduce any type instabilities for VarInfo, unless there is a concrete example where it becomes unstable slightly_smiling_face

I'll give it a go and see if there are any example models which fails 👍

I think it would be cleaner to have a separate submodel_evaluate function.

You mean adding a separate method to Model?

@devmotion
Copy link
Member

I thought one could just implement an alternative to

DynamicPPL.jl/src/model.jl

Lines 156 to 161 in 892b971

@generated function _evaluate(
model::Model{_F,argnames}, varinfo, context
) where {_F,argnames}
unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames]
return :(model.f(model, varinfo, context, $(unwrap_args...)))
end
but of course that won't work since f does not return the VarInfo object.

Mostly I am not a fan of adding a special argument (or also the trait approach for immutable varinfos) since IMO it makes the code messier just to handle a special case. We put quite some effort into simplifying how the evaluator works and the macro is expanded to, so it would be unfortunate to add all these things to handle some non-standard cases 😞

My initial suggestion would have been to maybe always return the return value and the varinfo object from the evaluator (and in _evaluate) but this won't work by just adding an additional statement at the end of the evaluator f: users may exit the function before reaching the end so one would have to handle also user-provided return statements. I assume this could work but requires a bit more changes.

@torfjelde
Copy link
Member Author

torfjelde commented Jun 19, 2021

Mostly I am not a fan of adding a special argument (or also the trait approach for immutable varinfos) since IMO it makes the code messier just to handle a special case. We put quite some effort into simplifying how the evaluator works and the macro is expanded to, so it would be unfortunate to add all these things to handle some non-standard cases disappointed

Just for the record, I'm also not a big fan of that 😅 But difficult to see how it could be done otherwise. If we want to make an AbstractVarInfo that is immutable (which might be a great idea for performance reasons also), we need to mess with the return-values if we want support for submodels. And submodels is such a significant feature that IMO it's worth it (and we only change the return-values if the model is actually executed as a submodel, no otherwise).

EDIT: And regarding the immutablility-trait: I'm also fine with not introducing it initially, and then we can do so if it indeed turns out that it introduces a bunch of type-instabilities.

My initial suggestion would have been to maybe always return the return value and the varinfo object from the evaluator (and in _evaluate) but this won't work by just adding an additional statement at the end of the evaluator f: users may exit the function before reaching the end so one would have to handle also user-provided return statements. I assume this could work but requires a bit more changes.

Yeah, it would require use to traverse and replace all return-values. But I'm fine with always returning the varinfo too, you prefer that; I was just thinking not returning varinfo in the outer-most model wouldn't change anything from a user-perspective, which is nice.

@devmotion
Copy link
Member

I'm fine with adding the return values, I hope it doesn't require too many canges (I guess some additional handling in the function that builds the model body, a way to capture the final value and a final return statement) 🙂

I was just thinking not returning varinfo in the outer-most model wouldn't change anything from a user-perspective, which is nice.

Always returning both the value and the varinfo object won't necessarily change anything from a user perspective - users don't call model.f directly. We could return both from _evaluate (this fixes the submodel issue) but still only return the value from evaluate_threadsafe and evaluate_threadunsafe. However, the problem with this approach would be that the new samples for immutable varinfo would not be available to users, so in fact we would have to return also the varinfo object at least in this case. One could add an implementation of evaluate_threadsafe etc. for SimpleVarInfo that returns both but then again I think that this inconsistency could be problematic for downstream packages...

@torfjelde
Copy link
Member Author

I'm fine with adding the return values, I hope it doesn't require too many canges (I guess some additional handling in the function that builds the model body, a way to capture the final value and a final return statement) slightly_smiling_face

👍

Always returning both the value and the varinfo object won't necessarily change anything from a user perspective - users don't call model.f directly. We could return both from _evaluate (this fixes the submodel issue) but still only return the value from evaluate_threadsafe and evaluate_threadunsafe. However, the problem with this approach would be that the new samples for immutable varinfo would not be available to users, so in fact we would have to return also the varinfo object at least in this case. One could add an implementation of evaluate_threadsafe etc. for SimpleVarInfo that returns both but then again I think that this inconsistency could be problematic for downstream packages...

Aaah good point; no way to retrive the samples then. With that in mind I agree always adding is a good idea 👍

Btw, how do we actually replace all the return-values? Do we need to use IRTools for this? Seems like macros won't cut it?

@devmotion
Copy link
Member

Btw, how do we actually replace all the return-values? Do we need to use IRTools for this? Seems like macros won't cut it?

Hmm yeah with the current macro approach we could only replace the high-level return statements that are explicit in the model function but not the ones in some function that is called inside the function body.

Another thing (not really related to SimpleVarInfo but I thought about this again when I saw the implementation): should logp be part of the varinfo objects? Or just be an accumulator in LikelihoodContext, PriorContext and JointContext?

@torfjelde
Copy link
Member Author

Hmm yeah with the current macro approach we could only replace the high-level return statements that are explicit in the model function but not the ones in some function that is called inside the function body.

But how do you do this? E.g. what about statements such as

map(x) do x
    return x
end

?

Another thing (not really related to SimpleVarInfo but I thought about this again when I saw the implementation): should logp be part of the varinfo objects? Or just be an accumulator in LikelihoodContext, PriorContext and JointContext?

I don't like this, nor do I quite see the motivation behind this 😕 I think if we start considering moving it away from varinfo, then I'm more in favour of a __logp__ variable that holds it, rather than putting it in the context.

@torfjelde
Copy link
Member Author

torfjelde commented Jun 19, 2021

It works at least:

julia> using DynamicPPL, Distributions

julia> import IRTools

julia> function add_varinfo_return!(ir, argidx=3)
           for block in IRTools.blocks(ir)
               for b in IRTools.branches(block)
                   if IRTools.isreturn(b)
                       # Add the new return-statement
                       retval = IRTools.push!(block, IRTools.xcall(:tuple, b.args[1], IRTools.var(argidx)))
                       IRTools.return!(block, retval)
                   end
               end
           end

           return ir
       end
add_varinfo_return! (generic function with 2 methods)

julia> function new_model(m::Model)
           # Get the IR representation.
           ir = IRTools.Inner.code_ir(m.f, Tuple{Model, DynamicPPL.AbstractVarInfo, DynamicPPL.AbstractContext, map(typeof, m.args)...})
           # Replace return-values.
           add_varinfo_return!(ir)
           # Create resulting function.
           mf = IRTools.func(ir)
           # Re-create new model.
           # TODO: Why do I need this first `nothing` argument?
           # NOTE: For some reason `Base.Fix1` doesn't work?
           function evaluator(args...)
               return mf(nothing, args...)
           end

           return Model(m.name, evaluator, m.args, m.defaults)
       end
new_model (generic function with 1 method)

julia> @model function demo(x)
           m ~ Normal()
           if m < 0
               return nothing
           end
           x ~ Normal(m, 1)

           z = map(1:3) do i
               i * x
           end

           return (; m, x, z)
       end
demo (generic function with 1 method)

julia> m = demo(1.0)
Model{var"#12#14", (:x,), (), (), Tuple{Float64}, Tuple{}}(:demo, var"#12#14"(), (x = 1.0,), NamedTuple())

julia> m()
(m = 0.06386506414921082, x = 1.0, z = [1.0, 2.0, 3.0])

julia> m()

julia> m()
(m = 1.2132951487644734, x = 1.0, z = [1.0, 2.0, 3.0])

julia> mnew = new_model(m)
Model{var"#evaluator#11"{IRTools.Inner.var"###286"}, (:x,), (), (), Tuple{Float64}, Tuple{}}(:demo, var"#evaluator#11"{IRTools.Inner.var"###286"}(IRTools.Inner.var"##286"), (x = 1.0,), NamedTuple())

julia> mnew()
((m = 1.1227582487925878, x = 1.0, z = [1.0, 2.0, 3.0]), VarInfo (1 variable (m), dimension 1; logp: -2.476))

julia> mnew()
((m = 0.5147520341831271, x = 1.0, z = [1.0, 2.0, 3.0]), VarInfo (1 variable (m), dimension 1; logp: -2.088))

julia> mnew()
(nothing, VarInfo (1 variable (m), dimension 1; logp: -1.118))

EDIT: A significant issue with the above is that it doesn't capture re-assignment since we replace return-values with the variable from the arguments, duh 😅

@devmotion
Copy link
Member

But how do you do this? E.g. what about statements such as

Exactly, one must not replace or modify return statements in such examples.

nor do I quite see the motivation behind this confused

The main motivation would be to separate the data structure for saving and retrieving variables from the one that accumulates the log density. This would allow a more orthogonal design. E.g., currently the only difference between VarInfo and ThreadSafeVarInfo is that the latter uses a thread-safe way for accumulating the log density. Instead one would have a joint/prior/likelihood context that accumulates the log density in a threadsafe way and could combine them with different varinfo types (currently ThreadSafeVarInfo is designed for VarInfo).

Another motivation would be that it means different things and hence there is no clear meaning to varinfo.logp. If it's an accumulator in the LikelihoodContext, then it's clear that it's the log-likelihood.

I'm more in favour of a logp variable that holds it, rather than putting it in the context.

I think I disagree with this suggestion, I would rather keep it as part of the varinfo objects or possibly of the contexts. It's not needed in every model execution, e.g., the pointwise_loglikelihoods don't ever make use of it IIRC. And hence it would make the model and its function signature more complicated even if it's not needed.

@torfjelde
Copy link
Member Author

Exactly, one must not replace or modify return statements in such examples.

Soooo ideas? 🙃 Realized it's a bit difficult to do on an IR-level because we're working with SSA 😅

The main motivation would be to separate the data structure for saving and retrieving variables from the one that accumulates the log density. This would allow a more orthogonal design. E.g., currently the only difference between VarInfo and ThreadSafeVarInfo is that the latter uses a thread-safe way for accumulating the log density. Instead one would have a joint/prior/likelihood context that accumulates the log density in a threadsafe way and could combine them with different varinfo types (currently ThreadSafeVarInfo is designed for VarInfo).

Hmm, maybe. IMO this should be a separate discussion though 😕

I think I disagree with this suggestion, I would rather keep it as part of the varinfo objects or possibly of the contexts. It's not needed in every model execution, e.g., the pointwise_loglikelihoods don't ever make use of it IIRC. And hence it would make the model and its function signature more complicated even if it's not needed.

True 👍

@torfjelde
Copy link
Member Author

torfjelde commented Jun 19, 2021

Maybe something as simple as

replace_returns(e) = e
replace_returns(e::Symbol) = e
function replace_returns(e::Expr)
    if Meta.isexpr(e, :function) || Meta.isexpr(e, :->)
        return e
    end

    if Meta.isexpr(e, :return)
        retval = if length(e.args) > 1
            Expr(:tuple, e.args...)
        else
            e.args[1]
        end
        return quote
            return $retval, __varinfo__
        end
    end

    return Expr(e.head, map(x -> replace_returns(x), e.args)...)
end

will do a good enough job? Is there anything else we're missing there?

Also, should have this as a "submodel" evaluator instead?

EDIT: This requires users to be explicit about their return-statements though. Could maybe extend this to also change the last statement, if it's not a return?

@devmotion
Copy link
Member

Hmm, maybe. IMO this should be a separate discussion though confused

I agree. I wanted to mention it since this PR introduces an additional varinfo struct with another logp field.

Could maybe extend this to also change the last statement, if it's not a return?

I think we have to change it, otherwise it will fail in such cases?

src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Comment on lines 76 to 85
function push!!(
vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution
) where {sym}
@set vi.θ = NamedTuple{(sym,)}((value,))
end
function push!!(
vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution
) where {sym}
@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,)))
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any good ideas of how we can support cases where we actually have indexing, i.e. VarName{sym} rather than VarName{sym, Tuple{}}? @devmotion ? 👼

Might be useful to exploit parent if we have #272 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess everyone has already had their own ideas about "trie-style" VarInfo, but mine was the following (similar to your views idea, I think): store the complete data in one big array (or perhaps one per type). Then put a trie-like thing on top of that, referencing only views into it. So we'd have something like, conceptually,

backup = [1.0, 2.0, 3.0, 4.0]
vi = (x = @view(backup, 1:2), 
      y = (var"1" = @view(backup, 3), 
           var"3" = @view(backup, 4)))

Of course there's a lot of possibilities of implementing this better (like a dictionary for y, in case there's a lot of indices.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that var"1" won't be a good idea once the indices are dynamic (this I guess is what you're referring to when saying that it can be implemented using a dictionary).

But for SimpleVarInfo, I'm happy to only support "simple" expressions meaning expressions of the form x and x[...], not x[1][...], or equivalently VarName{<:Any,Tuple{}} and VarName{<:Any, Tuple{<:Tuple}}.

Now that everything is a View, this could be implemented by just preallocating equivalent of parent and the inserting into this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing: nested indexing is kind of "useless" if we have something like #275 😕


# Context implementations
function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo)
value, logp, vi_new = tilde_assume(context, right, vn, inds, vi)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice how here we also take the vi_new.

Ideally this is what we should be doing overall if we're going to support immutable varinfos. But if we force this, we'll end up breaking a lot of downstream samplers since it requires changing assume statements to also return vi.

As an intermediate step, it might be worth just overloading tilde_assume!! as I have done here + assume as I have done below. This also brings up another annoyance though: vi should really be at the beginning of the arguments of all the tilde-statements (and ideally assume too, but we can delay this) to minimize method ambiguities but allow us to have some special behavior for the different impls for AbstractVarInfo.

bors bot pushed a commit that referenced this pull request Jul 17, 2021
Looking at `to_namedtuple_expr` it seems like it's leftover code to be compatible with Julia 1.2. Given that DPPL only supports Julia 1.3 or higher I've simplified the implementation, doing away with the `namedtuple` method.

EDIT: This also fixed a bug I ran into when using Zygote + `@submodel` in #267 . Struggling to come up with a MWE of the bug, but just adding merging this fixed it. This was also my original motivation for making this change, as Zygote was complaining about the `namedtuple` for some reason (something about "non-differentiable getfield").

Co-authored-by: Hong Ge <hg344@cam.ac.uk>
src/varinfo.jl Outdated Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Jul 20, 2021

Is this ready for a second look?

@torfjelde
Copy link
Member Author

Not yet, no.

src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/simple_varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
@torfjelde torfjelde merged commit a72594f into tor/immutable-varinfo-support Aug 14, 2021
@torfjelde torfjelde deleted the tor/simple-varinfo-v2 branch August 14, 2021 00:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants