From f4dc84d7a0db4d2fd98658752c209c1242fa0b32 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 26 Apr 2021 11:54:49 +0200 Subject: [PATCH] Define `Random.Sampler` --- src/genericrand.jl | 4 ++++ test/samplers.jl | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/genericrand.jl b/src/genericrand.jl index d38f9b6e8..f54c2748e 100644 --- a/src/genericrand.jl +++ b/src/genericrand.jl @@ -55,3 +55,7 @@ with this `sampler` method, which would be used for batch sampling. The general fallback is `sampler(d::Distribution) = d`. """ sampler(s::Sampleable) = s + +# Random API +Random.Sampler(::Type{<:AbstractRNG}, s::Sampleable, ::Val{1}) = s +Random.Sampler(::Type{<:AbstractRNG}, s::Sampleable, ::Val{Inf}) = sampler(s) diff --git a/test/samplers.jl b/test/samplers.jl index 4ae17941a..0555b1d63 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -100,4 +100,20 @@ import Distributions: test_samples(S(d), d, n_tsamples, rng=rng) end end + + @testset "Random.Sampler" begin + for dist in ( + Binomial(5, 0.3), + Exponential(2.0), + Gamma(0.1, 1.0), + Gamma(2.0, 1.0), + MatrixNormal(3, 4), + MvNormal(3, 1.0), + Normal(1.5, 2.0), + Poisson(0.5), + ) + @test Random.Sampler(rng, dist, Val(1)) == dist + @test Random.Sampler(rng, dist) == sampler(dist) + end + end end