Skip to content

Commit

Permalink
Merge pull request #24866 from JuliaLang/rf/rand/faster-dict-set
Browse files Browse the repository at this point in the history
faster rand! for Dict, Set, BitSet
  • Loading branch information
rfourquet authored Dec 4, 2017
2 parents 9b1a56e + 199073a commit c979996
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
56 changes: 33 additions & 23 deletions base/random/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,45 +244,55 @@ rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractArray,<:Sampler}) =
@inbounds return sp[][rand(rng, sp.data)]


## random values from Dict, Set, BitSet
## random values from Dict

for x in (1, Inf) # eval because of ambiguity otherwise
for T in (Dict, Set, BitSet)
@eval Sampler(::AbstractRNG, t::$T, ::Val{$x}) = SamplerTrivial(t)
end
function Sampler(rng::AbstractRNG, t::Dict, ::Repetition)
isempty(t) && throw(ArgumentError("collection must be non-empty"))
# we use Val(Inf) below as rand is called repeatedly internally
# even for generating only one random value from t
SamplerSimple(t, Sampler(rng, linearindices(t.slots), Val(Inf)))
end

function rand(rng::AbstractRNG, sp::SamplerTrivial{<:Dict})
isempty(sp[]) && throw(ArgumentError("collection must be non-empty"))
rsp = Sampler(rng, 1:length(sp[].slots))
function rand(rng::AbstractRNG, sp::SamplerSimple{<:Dict,<:Sampler})
while true
i = rand(rng, rsp)
i = rand(rng, sp.data)
Base.isslotfilled(sp[], i) && @inbounds return (sp[].keys[i] => sp[].vals[i])
end
end

rand(rng::AbstractRNG, sp::SamplerTrivial{<:Set}) = rand(rng, sp[].dict).first
## random values from Set

Sampler(rng::AbstractRNG, t::Set, n::Repetition) = SamplerTag{Set}(Sampler(rng, t.dict, n))

rand(rng::AbstractRNG, sp::SamplerTag{Set,<:Sampler}) = rand(rng, sp.data).first

function rand(rng::AbstractRNG, sp::SamplerTrivial{BitSet})
isempty(sp[]) && throw(ArgumentError("collection must be non-empty"))
# sp[] can be empty while sp[].bits is not, so we cannot rely on the
# length check in Sampler below
rsp = Sampler(rng, 1:length(sp[].bits))
## random values from BitSet

function Sampler(rng::AbstractRNG, t::BitSet, n::Repetition)
isempty(t) && throw(ArgumentError("collection must be non-empty"))
SamplerSimple(t, Sampler(rng, linearindices(t.bits), Val(Inf)))
end

function rand(rng::AbstractRNG, sp::SamplerSimple{BitSet,<:Sampler})
while true
n = rand(rng, rsp)
n = rand(rng, sp.data)
@inbounds b = sp[].bits[n]
b && return n
end
end

## random values from Associative/AbstractSet

# avoid linear complexity for repeated calls
# we defer to _Sampler to avoid ambiguities with a call like Sampler(rng, Set(1), Val(1))
Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Repetition) =
_Sampler(rng, t, n)

# avoid linear complexity for repeated calls
_Sampler(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Val{Inf}) =
Sampler(rng, collect(t), n)

# when generating only one element, avoid the call to collect
Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) =
_Sampler(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) =
SamplerTrivial(t)

function nth(iter, n::Integer)::eltype(iter)
Expand All @@ -299,22 +309,22 @@ rand(rng::AbstractRNG, sp::SamplerTrivial{<:Union{Associative,AbstractSet}}) =

# we use collect(str), which is most of the time more efficient than specialized methods
# (except maybe for very small arrays)
Sampler(rng::AbstractRNG, str::AbstractString, n::Repetition) = Sampler(rng, collect(str), n)
Sampler(rng::AbstractRNG, str::AbstractString, n::Val{Inf}) = Sampler(rng, collect(str), n)

# when generating only one char from a string, the specialized method below
# is usually more efficient
Sampler(::AbstractRNG, str::AbstractString, ::Val{1}) = SamplerTrivial(str)
Sampler(rng::AbstractRNG, str::AbstractString, ::Val{1}) =
SamplerSimple(str, Sampler(rng, 1:_endof(str), Val(Inf)))

isvalid_unsafe(s::String, i) = !Base.is_valid_continuation(Base.@gc_preserve s unsafe_load(pointer(s), i))
isvalid_unsafe(s::AbstractString, i) = isvalid(s, i)
_endof(s::String) = sizeof(s)
_endof(s::AbstractString) = endof(s)

function rand(rng::AbstractRNG, sp::SamplerTrivial{<:AbstractString})::Char
function rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractString,<:Sampler})::Char
str = sp[]
sp_pos = Sampler(rng, 1:_endof(str))
while true
pos = rand(rng, sp_pos)
pos = rand(rng, sp.data)
isvalid_unsafe(str, pos) && return str[pos]
end
end
6 changes: 6 additions & 0 deletions base/random/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ end

Base.getindex(sp::SamplerSimple) = sp.self

# simple sampler carrying a (type) tag T and data
struct SamplerTag{T,S} <: Sampler
data::S
SamplerTag{T}(s::S) where {T,S} = new{T,S}(s)
end


### machinery for generation with Sampler

Expand Down

0 comments on commit c979996

Please sign in to comment.