Skip to content

Commit

Permalink
Define Random.Sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Apr 26, 2021
1 parent a36c613 commit f4dc84d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/genericrand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ with this `sampler` method, which would be used for batch sampling.
The general fallback is `sampler(d::Distribution) = d`.
"""
sampler(s::Sampleable) = s

# Random API
Random.Sampler(::Type{<:AbstractRNG}, s::Sampleable, ::Val{1}) = s
Random.Sampler(::Type{<:AbstractRNG}, s::Sampleable, ::Val{Inf}) = sampler(s)
16 changes: 16 additions & 0 deletions test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,20 @@ import Distributions:
test_samples(S(d), d, n_tsamples, rng=rng)
end
end

@testset "Random.Sampler" begin
for dist in (
Binomial(5, 0.3),
Exponential(2.0),
Gamma(0.1, 1.0),
Gamma(2.0, 1.0),
MatrixNormal(3, 4),
MvNormal(3, 1.0),
Normal(1.5, 2.0),
Poisson(0.5),
)
@test Random.Sampler(rng, dist, Val(1)) == dist
@test Random.Sampler(rng, dist) == sampler(dist)
end
end
end

0 comments on commit f4dc84d

Please sign in to comment.