Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UPDATED: Multinomial distribution rand does not support Float32 probability vectors #1738

Merged
merged 11 commits into from
Jun 30, 2023
2 changes: 1 addition & 1 deletion src/samplers/multinomial.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function multinom_rand!(rng::AbstractRNG, n::Int, p::AbstractVector{Float64},
function multinom_rand!(rng::AbstractRNG, n::Int, p::AbstractVector{<:Real},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT one has to update the body of the function as well to avoid type stability issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More concretely, it must not contain any Float64 literals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't is true that the output type of the function can only be x::AbstractVector{eltype(x)} ?
I'm adding the type stability tests and it looks ok but I wanted to make sure since I'm pretty new to Julia.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are internal type instabilities which can cause slow downs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, rp might change its type inside of the loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took care of this aspect.
While doing so I noticed that Binomial is restricted to using Float64, so it may be slower/more accurate than necessary.
This also would affect Multinomial.
This is not needed for me, but it can be a good addition in future.

x::AbstractVector{<:Real})
k = length(p)
length(x) == k || throw(DimensionMismatch("Invalid argument dimension."))
Expand Down
7 changes: 6 additions & 1 deletion test/multivariate/multinomial.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Tests for Multinomial

using Distributions, Random, StaticArrays
using Distributions, Random, StaticArrays, ForwardDiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed for the changes below it seems:

Suggested change
using Distributions, Random, StaticArrays, ForwardDiff
using Distributions, Random, StaticArrays

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove it happily, as I don't personally think it really belongs here.
But at least for completeness sake, here is the source of this addition:
#1033 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems it was added because someone wanted dual numbers to be tested as well - but later it was figured out that for other reasons even with the changes in the PR they don't work and the tests were removed again.

So it seems fine to remove this import.

using Test


Expand Down Expand Up @@ -210,3 +210,8 @@ p_v = [0.1, 0.4, 0.3, 0.8]
@test_throws DomainError Multinomial(10, p_v)
@test_throws DomainError Multinomial(10, p_v; check_args=true)
Multinomial(10, p_v; check_args=false) # should not warn

# check different prob vector types
p = [0.2, 0.4, 0.3, 0.1]
@test (rand(Multinomial(10, p)); true)
@test (rand(Multinomial(10, convert.(Float32, p))); true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use better tests that test correctness instead of only "not erroring"? For instance by looping over Float64 and Float32 probabilities in the tests above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow.
What kind of correctness do you mean?
As for type stability, I'm looking at adding some @inferred tests here for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That the samples follow the desired distribution - basically, the same tests as for Float64.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer running the existing tests with Float32 as well over a few additional tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

took care of this ✔️