Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Resample variable if not given in setval! #216

Closed
wants to merge 26 commits into from

Conversation

torfjelde
Copy link
Member

Currently if one calls DynamicPPL._setval!(vi, vi.metadata, values, keys) , then only those values present in keys will be set, as expected, but the variables which are not present in keys will simply be left as-is. This means that we get the following behavior:

julia> using Turing

julia> @model function demo(x)
           m ~ Normal(0, 1)
           for i in eachindex(x)
               x[i] ~ Normal(m, 1)
           end
       end
demo (generic function with 1 method)

julia> m_missing = demo(fill(missing, 2));

julia> var_info_missing = DynamicPPL.VarInfo(m_missing);

julia> var_info_missing.metadata.m.vals
1-element Array{Float64,1}:
 0.7251417347423874

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408
julia> var_info_missing.metadata.m.vals # ✓ new value
1-element Array{Float64,1}:
 0.0
julia> var_info_missing.metadata.x.vals # ✓ still the same value
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> m_missing(var_info_missing) # Re-run the model with new value for `m`

julia> var_info_missing.metadata.x.vals # × still the same and thus not reflecting the change in `m`!
2-element Array{Float64,1}:
 1.2576791054418153
0.764913349211408

Personally I expected x to be resampled since now parts of the model has changed and thus the sample x is no longer representative of a sample from the model (under the sampler used).

This PR "fixes" the above so that you get the following behavior:

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> DynamicPPL.setval!(var_info_missing, (m = 0.0, ));

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> m_missing(var_info_missing)

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 -2.0493130638394947
  0.3881955730968598

This was discoverd when debugging TuringLang/Turing.jl#1352 as I want to move Turing.predict over to using DynamicPPL.setval! and it also has consequences for DynamicPPL.generated_quantities which uses DynamicPPL.setval! under the hood and thus suffer from the same issue.

There's an alternative: instead of making this the default-behavior, we could add kwargs... to setval! which includes resample_missing::Bool or something. I'm also completely fine with a solution like that 👍

@devmotion
Copy link
Member

I think the new behaviour might be even more surprising (and I wonder if it actually breaks something else?). When introducing _setval!, my intention was actually that it works in the same way as setindex! but with collections of variable names and their values - so it is intentional that it just updates these values but does not touch anything else.

My experience with setindex!, setval! etc. is that they cause a lot of confusion and bugs since they all update the VarInfo object but in slightly different ways and there's too much "magic" (e.g., sometimes - also depending on the sampler - the unconstrained values are updated, sometimes not). So I think it would be good to have dedicated functions for just updating the variables, and keep that separate from any potential resampling. Therefore I would prefer a completely separate function instead of some "magic" keyword arguments.

If this behaviour is only used in predict, maybe it would make sense to just change predict and explicitly mark the non-updated variables to be deleted? (BTW I guess it is a bit inefficient with the current design of VarInfo due to all the linearization/ranges stuff, but wouldn't it be even more reasonable/correct to completely delete the other variables? If the other variables would not be resampled when rerunning the model (e.g., due to an early return statement in the model or some stochastic control flow...) the old outdated values would still exist in the VarInfo object)

@torfjelde
Copy link
Member Author

torfjelde commented Mar 23, 2021

so it is intentional that it just updates these values but does not touch anything else

Makes sense 👍

Therefore I would prefer a completely separate function instead of some "magic" keyword arguments.

So I don't disagree with not wanting to introduce magic kwargs, but this would mean that you'd just copy-paste the current implementation of _setval! and change it to also setflag!; this doesn't seem like a good idea IMO. Would it make sense to separate out the logic of matching vn in a given keys argument and then methods like _setval! will call this internally? Then we can introduce a setval_or_resample! or something that will also use this method internally to determine which variables to update and which to sample.

If this behaviour is only used in predict

It also affect generated_quantities (which is located here in DynamicPPL) since the "expected" behavior of generated_quantities would be to regenerate quantities according to the set values and thus if you don't set all of them you get something stochastic.

And the reason why I made this PR in the first place is because fixes to exactly the same functionality was implemented in DynamicPPL, but we had a "manual" implementation in predict those fixes of course didn't propagate, i.e. we have to fix the same bug twice. IMO this is not a sustainable approach.

And even so, I do think there should be an "easy" way of telling DynamicPPL "Hey, can you resample these variables for me when I run the model next?"

@devmotion
Copy link
Member

Sure, regardless of how it is done in the end, we should avoid to copy-paste stuff (and fix bugs twice).

One of my concerns was just that the proposed change here is not sufficient to achieve what you want - you would like to resample the other variables but with the change here they are just marked to be resampled when you run the model the next time. So you have to rerun the model separately to get the desired result. Also it seems this could become problematic if you update the VarInfo object multiple times - the "other" variables of the last updates would be marked to be resampled, regardless of the previous updates. And, as mentioned above, even after rerunning the model the VarInfo object might not be "clean" - if we did not try to resample the to-be-resampled variables (e.g., due to an early return statement), the old values will still be present.

Therefore to me it seemed natural to split it up into two functions: one that only updates values (as intended originally) and one that reruns the model with a specific subset of variables being fixed (not completely sure what the optimal implementation would be but maybe it would be more natural to provide the fixed variable names as part of the (existing) contexts than to set a "del" flag; it seems one would still have to remove the other variables from VarInfo before sampling to get a "clean" object).

@torfjelde
Copy link
Member Author

torfjelde commented Mar 24, 2021

One of my concerns was just that the proposed change here is not sufficient to achieve what you want - you would like to resample the other variables but with the change here they are just marked to be resampled when you run the model the next time. So you have to rerun the model separately to get the desired result. Also it seems this could become problematic if you update the VarInfo object multiple times - the "other" variables of the last updates would be marked to be resampled, regardless of the previous updates. And, as mentioned above, even after rerunning the model the VarInfo object might not be "clean" - if we did not try to resample the to-be-resampled variables (e.g., due to an early return statement), the old values will still be present.

Good points 👍 But could we actually handle this for TypedVarInfo though?
Also I want to point out that I still don't intend to tell users to start using setval! or setval_or_resample!; this is mostly for internal functionality, e.g. generated_quantities, and so we can ensure that the behavior above doesn't happen, e.g. setting VarInfo multiple times.

Therefore to me it seemed natural to split it up into two functions: one that only updates values (as intended originally) and one that reruns the model with a specific subset of variables being fixed (not completely sure what the optimal implementation would be but maybe it would be more natural to provide the fixed variable names as part of the (existing) contexts than to set a "del" flag; it seems one would still have to remove the other variables from VarInfo before sampling to get a "clean" object).

Agree; if this can be done, I'm all for it! But for the context you'd have to then provide both:

  • symbols to be sampled
  • symbols to be set

But AFAIK this doesn't actually solve the problem you raised above since in functions such as generated_quantities we're dependent on running the model once to get the variables present in the model and thus in the case of early returns you'd still have the issue of potentially missing variables.

To actually fix that problem you'd have to add an additional field to the context:

  • metadata for all symbols

Right? And how we'd go about providing a way of doing that I don't know...

TL;DR: It seems like solving the problem in full generality is veeeery difficult, and so I think I'm leaning towards adding an internal method called setval_or_resample! that will simply:

  1. Set the values present in keys
  2. Prepare values not present in keys to be resampled, either by replacing with missing (but this will also not work for all kinds of indexing due to limitations in how we decide wether a variable is missing or not) or by setting the del flag (which should work for all kinds of indexing but will have issue that old values will be present if we hit early return or something).

This assumes:

  1. Model is static (could specialize this to only work for TypedVarInfo)
  2. All symbols are present in the model, i.e. no early return statements hit.

@torfjelde
Copy link
Member Author

Btw, do you have any thoughts on this @devmotion ? I think we at least need a fix to the issue I referenced above, as the current implementation will silently give you the wrong answer.

@devmotion
Copy link
Member

Sure, we should fix the issue 👍 Let's just apply your suggestion with an internal setval_or_resample! function? And probably good to open an issue about the remaining problems that we discussed above, just so we don't forget about it.

@torfjelde
Copy link
Member Author

torfjelde commented Mar 27, 2021

Awesome!:) I'll make an issue and a PR 👍

EDIT: Doh. The PR is already here 🙃 Haha, thought I was commenting on an issue.

@torfjelde
Copy link
Member Author

Added the method as discussed + docstrings for the two methods (and doctests)

src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

I'm cleaning up the implementation (want to share more code between setval! and setval_and_resample!) and I'll address the comments on the docstring too.

torfjelde and others added 2 commits March 28, 2021 18:09
src/varinfo.jl Outdated Show resolved Hide resolved
src/varname.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

Is the PR ready for another round of review?

@torfjelde
Copy link
Member Author

Yep!:)

src/varinfo.jl Outdated Show resolved Hide resolved
src/varname.jl Outdated Show resolved Hide resolved
torfjelde and others added 3 commits April 2, 2021 18:29
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're almost done 🙂 I noticed some additional things that should be fixed, can you also bump the version? And feel free to drop the ::Function type constraints.

src/varname.jl Outdated Show resolved Hide resolved
test/Project.toml Show resolved Hide resolved
test/varinfo.jl Outdated

### `setval` ###
DynamicPPL.setval!(vicopy, (m = zeros(5),))
# Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we could fix? Should we open an issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens because we only try to locate variables present in vi in keys, not also the other direciton, i.e. locate variables present in keys in vi.

So in this particular case we have vi.metadata.m.vns containing [m[1], m[2], ...] while keys(x) is just (m, ).

I guess maybe we could also match the other way, if we can't find vi -> keys mapping? Could this lead some weird behavior though?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would also require a rewrite again I believe, since the current impl of _apply! iterates over vns 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe leave for now since it won't do worse than the current impl and make an issue?
Kind of eager to get this out as I know people are waiting for a fix to this (+ generated_quantities might be silently doing something wrong in some models).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing though: should we maybe warn the user if they provide a key that isn't present in VarInfo? At least we won't do something unexpected without telling them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to understand what the actual limitation is here. So I suggest:

  • add a short explanation in the comment (IMO it's a bit unclear what limitation of subsumes or rather subsumes_string this refers to)
  • add a warning if a user provides a key that is not present in VarInfo (I like the idea and I think it is better to warn than to fail silently)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Check if you think this is okay:)

src/varinfo.jl Outdated
Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`.
"""
function _apply!(kernel!, vi::AbstractVarInfo, values, keys)
indices_seen = Set(1:length(keys))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it maybe be sufficient to not keep track of the indices but instead return and count the number of keys that were used? Then one could just check in the end if the total number of used keys is equal to length(keys).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I did initially, but it kind of sucks that the user then doesn't' get any information about the missing variables. I think it's a bit difficult for someone not too familiar with Turing's internals to understand which variable is actually causing the issue?

But it does have a performance cost 😕

You choose! I'm for whatever:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to compute the indices separately, using subsumes_string, only if the length does not match the number of counts. That should not impact performance (much) in the common case where all keys are used but the warning message would be more informative.

Another performance improvement might be possible by not calling map(string, keys) in every call of the kernel function but only once in the outer function.

Copy link
Member Author

@torfjelde torfjelde Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eeh, I'm struggling a bit to see exactly how we're to compute the indices separately at the moment 😕 I def believe there's a way, but I am veery tired (owm to bed atm) so I think I just need to some sleep and then I can have a another look 😅

I see what you mean by doing map(string, keys) outside of the kernel! call though 👍 Pushing that change now.

Copy link
Member

@devmotion devmotion Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagined one could use something like

string_vns = map(string, Base.keys(vi))
unused_keys = filter(keys_strings) do key
    !any(Base.Fix2(subsumes_string, key), string_vns)
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't exist for TypedVarInfo, but I can add the following (IMO we should actually have this, even though it's the preferred way to iterate through the varnames):

@generated function keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names}
    expr = Expr(:call)
    push!(expr.args, :vcat)

    for n in names
        push!(expr.args, :(vi.metadata.$n.vns))
    end

    return expr
end

Then I can just do what you suggest yeah:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made that change. Kind of don't like the "private" method _find_missing_keys but as we spoke about before, we should aim to find a better way than this string-based one anyways in the not too distance future.

test/Project.toml Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)
Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)

@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some tests for keys I guess.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Also added a collectmaybe method to deal with the case where keys are AbstractSet, in which case map is not defined.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We can reiterate and improve it in other PRs (e.g., according to the docstring keys returns an iterator - maybe a non-allocating version can be implemented for TypedVarInfo as well?). Just bump the version and run bors 🙂

@torfjelde
Copy link
Member Author

torfjelde commented Apr 3, 2021

(e.g., according to the docstring keys returns an iterator - maybe a non-allocating version can be implemented for TypedVarInfo as well?)

Yeaaah noticed but e.g. the current impl doesn't implement a non-allocating iterator (it's a KeySet I believe?) and so I was like "Eh, I guess it's whatever. I'll see what David says". But as you said, prob better to improve both in another PR:)

KeySet is of course non-allocating, duh. But yeah, another PR 😅

Version bumped!

@torfjelde
Copy link
Member Author

bors r+

bors bot pushed a commit that referenced this pull request Apr 3, 2021
Currently if one calls `DynamicPPL._setval!(vi, vi.metadata, values, keys)` , then only those values present in `keys` will be set, as expected, but the variables which are _not_ present in `keys` will simply be left as-is. This means that we get the following behavior:
``` julia
julia> using Turing

julia> @model function demo(x)
           m ~ Normal(0, 1)
           for i in eachindex(x)
               x[i] ~ Normal(m, 1)
           end
       end
demo (generic function with 1 method)

julia> m_missing = demo(fill(missing, 2));

julia> var_info_missing = DynamicPPL.VarInfo(m_missing);

julia> var_info_missing.metadata.m.vals
1-element Array{Float64,1}:
 0.7251417347423874

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408
julia> var_info_missing.metadata.m.vals # ✓ new value
1-element Array{Float64,1}:
 0.0
julia> var_info_missing.metadata.x.vals # ✓ still the same value
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> m_missing(var_info_missing) # Re-run the model with new value for `m`

julia> var_info_missing.metadata.x.vals # × still the same and thus not reflecting the change in `m`!
2-element Array{Float64,1}:
 1.2576791054418153
0.764913349211408
```

_Personally_ I expected `x` to be resampled since now parts of the model has changed and thus the sample `x` is no longer representative of a sample from the model (under the sampler used).

This PR "fixes" the above so that you get the following behavior:
``` julia
julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> DynamicPPL.setval!(var_info_missing, (m = 0.0, ));

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> m_missing(var_info_missing)

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 -2.0493130638394947
  0.3881955730968598
```

This was discoverd when debugging TuringLang/Turing.jl#1352 as I want to move `Turing.predict` over to using `DynamicPPL.setval!` and it also has consequences for `DynamicPPL.generated_quantities` which uses `DynamicPPL.setval!` under the hood and thus suffer from the same issue.


There's an alternative: instead of making this the default-behavior, we could add `kwargs...` to `setval!` which includes `resample_missing::Bool` or something. I'm also completely fine with a solution like that 👍
@bors bors bot changed the title Resample variable if not given in setval! [Merged by Bors] - Resample variable if not given in setval! Apr 3, 2021
@bors bors bot closed this Apr 3, 2021
@bors bors bot deleted the tor/minor-change-to-setval branch April 3, 2021 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants