Skip to content

Commit

Permalink
Fix rand(::Beta) inconsistencies
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Oct 2, 2024
1 parent 0ea5502 commit 17154a2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 22 deletions.
45 changes: 25 additions & 20 deletions src/univariate/continuous/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,16 @@ struct BetaSampler{T<:Real, S1 <: Sampleable{Univariate,Continuous},
s2::S2
end

function sampler(d::Beta{T}) where T
(α, β) = params(d)
if (α 1.0) && (β 1.0)
function sampler(d::Beta)
α, β = params(d)
if α 1 && β 1
return BetaSampler(false, inv(α), inv(β),
sampler(Uniform()), sampler(Uniform()))
sampler(Uniform(zero(α), oneunit(α))),
sampler(Uniform(zero(β), oneunit(β))))
else
return BetaSampler(true, inv(α), inv(β),
sampler(Gamma(α, one(T))),
sampler(Gamma(β, one(T))))
sampler(Gamma(α, oneunit))),
sampler(Gamma(β, oneunit))))
end
end

Expand All @@ -160,11 +161,11 @@ function rand(rng::AbstractRNG, s::BetaSampler)
= s.
= s.
while true
u = rand(rng) # the Uniform sampler just calls rand()
v = rand(rng)
u = rand(rng, s.s1) # the Uniform sampler just calls rand()
v = rand(rng, s.s2)
x = u^
y = v^
if x + y one(x)
if x + y 1
if (x + y > 0)
return x / (x + y)
else
Expand All @@ -180,16 +181,20 @@ function rand(rng::AbstractRNG, s::BetaSampler)
end
end

function rand(rng::AbstractRNG, d::Beta{T}) where T
(α, β) = params(d)
if 1.0) && 1.0)
function rand(rng::AbstractRNG, d::Beta)
α, β = params(d)
if α 1 && β 1
= inv(α)
= inv(β)
Tu = typeof(float(iα))
Tv = typeof(float(iβ))
while true
u = rand(rng)
v = rand(rng)
x = u^inv(α)
y = v^inv(β)
if x + y one(x)
if (x + y > 0)
u = rand(rng, Tu)
v = rand(rng, Tv)
x = u^
y = v^
if x + y 1
if x + y > 0
return x / (x + y)
else
logX = log(u) / α
Expand All @@ -202,8 +207,8 @@ function rand(rng::AbstractRNG, d::Beta{T}) where T
end
end
else
g1 = rand(rng, Gamma(α, one(T)))
g2 = rand(rng, Gamma(β, one(T)))
g1 = rand(rng, Gamma(α, oneunit)))
g2 = rand(rng, Gamma(β, oneunit)))
return g1 / (g1 + g2)
end
end
Expand Down
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const tests = [
"univariate/discrete/poisson",
"univariate/discrete/soliton",
"univariate/continuous/skewnormal",
"univariate/continuous/beta",
"univariate/continuous/chi",
"univariate/continuous/chisq",
"univariate/continuous/erlang",
Expand Down Expand Up @@ -129,8 +130,6 @@ const tests = [
# "samplers/vonmisesfisher",
# "show",
# "truncated/loguniform",
# "univariate/continuous/beta",
# "univariate/continuous/beta",
# "univariate/continuous/betaprime",
# "univariate/continuous/biweight",
# "univariate/continuous/cosine",
Expand Down
19 changes: 19 additions & 0 deletions test/univariate/continuous/beta.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Distributions
using Test

@testset "beta.jl" begin
# issue #1907
@testset "rand consistency" begin
for T in (Float32, Float64)
@test @inferred(rand(Beta(T(1), T(1)))) isa T
@test @inferred(rand(Beta(T(4//5), T(4//5)))) isa T
@test @inferred(rand(Beta(T(1), T(2)))) isa T
@test @inferred(rand(Beta(T(2), T(1)))) isa T

@test @inferred(eltype(rand(Beta(T(1), T(1)), 2))) === T
@test @inferred(eltype(rand(Beta(T(4//5), T(4//5)), 2))) === T
@test @inferred(eltype(rand(Beta(T(1), T(2)), 2))) === T
@test @inferred(eltype(rand(Beta(T(2), T(1)), 2))) === T
end
end
end

0 comments on commit 17154a2

Please sign in to comment.