Skip to content

Commit

Permalink
Respect type of parameters in fit for Bernoulli, Binomial, and …
Browse files Browse the repository at this point in the history
…`Uniform` (#1558)

* Respect parameter type in `fit` for `Bernoulli` and `Binomial`

* Respect parameter type in `fit` for `Uniform`
  • Loading branch information
devmotion authored May 30, 2022
1 parent a350622 commit f889f9e
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Distributions"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
authors = ["JuliaStats"]
version = "0.25.61"
version = "0.25.62"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 3 additions & 3 deletions src/univariate/continuous/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ External links
struct Uniform{T<:Real} <: ContinuousUnivariateDistribution
a::T
b::T
Uniform{T}(a::T, b::T) where {T <: Real} = new{T}(a, b)
Uniform{T}(a::Real, b::Real) where {T <: Real} = new{T}(a, b)
end

function Uniform(a::T, b::T; check_args::Bool=true) where {T <: Real}
Expand Down Expand Up @@ -125,11 +125,11 @@ _rand!(rng::AbstractRNG, d::Uniform, A::AbstractArray{<:Real}) =

#### Fitting

function fit_mle(::Type{<:Uniform}, x::AbstractArray{<:Real})
function fit_mle(::Type{T}, x::AbstractArray{<:Real}) where {T<:Uniform}
if isempty(x)
throw(ArgumentError("x cannot be empty."))
end
return Uniform(extrema(x)...)
return T(extrema(x)...)
end

# ChainRules definitions
Expand Down
4 changes: 2 additions & 2 deletions src/univariate/discrete/bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ External links:
struct Bernoulli{T<:Real} <: DiscreteUnivariateDistribution
p::T

Bernoulli{T}(p::T) where {T <: Real} = new{T}(p)
Bernoulli{T}(p::Real) where {T <: Real} = new{T}(p)
end

function Bernoulli(p::Real; check_args::Bool=true)
Expand Down Expand Up @@ -120,7 +120,7 @@ end

BernoulliStats(c0::Real, c1::Real) = BernoulliStats(promote(c0, c1)...)

fit_mle(::Type{<:Bernoulli}, ss::BernoulliStats) = Bernoulli(ss.cnt1 / (ss.cnt0 + ss.cnt1))
fit_mle(::Type{T}, ss::BernoulliStats) where {T<:Bernoulli} = T(ss.cnt1 / (ss.cnt0 + ss.cnt1))

function suffstats(::Type{<:Bernoulli}, x::AbstractArray{<:Integer})
c0 = c1 = 0
Expand Down
18 changes: 9 additions & 9 deletions src/univariate/discrete/binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ end

const BinomData = Tuple{Int, AbstractArray}

suffstats(::Type{<:Binomial}, data::BinomData) = suffstats(Binomial, data...)
suffstats(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = suffstats(Binomial, data..., w)
suffstats(::Type{T}, data::BinomData) where {T<:Binomial} = suffstats(T, data...)
suffstats(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = suffstats(T, data..., w)

fit_mle(::Type{<:Binomial}, ss::BinomialStats) = Binomial(ss.n, ss.ns / (ss.ne * ss.n))
fit_mle(::Type{T}, ss::BinomialStats) where {T<:Binomial} = T(ss.n, ss.ns / (ss.ne * ss.n))

fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}) = fit_mle(Binomial, suffstats(Binomial, n, x))
fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, n, x, w))
fit_mle(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, suffstats(Binomial, data))
fit_mle(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, data, w))
fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}) where {T<:Binomial}= fit_mle(T, suffstats(T, n, x))
fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, suffstats(T, n, x, w))
fit_mle(::Type{T}, data::BinomData) where {T<:Binomial} = fit_mle(T, suffstats(T, data))
fit_mle(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, suffstats(T, data, w))

fit(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, data)
fit(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, data, w)
fit(::Type{T}, data::BinomData) where {T<:Binomial} = fit_mle(T, data)
fit(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, data, w)
40 changes: 23 additions & 17 deletions test/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ end


@testset "Testing fit for Bernoulli" begin
for rng in ((), (rng,)), D in (Bernoulli, Bernoulli{Float64})
for rng in ((), (rng,)), D in (Bernoulli, Bernoulli{Float64}, Bernoulli{Float32})
v = rand(rng..., n0)
z = rand(rng..., D(0.7), n0)
z = rand(rng..., Bernoulli(0.7), n0)
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
ss = @inferred suffstats(D, x)
@test ss isa Distributions.BernoulliStats
Expand All @@ -57,7 +57,7 @@ end
@test mean(d) sum(v[z .== 1]) / sum(v)
end

z = rand(rng..., D(0.7), N)
z = rand(rng..., Bernoulli(0.7), N)
for x in (z, OffsetArray(z, -N ÷ 2))
d = @inferred fit(D, x)
@test d isa D
Expand All @@ -82,9 +82,9 @@ end
end

@testset "Testing fit for Binomial" begin
for rng in ((), (rng,)), D in (Binomial, Binomial{Float64})
for rng in ((), (rng,)), D in (Binomial, Binomial{Float64}, Binomial{Float32})
v = rand(rng..., n0)
z = rand(rng..., D(100, 0.3), n0)
z = rand(rng..., Binomial(100, 0.3), n0)
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
ss = @inferred suffstats(D, (100, x))
@test ss isa Distributions.BinomialStats
Expand All @@ -109,7 +109,7 @@ end
@test succprob(d) dot(z, v) / (sum(v) * 100)
end

z = rand(rng..., D(100, 0.3), N)
z = rand(rng..., Binomial(100, 0.3), N)
for x in (z, OffsetArray(z, -N ÷ 2))
d = @inferred fit(D, 100, x)
@test d isa D
Expand Down Expand Up @@ -291,18 +291,24 @@ end
end

@testset "Testing fit for Uniform" begin
for func in funcs, dist in (Uniform, Uniform{Float64})
x = func[2](dist(1.2, 5.8), n0)
d = fit(dist, x)
@test isa(d, dist)
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
@test minimum(d) == minimum(x)
@test maximum(d) == maximum(x)
for rng in ((), (rng,)), D in (Uniform, Uniform{Float64}, Uniform{Float32})
z = rand(rng..., Uniform(1.2, 5.8), n0)
for x in (z, OffsetArray(z, -n0 ÷ 2))
d = fit(D, x)
@test d isa D
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
@test minimum(d) == partype(d)(minimum(z))
@test maximum(d) == partype(d)(maximum(z))
end

d = fit(dist, func[2](dist(1.2, 5.8), N))
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
@test isapprox(minimum(d), 1.2, atol=0.02)
@test isapprox(maximum(d), 5.8, atol=0.02)
z = rand(rng..., Uniform(1.2, 5.8), N)
for x in (z, OffsetArray(z, -N ÷ 2))
d = fit(D, x)
@test d isa D
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
@test minimum(d) 1.2 atol=0.02
@test maximum(d) 5.8 atol=0.02
end
end
end

Expand Down

2 comments on commit f889f9e

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61316

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.25.62 -m "<description of version>" f889f9e56b0243d770c195b3eee8baef4880bd2e
git push origin v0.25.62

Please sign in to comment.