From ef42afb215cb08304dd64e309da8002c333caef0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 17 May 2023 11:35:46 +0200 Subject: [PATCH] Fix inference failures (#1722) * Fix test failures * Add tests * Bump version --- Project.toml | 2 +- src/truncated/normal.jl | 14 +++++++------- test/truncated/normal.jl | 2 ++ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 497eefb4f9..2d891419ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.92" +version = "0.25.93" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/truncated/normal.jl b/src/truncated/normal.jl index 6b79b5dae3..a3ff33e1e1 100644 --- a/src/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -12,9 +12,9 @@ TruncatedNormal ### statistics -function mode(d::Truncated{<:Normal{<:Real},Continuous}) +function mode(d::Truncated{<:Normal{<:Real},Continuous,T}) where {T<:Real} μ = mean(d.untruncated) - return clamp(μ, extrema(d)...) + return T(clamp(μ, extrema(d)...)) end modes(d::Truncated{<:Normal{<:Real},Continuous}) = [mode(d)] @@ -89,7 +89,7 @@ function _tnvar(a::Real, b::Real) end end -function mean(d::Truncated{<:Normal{<:Real},Continuous}) +function mean(d::Truncated{<:Normal{<:Real},Continuous,T}) where {T<:Real} d0 = d.untruncated μ = mean(d0) σ = std(d0) @@ -99,21 +99,21 @@ function mean(d::Truncated{<:Normal{<:Real},Continuous}) lower, upper = extrema(d) a = (lower - μ) / σ b = (upper - μ) / σ - return μ + _tnmom1(a, b) * σ + return T(μ + _tnmom1(a, b) * σ) end end -function var(d::Truncated{<:Normal{<:Real},Continuous}) +function var(d::Truncated{<:Normal{<:Real},Continuous,T}) where {T<:Real} d0 = d.untruncated μ = mean(d0) σ = std(d0) if iszero(σ) - return σ + return T(σ) else lower, upper = extrema(d) a = (lower - μ) / σ b = (upper - μ) / σ - return _tnvar(a, b) * σ^2 + return T(_tnvar(a, b) * σ^2) end end diff --git a/test/truncated/normal.jl b/test/truncated/normal.jl index 7ed4b8f900..9d287c0262 100644 --- a/test/truncated/normal.jl +++ b/test/truncated/normal.jl @@ -31,8 +31,10 @@ rng = MersenneTwister(123) # Type stability for T in (Float32, Float64) t = truncated(Normal(T(1.5), T(4.1)), 0, 1) + m = @inferred mode(t) μ = @inferred mean(t) σ = @inferred std(t) + @test m === T(1) @test μ ≈ 0.50494725270783081889610661619986770485973643194141 @test μ isa T @test σ ≈ 0.28836356398830993140576947440881738258157196701554