Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the constructor of DiscreteNonParametric #1908

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions src/univariate/discrete/discretenonparametric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,47 @@ struct DiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractV

function DiscreteNonParametric{T,P,Ts,Ps}(xs::Ts, ps::Ps; check_args::Bool=true) where {
T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}}
check_args || return new{T,P,Ts,Ps}(xs, ps)
let xs = xs, ps = ps
@check_args(
DiscreteNonParametric,
(length(xs) == length(ps), "length of support and probability vector must be equal"),
(ps, isprobvec(ps), "vector is not a probability vector"),
(xs, issorted_allunique(xs), "support must be sorted and contain only unique elements"),
)
end
new{T,P,Ts,Ps}(xs, ps)
end
end

function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {T<:Real,P<:Real}
# These checks are performed before sorting the support since we do not want to throw a `BoundsError` when the lengths do not match
let xs = xs, ps = ps
@check_args(
DiscreteNonParametric,
(length(xs) == length(ps), "length of support and probability vector must be equal"),
(ps, isprobvec(ps), "vector is not a probability vector"),
(xs, allunique(xs), "support must contain only unique elements"),
)
end
# We always sort the support unless it can be deduced from the type of the support that it is sorted.
# Sorting can be skipped for all inputs by using the inner constructor.
if xs isa AbstractUnitRange
sortedxs = xs
sortedps = ps
else
sort_order = sortperm(xs)
new{T,P,Ts,Ps}(xs[sort_order], ps[sort_order])
sortedxs = xs[sort_order]
sortedps = ps[sort_order]
# It is more efficient to perform this check once the array is sorted
let sortedxs = sortedxs
@check_args(
DiscreteNonParametric,
(sortedxs, issorted_allunique(sortedxs), "support must contain only unique elements"),
)
end
end
return DiscreteNonParametric{T,P,typeof(sortedxs),typeof(sortedps)}(sortedxs, sortedps; check_args=false)
end

DiscreteNonParametric(vs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {
T<:Real,P<:Real} =
DiscreteNonParametric{T,P,typeof(vs),typeof(ps)}(vs, ps; check_args=check_args)

Base.eltype(::Type{<:DiscreteNonParametric{T}}) where T = T

# Conversion
Expand Down
17 changes: 17 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ isunitvec(v::AbstractVector) = (norm(v) - 1.0) < 1.0e-12
isprobvec(p::AbstractVector{<:Real}) =
all(x -> x ≥ zero(x), p) && isapprox(sum(p), one(eltype(p)))

issorted_allunique(xs::AbstractUnitRange{<:Real}) = true
function issorted_allunique(xs::AbstractVector{<:Real})
xi_state = iterate(xs)
if xi_state === nothing
return true
end
xi, state = xi_state
while (xj_state = iterate(xs, state)) !== nothing
xj, state = xj_state
if xj <= xi
return false
end
xi = xj
end
return true
end

# get a type wide enough to represent all a distributions's parameters
# (if the distribution is parametric)
# if the distribution is not parametric, we need this to be a float so that
Expand Down
17 changes: 17 additions & 0 deletions test/univariate/discrete/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,21 @@ end
@test count(==(1e8), priorities[iat]) >= 13
end

@testset "AbstractVector" begin
# issue #1084
P = abs.(randn(5,4,2))
p = view(P,:,1,1)
p ./= sum(p)
d = @inferred(Categorical(p))
@test d isa Categorical{Float64, typeof(p)}
@test d.p === p

# #1832
x = rand(3,5)
x ./= sum(x; dims=1)
c = Categorical.(eachcol(x))
@test c isa Vector{<:Categorical}
@test all(ci.p isa SubArray for ci in c)
end

end
45 changes: 43 additions & 2 deletions test/univariate/discrete/discretenonparametric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ rng = MersenneTwister(123)

d = DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2])

@test !(d ≈ DiscreteNonParametric([40., 80, 120, -60], [.4, .3, .1, .2], check_args=false))
# In the outer constructor, the support is always sorted, regardless of whether `check_args = false` or `check_args = true`
@test d ≈ DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2], check_args=false)
@test d ≈ DiscreteNonParametric([-60., 40., 80, 120], [.2, .4, .3, .1], check_args=false)

# Invalid probability
Expand All @@ -23,6 +24,25 @@ d = DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2])
# Invalid probability, but no arg check
DiscreteNonParametric([40., 80, 120, -60], [.5, .3, .1, .2], check_args=false)

# Invalid support
@test_throws DomainError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([40., 80, 120, -60], [.4, .3, .1, .2])
@test_throws DomainError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3, .1])
@test_throws DomainError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3, .1])

# Invalid support but no arg check
DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([40., 80, 120, -60], [.4, .3, .1, .2], check_args=false)
DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3, .1], check_args=false)
DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3, .1], check_args=false)

# Mismatch between support and probabilities
@test_throws ArgumentError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3])
@test_throws ArgumentError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3])

# Mismatch between support and probabilities but no arg check
@test_throws BoundsError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3], check_args=false) # sorting errors
DiscreteNonParametric(1:4, [.2, .4, .3], check_args=false) # no sorting, hence no `BoundsError`
DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3], check_args=false)

test_range(d)
vs = Distributions.get_evalsamples(d, 0.00001)
test_evaluation(d, vs, true)
Expand Down Expand Up @@ -213,4 +233,25 @@ end
# Different types
@test DiscreteNonParametric(1:2, [0.5, 0.5]) == DiscreteNonParametric([1, 2], [0.5f0, 0.5f0])
@test DiscreteNonParametric(1:2, [0.5, 0.5]) ≈ DiscreteNonParametric([1, 2], [0.5f0, 0.5f0])
end
end

@testset "AbstractVector (issue #1084)" begin
P = abs.(randn(5,4,2))
p = view(P,:,1,1)
p ./= sum(p)

d = @inferred(DiscreteNonParametric(Base.OneTo(5), p))
@test d isa DiscreteNonParametric
@test d.p === p
d = @inferred(DiscreteNonParametric(1:5, p))
@test d isa DiscreteNonParametric
@test d.p === p
d = @inferred(DiscreteNonParametric(1:1:5, p))
@test d isa DiscreteNonParametric
@test d.p !== p
@test d.p == p
d = @inferred(DiscreteNonParametric([1, 2, 3, 4, 5], p))
@test d isa DiscreteNonParametric
@test d.p !== p
@test d.p == p
end
Loading