Skip to content

Commit

Permalink
Fix known stats and minimize diff
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jan 13, 2024
1 parent 27ea06f commit de449b2
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions src/univariate/continuous/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct Normal{T<:Real} <: ContinuousUnivariateDistribution
Normal{T}::T, σ::T) where {T<:Real} = new{T}(µ, σ)
end

function Normal::T, σ::T; check_args::Bool=true) where {T <: Real}
function Normal::T, σ::T; check_args::Bool=true) where {T<:Real}
@check_args Normal (σ, σ >= zero(σ))
return Normal{T}(μ, σ)
end
Expand Down Expand Up @@ -120,7 +120,7 @@ rand!(rng::AbstractRNG, d::Normal, A::AbstractArray{<:Real}) = A .= muladd.(d.σ

#### Fitting

struct NormalStats{T} <: SufficientStats
struct NormalStats{T<:Real} <: SufficientStats
s::T # (weighted) sum of x
m::T # (weighted) mean of x
s2::T # (weighted) sum of (x - μ)^2
Expand All @@ -131,31 +131,30 @@ struct NormalStats{T} <: SufficientStats
end
end

function suffstats(::Type{<:Normal}, x::AbstractArray{<:Real})
function suffstats(::Type{<:Normal}, x::AbstractArray{T}) where T<:Real
n = length(x)

# compute s
s = zero(T) + zero(T)
s = zero(T)
for i in eachindex(x)
@inbounds s += x[i]
end
m = s / n

# compute s2
s2 = zero(m)
s2 = zero(T)
for i in eachindex(x)
@inbounds s2 += abs2(x[i] - m)
end

NormalStats(s, m, s2, n)
end

function suffstats(::Type{<:Normal}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
n = length(x)

function suffstats(::Type{<:Normal}, x::AbstractArray{T1}, w::AbstractArray{T2}) where {T1<:Real,T2<:Real}
T = promote_type(T1, T2)
# compute s
tw = 0.0
s = 0.0 * zero(T)
tw = zero(T)
s = zero(T)
for i in eachindex(x, w)
@inbounds wi = w[i]
@inbounds s += wi * x[i]
Expand All @@ -164,7 +163,7 @@ function suffstats(::Type{<:Normal}, x::AbstractArray{<:Real}, w::AbstractArray{
m = s / tw

# compute s2
s2 = zero(m)
s2 = zero(T)
for i in eachindex(x, w)
@inbounds s2 += w[i] * abs2(x[i] - m)
end
Expand All @@ -174,11 +173,11 @@ end

# Cases where μ or σ is known

struct NormalKnownMu{T} <: IncompleteDistribution
struct NormalKnownMu{T<:Real} <: IncompleteDistribution
μ::T
end

struct NormalKnownMuStats{T} <: SufficientStats
struct NormalKnownMuStats{T<:Real} <: SufficientStats
μ::T # known mean
s2::T # (weighted) sum of (x - μ)^2
tw::T # total sample weight
Expand All @@ -188,19 +187,21 @@ struct NormalKnownMuStats{T} <: SufficientStats
end
end

function suffstats(g::NormalKnownMu, x::AbstractArray{<:Real})
function suffstats(g::NormalKnownMu{T0}, x::AbstractArray{T1}) where {T0,T1<:Real}
T = promote_type(T0, T1)
μ = g.μ
s2 = zero(T) + zero(μ)
s2 = zero(T)
for i in eachindex(x)
@inbounds s2 += abs2(x[i] - μ)
end
NormalKnownMuStats(g.μ, s2, length(x))
end

function suffstats(g::NormalKnownMu, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
function suffstats(g::NormalKnownMu{T0}, x::AbstractArray{T1}, w::AbstractArray{T2}) where {T0,T1<:Real,T2<:Real}
T = promote_type(T0, T1, T2)
μ = g.μ
s2 = 0.0 * abs2(zero(T) - zero(μ))
tw = 0.0
s2 = zero(T)
tw = zero(T)
for i in eachindex(x, w)
@inbounds wi = w[i]
@inbounds s2 += abs2(x[i] - μ) * wi
Expand All @@ -209,15 +210,15 @@ function suffstats(g::NormalKnownMu, x::AbstractArray{<:Real}, w::AbstractArray{
NormalKnownMuStats(g.μ, s2, tw)
end

struct NormalKnownSigma{T} <: IncompleteDistribution
struct NormalKnownSigma{T<:Real} <: IncompleteDistribution
σ::T
function NormalKnownSigma::T) where {T}
σ > 0 || throw(ArgumentError("σ must be a positive value."))
return new{T}(σ)
end
end

struct NormalKnownSigmaStats{T} <: SufficientStats
struct NormalKnownSigmaStats{T<:Real} <: SufficientStats
σ::T # known std.dev
sx::T # (weighted) sum of x
tw::T # total sample weight
Expand All @@ -228,7 +229,7 @@ struct NormalKnownSigmaStats{T} <: SufficientStats
end

function suffstats(g::NormalKnownSigma, x::AbstractArray{<:Real})
NormalKnownSigmaStats(g.σ, sum(x), Float64(length(x)))
NormalKnownSigmaStats(g.σ, sum(x), length(x))
end

function suffstats(g::NormalKnownSigma, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
Expand All @@ -252,12 +253,12 @@ function fit_mle(
fit_mle(D, suffstats(Normal, x))
else
g = NormalKnownSigma(sigma)
fit_mle(g, suffstats(g, x))
convert(D, fit_mle(g, suffstats(g, x)))
end
else
if isnothing(sigma)
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x))
convert(D, fit_mle(g, suffstats(g, x)))
else
D(mu, sigma)
end
Expand All @@ -273,12 +274,12 @@ function fit_mle(
fit_mle(D, suffstats(Normal, x, w))
else
g = NormalKnownSigma(sigma)
fit_mle(g, suffstats(g, x, w))
convert(D, fit_mle(g, suffstats(g, x, w)))
end
else
if isnothing(sigma)
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x, w))
convert(D, fit_mle(g, suffstats(g, x, w)))
else
D(mu, sigma)
end
Expand Down

0 comments on commit de449b2

Please sign in to comment.