From f889f9e56b0243d770c195b3eee8baef4880bd2e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 30 May 2022 15:18:25 +0200 Subject: [PATCH] Respect type of parameters in `fit` for `Bernoulli`, `Binomial`, and `Uniform` (#1558) * Respect parameter type in `fit` for `Bernoulli` and `Binomial` * Respect parameter type in `fit` for `Uniform` --- Project.toml | 2 +- src/univariate/continuous/uniform.jl | 6 ++--- src/univariate/discrete/bernoulli.jl | 4 +-- src/univariate/discrete/binomial.jl | 18 ++++++------- test/fit.jl | 40 ++++++++++++++++------------ 5 files changed, 38 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 21fad4cb9..354ba2837 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.61" +version = "0.25.62" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/univariate/continuous/uniform.jl b/src/univariate/continuous/uniform.jl index 089c64efc..f707b38be 100644 --- a/src/univariate/continuous/uniform.jl +++ b/src/univariate/continuous/uniform.jl @@ -26,7 +26,7 @@ External links struct Uniform{T<:Real} <: ContinuousUnivariateDistribution a::T b::T - Uniform{T}(a::T, b::T) where {T <: Real} = new{T}(a, b) + Uniform{T}(a::Real, b::Real) where {T <: Real} = new{T}(a, b) end function Uniform(a::T, b::T; check_args::Bool=true) where {T <: Real} @@ -125,11 +125,11 @@ _rand!(rng::AbstractRNG, d::Uniform, A::AbstractArray{<:Real}) = #### Fitting -function fit_mle(::Type{<:Uniform}, x::AbstractArray{<:Real}) +function fit_mle(::Type{T}, x::AbstractArray{<:Real}) where {T<:Uniform} if isempty(x) throw(ArgumentError("x cannot be empty.")) end - return Uniform(extrema(x)...) + return T(extrema(x)...) end # ChainRules definitions diff --git a/src/univariate/discrete/bernoulli.jl b/src/univariate/discrete/bernoulli.jl index b4aead7ee..d27bc9b50 100644 --- a/src/univariate/discrete/bernoulli.jl +++ b/src/univariate/discrete/bernoulli.jl @@ -27,7 +27,7 @@ External links: struct Bernoulli{T<:Real} <: DiscreteUnivariateDistribution p::T - Bernoulli{T}(p::T) where {T <: Real} = new{T}(p) + Bernoulli{T}(p::Real) where {T <: Real} = new{T}(p) end function Bernoulli(p::Real; check_args::Bool=true) @@ -120,7 +120,7 @@ end BernoulliStats(c0::Real, c1::Real) = BernoulliStats(promote(c0, c1)...) -fit_mle(::Type{<:Bernoulli}, ss::BernoulliStats) = Bernoulli(ss.cnt1 / (ss.cnt0 + ss.cnt1)) +fit_mle(::Type{T}, ss::BernoulliStats) where {T<:Bernoulli} = T(ss.cnt1 / (ss.cnt0 + ss.cnt1)) function suffstats(::Type{<:Bernoulli}, x::AbstractArray{<:Integer}) c0 = c1 = 0 diff --git a/src/univariate/discrete/binomial.jl b/src/univariate/discrete/binomial.jl index 5a2833a54..e8e17afdd 100644 --- a/src/univariate/discrete/binomial.jl +++ b/src/univariate/discrete/binomial.jl @@ -187,15 +187,15 @@ end const BinomData = Tuple{Int, AbstractArray} -suffstats(::Type{<:Binomial}, data::BinomData) = suffstats(Binomial, data...) -suffstats(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = suffstats(Binomial, data..., w) +suffstats(::Type{T}, data::BinomData) where {T<:Binomial} = suffstats(T, data...) +suffstats(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = suffstats(T, data..., w) -fit_mle(::Type{<:Binomial}, ss::BinomialStats) = Binomial(ss.n, ss.ns / (ss.ne * ss.n)) +fit_mle(::Type{T}, ss::BinomialStats) where {T<:Binomial} = T(ss.n, ss.ns / (ss.ne * ss.n)) -fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}) = fit_mle(Binomial, suffstats(Binomial, n, x)) -fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, n, x, w)) -fit_mle(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, suffstats(Binomial, data)) -fit_mle(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, data, w)) +fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}) where {T<:Binomial}= fit_mle(T, suffstats(T, n, x)) +fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, suffstats(T, n, x, w)) +fit_mle(::Type{T}, data::BinomData) where {T<:Binomial} = fit_mle(T, suffstats(T, data)) +fit_mle(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, suffstats(T, data, w)) -fit(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, data) -fit(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, data, w) +fit(::Type{T}, data::BinomData) where {T<:Binomial} = fit_mle(T, data) +fit(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, data, w) diff --git a/test/fit.jl b/test/fit.jl index 904fd9192..50ef4499a 100644 --- a/test/fit.jl +++ b/test/fit.jl @@ -34,9 +34,9 @@ end @testset "Testing fit for Bernoulli" begin - for rng in ((), (rng,)), D in (Bernoulli, Bernoulli{Float64}) + for rng in ((), (rng,)), D in (Bernoulli, Bernoulli{Float64}, Bernoulli{Float32}) v = rand(rng..., n0) - z = rand(rng..., D(0.7), n0) + z = rand(rng..., Bernoulli(0.7), n0) for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2)) ss = @inferred suffstats(D, x) @test ss isa Distributions.BernoulliStats @@ -57,7 +57,7 @@ end @test mean(d) ≈ sum(v[z .== 1]) / sum(v) end - z = rand(rng..., D(0.7), N) + z = rand(rng..., Bernoulli(0.7), N) for x in (z, OffsetArray(z, -N ÷ 2)) d = @inferred fit(D, x) @test d isa D @@ -82,9 +82,9 @@ end end @testset "Testing fit for Binomial" begin - for rng in ((), (rng,)), D in (Binomial, Binomial{Float64}) + for rng in ((), (rng,)), D in (Binomial, Binomial{Float64}, Binomial{Float32}) v = rand(rng..., n0) - z = rand(rng..., D(100, 0.3), n0) + z = rand(rng..., Binomial(100, 0.3), n0) for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2)) ss = @inferred suffstats(D, (100, x)) @test ss isa Distributions.BinomialStats @@ -109,7 +109,7 @@ end @test succprob(d) ≈ dot(z, v) / (sum(v) * 100) end - z = rand(rng..., D(100, 0.3), N) + z = rand(rng..., Binomial(100, 0.3), N) for x in (z, OffsetArray(z, -N ÷ 2)) d = @inferred fit(D, 100, x) @test d isa D @@ -291,18 +291,24 @@ end end @testset "Testing fit for Uniform" begin - for func in funcs, dist in (Uniform, Uniform{Float64}) - x = func[2](dist(1.2, 5.8), n0) - d = fit(dist, x) - @test isa(d, dist) - @test 1.2 <= minimum(d) <= maximum(d) <= 5.8 - @test minimum(d) == minimum(x) - @test maximum(d) == maximum(x) + for rng in ((), (rng,)), D in (Uniform, Uniform{Float64}, Uniform{Float32}) + z = rand(rng..., Uniform(1.2, 5.8), n0) + for x in (z, OffsetArray(z, -n0 ÷ 2)) + d = fit(D, x) + @test d isa D + @test 1.2 <= minimum(d) <= maximum(d) <= 5.8 + @test minimum(d) == partype(d)(minimum(z)) + @test maximum(d) == partype(d)(maximum(z)) + end - d = fit(dist, func[2](dist(1.2, 5.8), N)) - @test 1.2 <= minimum(d) <= maximum(d) <= 5.8 - @test isapprox(minimum(d), 1.2, atol=0.02) - @test isapprox(maximum(d), 5.8, atol=0.02) + z = rand(rng..., Uniform(1.2, 5.8), N) + for x in (z, OffsetArray(z, -N ÷ 2)) + d = fit(D, x) + @test d isa D + @test 1.2 <= minimum(d) <= maximum(d) <= 5.8 + @test minimum(d) ≈ 1.2 atol=0.02 + @test maximum(d) ≈ 5.8 atol=0.02 + end end end