diff --git a/src/deprecates.jl b/src/deprecates.jl index 5c0cc11b1..9b36f4c60 100644 --- a/src/deprecates.jl +++ b/src/deprecates.jl @@ -47,6 +47,7 @@ end @deprecate Wishart(df::Real, S::Matrix, warn::Bool) Wishart(df, S) @deprecate Wishart(df::Real, S::Cholesky, warn::Bool) Wishart(df, S) -# Deprecate 3 arguments expectation -@deprecate expectation(distr::DiscreteUnivariateDistribution, g::Function, epsilon::Real) expectation(distr, g; epsilon=epsilon) false -@deprecate expectation(distr::ContinuousUnivariateDistribution, g::Function, epsilon::Real) expectation(distr, g) false +# Deprecate 3 arguments expectation and once with function in second place +@deprecate expectation(distr::DiscreteUnivariateDistribution, g::Function, epsilon::Real) expectation(g, distr; epsilon=epsilon) false +@deprecate expectation(distr::ContinuousUnivariateDistribution, g::Function, epsilon::Real) expectation(g, distr) false +@deprecate expectation(distr::Union{UnivariateDistribution,MultivariateDistribution}, g::Function; kwargs...) expectation(g, distr; kwargs...) false diff --git a/src/functionals.jl b/src/functionals.jl index 57e72c88b..2fc2ee47a 100644 --- a/src/functionals.jl +++ b/src/functionals.jl @@ -1,9 +1,9 @@ -function expectation(distr::ContinuousUnivariateDistribution, g::Function; kwargs...) +function expectation(g, distr::ContinuousUnivariateDistribution; kwargs...) return first(quadgk(x -> pdf(distr, x) * g(x), extrema(distr)...; kwargs...)) end ## Assuming that discrete distributions only take integer values. -function expectation(distr::DiscreteUnivariateDistribution, g::Function; epsilon::Real=1e-10) +function expectation(g, distr::DiscreteUnivariateDistribution; epsilon::Real=1e-10) mindist, maxdist = extrema(distr) # We want to avoid taking values up to infinity minval = isfinite(mindist) ? mindist : quantile(distr, epsilon) @@ -11,7 +11,7 @@ function expectation(distr::DiscreteUnivariateDistribution, g::Function; epsilon return sum(x -> pdf(distr, x) * g(x), minval:maxval) end -function expectation(distr::MultivariateDistribution, g::Function; nsamples::Int=100, rng::AbstractRNG=GLOBAL_RNG) +function expectation(g, distr::MultivariateDistribution; nsamples::Int=100, rng::AbstractRNG=GLOBAL_RNG) nsamples > 0 || throw(ArgumentError("number of samples should be > 0")) # We use a function barrier to work around type instability of `sampler(dist)` return mcexpectation(rng, g, sampler(distr), nsamples) @@ -27,9 +27,8 @@ mcexpectation(rng, f, sampler, n) = sum(f, rand(rng, sampler) for _ in 1:n) / n # end function kldivergence(P::Distribution{V}, Q::Distribution{V}; kwargs...) where {V<:VariateForm} - function logdiff(x) + return expectation(P; kwargs...) do x logp = logpdf(P, x) return (logp > oftype(logp, -Inf)) * (logp - logpdf(Q, x)) end - expectation(P, logdiff; kwargs...) -end \ No newline at end of file +end diff --git a/test/binomial.jl b/test/binomial.jl index 10dedb8ce..9fd3251f2 100644 --- a/test/binomial.jl +++ b/test/binomial.jl @@ -23,8 +23,8 @@ for (p, n) in [(0.6, 10), (0.8, 6), (0.5, 40), (0.04, 20), (1., 100), (0., 10), end # Test calculation of expectation value for Binomial distribution -@test Distributions.expectation(Binomial(6), identity) ≈ 3.0 -@test Distributions.expectation(Binomial(10, 0.2), x->-x) ≈ -2.0 +@test Distributions.expectation(identity, Binomial(6)) ≈ 3.0 +@test Distributions.expectation(x -> -x, Binomial(10, 0.2)) ≈ -2.0 # Test mode @test Distributions.mode(Binomial(100, 0.4)) == 40 diff --git a/test/functionals.jl b/test/functionals.jl index 9c8007ad3..16cbadee3 100644 --- a/test/functionals.jl +++ b/test/functionals.jl @@ -24,13 +24,18 @@ end @testset "Expectations" begin # univariate distributions for d in (Normal(), Poisson(2.0), Binomial(10, 0.4)) - @test Distributions.expectation(d, identity) ≈ mean(d) atol=1e-3 - @test @test_deprecated(Distributions.expectation(d, identity, 1e-10)) ≈ mean(d) atol=1e-3 + m = Distributions.expectation(identity, d) + @test m ≈ mean(d) atol=1e-3 + @test Distributions.expectation(x -> (x - mean(d))^2, d) ≈ var(d) atol=1e-3 + + @test @test_deprecated(Distributions.expectation(d, identity, 1e-10)) == m + @test @test_deprecated(Distributions.expectation(d, identity)) == m end # multivariate distribution d = MvNormal([1.5, -0.5], I) - @test Distributions.expectation(d, identity; nsamples=10_000) ≈ mean(d) atol=1e-2 + @test Distributions.expectation(identity, d; nsamples=10_000) ≈ mean(d) atol=5e-2 + @test @test_deprecated(Distributions.expectation(d, identity; nsamples=10_000)) ≈ mean(d) atol=5e-2 end @testset "KL divergences" begin diff --git a/test/loguniform.jl b/test/loguniform.jl index 649c147b9..11975780a 100644 --- a/test/loguniform.jl +++ b/test/loguniform.jl @@ -59,7 +59,7 @@ import Random x = rand(rng, dist) @test cdf(u, log(x)) ≈ cdf(dist, x) - @test @inferred(entropy(dist)) ≈ Distributions.expectation(dist, x->-logpdf(dist,x)) + @test @inferred(entropy(dist)) ≈ Distributions.expectation(x->-logpdf(dist,x), dist) end @test kldivergence(LogUniform(1,2), LogUniform(1,2)) ≈ 0 atol=100eps(Float64)