Skip to content

Commit

Permalink
Improve Bernoulli and Binomial (#1557)
Browse files Browse the repository at this point in the history
* Improve Bernoulli

- replace `1:length(x)` with `eachindex(x)` so that it works with more
  arrays, for instance `OffsetArrays`
- initialize the counters `c0` and `c1` with `zero(eltype(w))` so that
  they are type-stable (previously they were `Int32/Float64`)
- fix the creation of domain error: the void method `DomainError()`
  does not exist
- make `BernoulliStats` generic so that it can collect counters of more
  general type, for instance `ForwardDiff.Dual`

Update src/univariate/discrete/bernoulli.jl
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

Improve Binomial (porting the changes from Bernoulli)

* Additional fixes

* Add tests

* Bump version

Co-authored-by: David Widmann <david.widmann@it.uu.se>
  • Loading branch information
FedericoStra and devmotion authored May 27, 2022
1 parent 02a8838 commit a350622
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Distributions"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
authors = ["JuliaStats"]
version = "0.25.60"
version = "0.25.61"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
32 changes: 14 additions & 18 deletions src/univariate/discrete/bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,44 +113,40 @@ rand(rng::AbstractRNG, d::Bernoulli) = rand(rng) <= succprob(d)

#### MLE fitting

struct BernoulliStats <: SufficientStats
cnt0::Float64
cnt1::Float64

BernoulliStats(c0::Real, c1::Real) = new(Float64(c0), Float64(c1))
struct BernoulliStats{C<:Real} <: SufficientStats
cnt0::C
cnt1::C
end

BernoulliStats(c0::Real, c1::Real) = BernoulliStats(promote(c0, c1)...)

fit_mle(::Type{<:Bernoulli}, ss::BernoulliStats) = Bernoulli(ss.cnt1 / (ss.cnt0 + ss.cnt1))

function suffstats(::Type{<:Bernoulli}, x::AbstractArray{T}) where T<:Integer
n = length(x)
function suffstats(::Type{<:Bernoulli}, x::AbstractArray{<:Integer})
c0 = c1 = 0
for i = 1:n
@inbounds xi = x[i]
for xi in x
if xi == 0
c0 += 1
elseif xi == 1
c1 += 1
else
throw(DomainError())
throw(DomainError(xi, "samples must be 0 or 1"))
end
end
BernoulliStats(c0, c1)
end

function suffstats(::Type{<:Bernoulli}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Integer
n = length(x)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions."))
c0 = c1 = 0
for i = 1:n
@inbounds xi = x[i]
@inbounds wi = w[i]
function suffstats(::Type{<:Bernoulli}, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real})
length(x) == length(w) || throw(DimensionMismatch("inconsistent argument dimensions"))
z = zero(eltype(w))
c0 = c1 = z + z # possibly widened and different from `z`, e.g., if `z = true`
for (xi, wi) in zip(x, w)
if xi == 0
c0 += wi
elseif xi == 1
c1 += wi
else
throw(DomainError())
throw(DomainError(xi, "samples must be 0 or 1"))
end
end
BernoulliStats(c0, c1)
Expand Down
42 changes: 20 additions & 22 deletions src/univariate/discrete/binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,31 +156,29 @@ end

#### Fit model

struct BinomialStats <: SufficientStats
ns::Float64 # the total number of successes
ne::Float64 # the number of experiments
struct BinomialStats{N<:Real} <: SufficientStats
ns::N # the total number of successes
ne::N # the number of experiments
n::Int # the number of trials in each experiment

BinomialStats(ns::Real, ne::Real, n::Integer) = new(ns, ne, n)
end

function suffstats(::Type{<:Binomial}, n::Integer, x::AbstractArray{T}) where T<:Integer
ns = zero(T)
for i = 1:length(x)
@inbounds xi = x[i]
0 <= xi <= n || throw(DomainError())
BinomialStats(ns::Real, ne::Real, n::Integer) = BinomialStats(promote(ns, ne)..., Int(n))

function suffstats(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer})
z = zero(eltype(x))
ns = z + z # possibly widened and different from `z`, e.g., if `z = true`
for xi in x
0 <= xi <= n || throw(DomainError(xi, "samples must be between 0 and $n"))
ns += xi
end
BinomialStats(ns, length(x), n)
end

function suffstats(::Type{<:Binomial}, n::Integer, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Integer
ns = 0.
ne = 0.
for i = 1:length(x)
@inbounds xi = x[i]
@inbounds wi = w[i]
0 <= xi <= n || throw(DomainError())
function suffstats(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real})
z = zero(eltype(x)) * zero(eltype(w))
ns = ne = z + z # possibly widened and different from `z`, e.g., if `z = true`
for (xi, wi) in zip(x, w)
0 <= xi <= n || throw(DomainError(xi, "samples must be between 0 and $n"))
ns += xi * wi
ne += wi
end
Expand All @@ -190,14 +188,14 @@ end
const BinomData = Tuple{Int, AbstractArray}

suffstats(::Type{<:Binomial}, data::BinomData) = suffstats(Binomial, data...)
suffstats(::Type{<:Binomial}, data::BinomData, w::AbstractArray{Float64}) = suffstats(Binomial, data..., w)
suffstats(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = suffstats(Binomial, data..., w)

fit_mle(::Type{<:Binomial}, ss::BinomialStats) = Binomial(ss.n, ss.ns / (ss.ne * ss.n))

fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{T}) where {T<:Integer} = fit_mle(Binomial, suffstats(Binomial, n, x))
fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{T}, w::AbstractArray{Float64}) where {T<:Integer} = fit_mle(Binomial, suffstats(Binomial, n, x, w))
fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}) = fit_mle(Binomial, suffstats(Binomial, n, x))
fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, n, x, w))
fit_mle(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, suffstats(Binomial, data))
fit_mle(::Type{<:Binomial}, data::BinomData, w::AbstractArray{Float64}) = fit_mle(Binomial, suffstats(Binomial, data, w))
fit_mle(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, data, w))

fit(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, data)
fit(::Type{<:Binomial}, data::BinomData, w::AbstractArray{Float64}) = fit_mle(Binomial, data, w)
fit(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, data, w)
122 changes: 64 additions & 58 deletions test/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#

using Distributions
using OffsetArrays
using Test, Random, LinearAlgebra


Expand Down Expand Up @@ -33,33 +34,35 @@ end


@testset "Testing fit for Bernoulli" begin
for func in funcs, dist in (Bernoulli, Bernoulli{Float64})
w = func[1](n0)
x = func[2](dist(0.7), n0)

ss = suffstats(dist, x)
@test isa(ss, Distributions.BernoulliStats)
@test ss.cnt0 == n0 - count(t->t != 0, x)
@test ss.cnt1 == count(t->t != 0, x)

ss = suffstats(dist, x, w)
@test isa(ss, Distributions.BernoulliStats)
@test ss.cnt0 sum(w[x .== 0])
@test ss.cnt1 sum(w[x .== 1])

d = fit(dist, x)
p = count(t->t != 0, x) / n0
@test isa(d, dist)
@test mean(d) p

d = fit(dist, x, w)
p = sum(w[x .== 1]) / sum(w)
@test isa(d, dist)
@test mean(d) p

d = fit(dist, func[2](dist(0.7), N))
@test isa(d, dist)
@test isapprox(mean(d), 0.7, atol=0.01)
for rng in ((), (rng,)), D in (Bernoulli, Bernoulli{Float64})
v = rand(rng..., n0)
z = rand(rng..., D(0.7), n0)
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
ss = @inferred suffstats(D, x)
@test ss isa Distributions.BernoulliStats
@test ss.cnt0 == n0 - count(t->t != 0, z)
@test ss.cnt1 == count(t->t != 0, z)

ss = @inferred suffstats(D, x, w)
@test ss isa Distributions.BernoulliStats
@test ss.cnt0 sum(v[z .== 0])
@test ss.cnt1 sum(v[z .== 1])

d = @inferred fit(D, x)
@test d isa D
@test mean(d) count(t->t != 0, z) / n0

d = @inferred fit(D, x, w)
@test d isa D
@test mean(d) sum(v[z .== 1]) / sum(v)
end

z = rand(rng..., D(0.7), N)
for x in (z, OffsetArray(z, -N ÷ 2))
d = @inferred fit(D, x)
@test d isa D
@test mean(d) 0.7 atol=0.01
end
end
end

Expand All @@ -79,37 +82,40 @@ end
end

@testset "Testing fit for Binomial" begin
for func in funcs, dist in (Binomial, Binomial{Float64})
w = func[1](n0)

x = func[2](dist(100, 0.3), n0)

ss = suffstats(dist, (100, x))
@test isa(ss, Distributions.BinomialStats)
@test ss.ns sum(x)
@test ss.ne == n0
@test ss.n == 100

ss = suffstats(dist, (100, x), w)
@test isa(ss, Distributions.BinomialStats)
@test ss.ns dot(Float64[xx for xx in x], w)
@test ss.ne sum(w)
@test ss.n == 100

d = fit(dist, (100, x))
@test isa(d, dist)
@test ntrials(d) == 100
@test succprob(d) sum(x) / (n0 * 100)

d = fit(dist, (100, x), w)
@test isa(d, dist)
@test ntrials(d) == 100
@test succprob(d) dot(x, w) / (sum(w) * 100)

d = fit(dist, 100, func[2](dist(100, 0.3), N))
@test isa(d, dist)
@test ntrials(d) == 100
@test isapprox(succprob(d), 0.3, atol=0.01)
for rng in ((), (rng,)), D in (Binomial, Binomial{Float64})
v = rand(rng..., n0)
z = rand(rng..., D(100, 0.3), n0)
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
ss = @inferred suffstats(D, (100, x))
@test ss isa Distributions.BinomialStats
@test ss.ns sum(z)
@test ss.ne == n0
@test ss.n == 100

ss = @inferred suffstats(D, (100, x), w)
@test ss isa Distributions.BinomialStats
@test ss.ns dot(z, v)
@test ss.ne sum(v)
@test ss.n == 100

d = @inferred fit(D, (100, x))
@test d isa D
@test ntrials(d) == 100
@test succprob(d) sum(z) / (n0 * 100)

d = @inferred fit(D, (100, x), w)
@test d isa D
@test ntrials(d) == 100
@test succprob(d) dot(z, v) / (sum(v) * 100)
end

z = rand(rng..., D(100, 0.3), N)
for x in (z, OffsetArray(z, -N ÷ 2))
d = @inferred fit(D, 100, x)
@test d isa D
@test ntrials(d) == 100
@test succprob(d) 0.3 atol=0.01
end
end
end

Expand Down

2 comments on commit a350622

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61155

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.25.61 -m "<description of version>" a350622d442c8cd4ec610dce975621ba053d03d9
git push origin v0.25.61

Please sign in to comment.