-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full update of weighted index by assigning weights #1194
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
There are a few comments below. Likely the new
constructor should be updated slightly too.
Could we have some benchmarks please comparing (1) replacing with a new
instance, (2) assign_weights
and (3) assign_weights_unchecked
.
src/distributions/weighted_index.rs
Outdated
@@ -130,6 +130,72 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { | |||
}) | |||
} | |||
|
|||
/// Updates all weights by recalculating the index, without changing the number of weights. | |||
/// | |||
/// **NOTE:** if `weights` contains invalid elements (for example, `f64::NAN` in the case of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NaN will fail the w >= &zero
check. What wouldn't be caught is +inf (or a sum of weights which overflows to +inf). Possibly we should add a check for this (total_weight.is_finite()
).
Despite this comment, the cases which are caught are identical to those of the new
constructor. Possibly both need updating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I noticed that as well. There is also is_normal()
, but subnormal values are probably not of concern.
The bigger problem however is that at the moment WeightedIndex
is valid for all X: SampleUniform
, which includes integers, while is_finite
only applies to floats.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point... I don't think we have any way of dealing with this.
The debug asserts in impl UniformSampler for UniformFloat<$ty>
will catch this, but it doesn't seem ideal.
src/distributions/weighted_index.rs
Outdated
/// partially updated index is undefined. It is the user's responsibility to not sample from | ||
/// the index upon encountering an error. The index may be used again after assigning a new set | ||
/// of weights that do not result in an error. | ||
pub fn assign_weights(&mut self, weights: &[X]) -> Result<(), WeightedError > |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new
takes an iterator while this method takes a slice, which is inconsistent. There's no real reason we can't use an iterator here (should be benchmarked but I suspect perf. will be very similar).
@vks do you think we should use an iterator for consistency?
But if we do, we have an additional choice: require ExactSizeIterator
or just test we finish with the right length? I think I favour using ExactSizeIterator
but I haven't thought a lot about it (it's also the more restricted choice: potentially we could switch away from it later if required).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's a good idea as that's strictly more general. Since I am zipping internally anyways, this should not change anything with regards to slices and vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented the change, but leaving the convo open because of the question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using an iterator makes sense, unless the slice optimizes better.
I have implemented the changes, including removing the I could not act on the Benchmarks showing that assignment via exact size iterator gives a nice little speedbump. Roughly 3x to 5x. test weighted_index_assignment ... bench: 27 ns/iter (+/- 0)
test weighted_index_assignment_large ... bench: 395 ns/iter (+/- 10)
test weighted_index_creation ... bench: 97 ns/iter (+/- 0)
test weighted_index_creation_large ... bench: 2,079 ns/iter (+/- 18)
test weighted_index_modification ... bench: 26 ns/iter (+/- 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent.
I'll let @vks take a look before merging.
A custom trait would be the right choice but overall I think it's better to leave this as it is: a public trait adds another complication to the API for little gain, while a private (sealed) trait restricts usage to std types. Note that for integer types there's already a check in debug builds: overflow when the sum gets too large. Again, this is not ideal, but whether it is worth checking for overflow is questionable. |
I think we haven't addressed what to do with the index if assignment fails mid update. Right now it's going to be filled with garbage if it encounters Nan, so one probably shouldn't sample on it. :-D |
A partial update could also be handled by setting the length of the weights to zero, this should make all other calls panic. |
If any of the weights is NaN or inf, then Not the most elegant handler but still sufficient in my opinion. Perhaps the docs should mention that NaN/inf weight will result in a panic. |
That's not actually the case. :-( We are returning early so that The sampler as well as the total weight as stored in the weighted index itself are not updated until the very end of the function. |
Ugh; you're right. We could just not return early (use a A panic is still not ideal in this context, but it's what In the mean-time, perhaps the right thing to do here is to use a sealed trait (a "pub" trait in a private module) to enable the checks we need. In the future we may be able to drop the trait, which won't be a breaking change. Caveat: using the fixed bounds-detection in |
If #1195 is implemented we can avoid the need for an extra trait bound. For now I suggest getting this PR merged without depending on that, however. |
I need further clarification. The issues we are discussing are somewhat orthogonal, and I think the sealed trait part warrants its own PR.
@dhardy Did I understand 2. correctly, or did you want me to do a completely different trait? |
@SuperFluffy you are concerned with the state of |
Turns out the way I am now:
If you are happy with these change I can squash the commits. |
The newest changes were not good for perf:
|
Again enforcing assignment to be equal length to do a lockstep zip between weights iterator and cumulative weights.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are four options:
- Use
std::panic::catch_unwind
. This is "not recommended for a general try/catch mechanism" and comes with various warnings, but may be okay. - Change
Sampler::new
to returnResult
on error — but this is beyond the scope of this PR: Error handling of distributions::Uniform::new #1195. - Use another trait bound to let us directly check the weight is finite. This fixes the specific case of
f32
/f64
but not necessarily user-extensions to theUniform
distribution. - Simply state that if the method fails, results of sampling the distribution are undefined (within certain bounds).
return Err(WeightedError::AllWeightsZero); | ||
}; | ||
|
||
self.weight_distribution = X::Sampler::new(zero, total_weight.clone()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's still a problem: this panics if total_weight
is +inf, and we don't catch panics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but here WeightedIndex::new
suffers from the same issue. So if we want to address this, both assign_new_weights
and new
should be changed in a new PR, I think.
Alright, I went with 4., mentioning that the results of sampling the distribution are undefined. Regarding your other comment about |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I think I'm happy with this now, but I'll let @vks take another look before merging.
ff9d38a
to
79f928f
Compare
79f928f
to
00dd89e
Compare
@vks Addressed all your points and squashed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! The new code is unfortunately not compatible with Rust 1.36:
error[E0277]: `[f64; 4]` is not an iterator
--> src/distributions/weighted_index.rs:475:29
|
475 | let mut distr = WeightedIndex::new([1.0f64, 2.0, 3.0, 0.0]).unwrap();
| ^^^^^^^^^^^^^^^^^^ borrow the array with `&` or call `.iter()` on it to iterate over it
|
= help: the trait `core::iter::Iterator` is not implemented for `[f64; 4]`
= note: arrays are not iterators, but slices like the following are: `&[1, 2, 3]`
= note: required because of the requirements on the impl of `core::iter::IntoIterator` for `[f64; 4]`
error[E0277]: the trait bound `[f64; 3]: core::iter::ExactSizeIterator` is not satisfied
--> src/distributions/weighted_index.rs:476:29
|
476 | let res = distr.assign_new_weights([1.0f64, 2.0, 3.0]);
| ^^^^^^^^^^^^^^^^^^ the trait `core::iter::ExactSizeIterator` is not implemented for `[f64; 3]`
error[E0277]: `[f64; 4]` is not an iterator
--> src/distributions/weighted_index.rs:480:29
|
480 | let mut distr = WeightedIndex::new([1.0f64, 2.0, 3.0, 0.0]).unwrap();
| ^^^^^^^^^^^^^^^^^^ borrow the array with `&` or call `.iter()` on it to iterate over it
|
= help: the trait `core::iter::Iterator` is not implemented for `[f64; 4]`
= note: arrays are not iterators, but slices like the following are: `&[1, 2, 3]`
= note: required because of the requirements on the impl of `core::iter::IntoIterator` for `[f64; 4]`
note: required by `distributions::weighted_index::WeightedIndex::<X>::new`
--> src/distributions/weighted_index.rs:97:5
|
97 | / pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
98 | | where
99 | | I: IntoIterator,
100 | | I::Item: SampleBorrow<X>,
... |
131 | | })
132 | | }
| |_____^
error[E0599]: no associated item named `NAN` found for type `f64` in the current scope
--> src/distributions/weighted_index.rs:481:67
|
481 | let res = distr.assign_new_weights([1.0f64, 2.0, f64::NAN, 0.0]);
| ^^^ associated item not found in `f64`
|
= help: items from traits can only be used if the trait is in scope
= note: the following trait is implemented but not in scope, perhaps add a `use` for it:
`use core::num::dec2flt::rawfp::RawFloat;`
error[E0277]: `[u32; 4]` is not an iterator
--> src/distributions/weighted_index.rs:485:29
|
485 | let mut distr = WeightedIndex::new([1u32, 2, 3, 0]).unwrap();
| ^^^^^^^^^^^^^^^^^^ borrow the array with `&` or call `.iter()` on it to iterate over it
|
= help: the trait `core::iter::Iterator` is not implemented for `[u32; 4]`
= note: arrays are not iterators, but slices like the following are: `&[1, 2, 3]`
= note: required because of the requirements on the impl of `core::iter::IntoIterator` for `[u32; 4]`
error[E0277]: the trait bound `[u32; 4]: core::iter::ExactSizeIterator` is not satisfied
--> src/distributions/weighted_index.rs:486:29
|
486 | let res = distr.assign_new_weights([0u32, 0, 0, 0]);
| ^^^^^^^^^^^^^^^^^^ the trait `core::iter::ExactSizeIterator` is not implemented for `[u32; 4]`
error: aborting due to 6 previous errors
f64::NAN
can be replaced withcore::f64::NAN
.- Iterating over slices instead of arrays should fix the other errors.
@dhardy Do you think we can start to merge breaking changes for rand 0.9? |
@vks I guess that depends on whether there are any significant non-breaking changes in master or expected to be merged soon. I don't know but can check tomorrow. If not, then I think we can start merging. |
BREAKING CHANGE: This commit adds a variant to `WeightedError`.
00dd89e
to
f6187ec
Compare
@vks Replaced |
Great, thanks! For the benchmarks it's fine to use newer features, because they require nightly anyway. For the API it's more important to track the MSRV, because this may break crates depending on rand. |
There is a more general API that would solve this with less new code, and avoid The documentation already commits to WeightedIndex<X>::into_cumulative_weights(self) -> Vec<X>;
WeightedIndex<X>::from_cumulative_weights(weights: Vec<X>) -> Result<Self, WeightedError>;
WeightedIndex<X>::from_cumulative_weights_unchecked(weights: Vec<X>) -> Self; The Vecs would have For convenience, Then, a user like @SuperFluffy could achieve the optimized operation in question like this: fn example(&mut self) -> Result<()> {
let mut weights = self.weighted_index.take().into_weights();
update_weights_somehow(&mut weights[..]);
self.weighted_index = WeightedIndex::from_weights(weights)?;
} The possibility of length mismatch is obviated here, and what becomes of the source value in the error case is explicitly the user's choice. This would also support at least one additional use case: if the user needs distributions of varying length at different times, they can reuse one Vec so that it only needs allocation when it reaches a new high-water mark. Incidentally, an example of using the new API in place of |
Interesting points @kazcw. There are two caveats:
Anyway, it does make me consider something else:
|
I actually had written a version that clears the original vector and pushes new elements into it instead of doing a lockstep zip + assignment. This led to exactly the same performance as a just calling |
The
Whereas given the As for the checked vs. unchecked question, the performance of As far as I can see, the advantage of |
Slightly weird, but pushing new elements is definitely slower. Using self.cumulative_weights.resize(iter.len(), zero.clone());
for (w, c) in iter.zip(self.cumulative_weights.iter_mut()) {
// ... Of course, this technique can be used in @kazcw: I think you're right that this isn't the optimal API. I will think further on it. |
It improves it a lot compared to direct pushing, but there is still a significant performance penalty on my M1 ARM machine:
|
I have been thinking about @kacw's suggestion. Their argument applies not just to cumulative weights, but to normal weights as well. We can easily do: WeightedIndex<X>::into_cumulative_weights(self) -> Vec<X>;
# Take the provided weights as they are
WeightedIndex<X>::from_cumulative_weights(weights: Vec<X>) -> Result<Self, WeightedError>;
WeightedIndex<X>::from_cumulative_weights_unchecked(weights: Vec<X>) -> Self;
# Iterate over the weights accumulating a total, and assign that total weight in each iteration
WeightedIndex<X>::from_weights(weights: Vec<X>) -> Result<Self, WeightedError>;
WeightedIndex<X>::from_weights_unchecked(weights: Vec<X>) -> Self; The |
@SuperFluffy As another convenience in the same vein, we might consider So we'd have: /// O(1), no allocations
WeightedIndex<X>::into_cumulative_weights(self) -> Vec<X>;
/// O(N), no allocations
WeightedIndex<X>::to_weights(self) -> Vec<X>;
/// O(N), no allocations
WeightedIndex<X>::from_weights(weights: Vec<X>) -> Result<Self, WeightedError>;
/// O(N), no allocations
WeightedIndex<X>::from_cumulative_weights(weights: Vec<X>) -> Result<Self, WeightedError>;
/// O(1), no allocations, if input is not cumulative then distribution will not yield meaningful samples
WeightedIndex<X>::from_cumulative_weights_unchecked(weights: Vec<X>) -> Self; (Distinguishing |
I had a little play with See: SuperFluffy#1 I didn't bother with |
@dhardy I also pushed my version of I especially like that Here are my results:
NOTE: I should get rid of the |
@SuperFluffy — about your benches — most of the results are extremely small (under 60ns). While the benchmark can fairly reliably time an operation at that level, I'm not convinced that the operation is representative. E.g. the "small" benchmark says that Looking at the "large" variants, Because API simplicity is an important factor, and I have a feeling that we are over-optimising here (without a specific target). |
@dhardy I agree with removing If you want, I will close this PR and submit a fresh PR that contains these. |
@SuperFluffy not very important whether you use a new PR. Notice that my PR reduced line count significantly by making |
@SuperFluffy are you still able to work on this? It would be good to get it merged soon!
Was there anything else to resolve? |
@SuperFluffy can I remind you of this? Both the above issues are now resolved in |
@dhardy apologies for not having responded. I left the job where this was relevant and it wasn't immediately relevant to me. I'm going to find some time and rebased/adjust this PR. Thanks for the reminder |
Closing due to inactivity. We can re-open if someone wishes to work on this again. I have a branch related to this here, but I don't recall any the motivation (probably benchmarking): https://github.com/dhardy/rand/commits/assign_weighted_index/ |
I need to update my weighted indices inside a hot loop. Instead of reconstructing the entire index from scratch, this commit allows updating the inner cumulative weights in-place using a slice of weights via
WeightedIndex::assign_weights
. WeightedIndex::assign_weights_unchecked` is also provided in those cases where the user promises that all weights are valid and their sum exceeds zero.Open questions
How to handle a partial update? If assignment fails during
WeightedIndex::assign_weights
, the index can be left in a partially updated undefined state. The method's doc comment notes that but does not go further than that. I see the following ways to handle this:assign_weights
read the documentation and will keep this caveat in mind.SubAssign
. This will probably require keeping around the old cumulative weights, which implies an extra allocation in the function body, which goes against the point of this new feature.has_errored: bool
toWeightedIndex
, initialized tofalse
. If an error is encountered during assignment, set totrue
. IfWeightedIndex
is used for sampling withhas_errored == True
, panic. I don't recall where I have seen this, but I believe this solution is even used somewhere in the standard library.