diff --git a/src/univariate/discrete/discretenonparametric.jl b/src/univariate/discrete/discretenonparametric.jl index 8e1eefab6..8f242ec23 100644 --- a/src/univariate/discrete/discretenonparametric.jl +++ b/src/univariate/discrete/discretenonparametric.jl @@ -23,21 +23,29 @@ struct DiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractV function DiscreteNonParametric{T,P,Ts,Ps}(xs::Ts, ps::Ps; check_args::Bool=true) where { T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} - check_args || return new{T,P,Ts,Ps}(xs, ps) @check_args( DiscreteNonParametric, (length(xs) == length(ps), "length of support and probability vector must be equal"), (ps, isprobvec(ps), "vector is not a probability vector"), - (xs, allunique(xs), "support must contain only unique elements"), + (xs, issorted_allunique(xs), "support must be sorted and contain only unique elements"), ) - sort_order = sortperm(xs) - new{T,P,Ts,Ps}(xs[sort_order], ps[sort_order]) + new{T,P,Ts,Ps}(xs, ps) end end -DiscreteNonParametric(vs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where { - T<:Real,P<:Real} = - DiscreteNonParametric{T,P,typeof(vs),typeof(ps)}(vs, ps; check_args=check_args) +function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {T<:Real,P<:Real} + # We always sort the support unless it can be deduced from the type of the support that it is sorted. + # Sorting can be skipped for all inputs by using the inner constructor. + if xs isa AbstractUnitRange + sortedxs = xs + sortedps = ps + else + sort_order = sortperm(xs) + sortedxs = xs[sort_order] + sortedps = ps[sort_order] + end + return DiscreteNonParametric{T,P,typeof(sortedxs),typeof(sortedps)}(sortedxs, sortedps; check_args=check_args) +end Base.eltype(::Type{<:DiscreteNonParametric{T}}) where T = T diff --git a/src/utils.jl b/src/utils.jl index a2c9aaffa..442b863d7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -97,6 +97,23 @@ isunitvec(v::AbstractVector) = (norm(v) - 1.0) < 1.0e-12 isprobvec(p::AbstractVector{<:Real}) = all(x -> x ≥ zero(x), p) && isapprox(sum(p), one(eltype(p))) +issorted_allunique(xs::AbstractUnitRange{<:Real}) = true +function issorted_allunique(xs::AbstractVector{<:Real}) + xi_state = iterate(xs) + if xi_state === nothing + return true + end + xi, state = xi_state + while (xj_state = iterate(xs, state)) !== nothing + xj, state = xj_state + if xj <= xi + return false + end + xi = xj + end + return true +end + # get a type wide enough to represent all a distributions's parameters # (if the distribution is parametric) # if the distribution is not parametric, we need this to be a float so that diff --git a/test/univariate/discrete/categorical.jl b/test/univariate/discrete/categorical.jl index 6d87d4dc8..96425a204 100644 --- a/test/univariate/discrete/categorical.jl +++ b/test/univariate/discrete/categorical.jl @@ -137,4 +137,21 @@ end @test count(==(1e8), priorities[iat]) >= 13 end +@testset "AbstractVector" begin + # issue #1084 + P = abs.(randn(5,4,2)) + p = view(P,:,1,1) + p ./= sum(p) + d = @inferred(Categorical(p)) + @test d isa Categorical{Float64, typeof(p)} + @test d.p === p + + # #1832 + x = rand(3,5) + x ./= sum(x; dims=1) + c = Categorical.(eachcol(x)) + @test c isa Vector{<:Categorical} + @test all(ci.p isa SubArray for ci in c) +end + end diff --git a/test/univariate/discrete/discretenonparametric.jl b/test/univariate/discrete/discretenonparametric.jl index 68354a064..b81d5b52a 100644 --- a/test/univariate/discrete/discretenonparametric.jl +++ b/test/univariate/discrete/discretenonparametric.jl @@ -213,4 +213,25 @@ end # Different types @test DiscreteNonParametric(1:2, [0.5, 0.5]) == DiscreteNonParametric([1, 2], [0.5f0, 0.5f0]) @test DiscreteNonParametric(1:2, [0.5, 0.5]) ≈ DiscreteNonParametric([1, 2], [0.5f0, 0.5f0]) -end \ No newline at end of file +end + +@testset "AbstractVector (issue #1084)" begin + P = abs.(randn(5,4,2)) + p = view(P,:,1,1) + p ./= sum(p) + + d = @inferred(DiscreteNonParametric(Base.OneTo(5), p)) + @test d isa DiscreteNonParametric + @test d.p === p + d = @inferred(DiscreteNonParametric(1:5, p)) + @test d isa DiscreteNonParametric + @test d.p === p + d = @inferred(DiscreteNonParametric(1:1:5, p)) + @test d isa DiscreteNonParametric + @test d.p !== p + @test d.p == p + d = @inferred(DiscreteNonParametric([1, 2, 3, 4, 5], p)) + @test d isa DiscreteNonParametric + @test d.p !== p + @test d.p == p +end