From f4dc84d7a0db4d2fd98658752c209c1242fa0b32 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 26 Apr 2021 11:54:49 +0200 Subject: [PATCH 1/2] Define `Random.Sampler` --- src/genericrand.jl | 4 ++++ test/samplers.jl | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/genericrand.jl b/src/genericrand.jl index d38f9b6e85..f54c2748e7 100644 --- a/src/genericrand.jl +++ b/src/genericrand.jl @@ -55,3 +55,7 @@ with this `sampler` method, which would be used for batch sampling. The general fallback is `sampler(d::Distribution) = d`. """ sampler(s::Sampleable) = s + +# Random API +Random.Sampler(::Type{<:AbstractRNG}, s::Sampleable, ::Val{1}) = s +Random.Sampler(::Type{<:AbstractRNG}, s::Sampleable, ::Val{Inf}) = sampler(s) diff --git a/test/samplers.jl b/test/samplers.jl index 4ae17941aa..0555b1d639 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -100,4 +100,20 @@ import Distributions: test_samples(S(d), d, n_tsamples, rng=rng) end end + + @testset "Random.Sampler" begin + for dist in ( + Binomial(5, 0.3), + Exponential(2.0), + Gamma(0.1, 1.0), + Gamma(2.0, 1.0), + MatrixNormal(3, 4), + MvNormal(3, 1.0), + Normal(1.5, 2.0), + Poisson(0.5), + ) + @test Random.Sampler(rng, dist, Val(1)) == dist + @test Random.Sampler(rng, dist) == sampler(dist) + end + end end From 06f3a7845fe4e4136574c3736cc49657d569b7ab Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 26 Apr 2021 11:55:14 +0200 Subject: [PATCH 2/2] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d4e1d9c1d3..85c9050a8a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.24.18" +version = "0.24.19" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"