Skip to content

Commit

Permalink
Fix the constructor of DiscreteNonParametric
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Oct 2, 2024
1 parent a1010e4 commit b5e13af
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 8 deletions.
22 changes: 15 additions & 7 deletions src/univariate/discrete/discretenonparametric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,29 @@ struct DiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractV

function DiscreteNonParametric{T,P,Ts,Ps}(xs::Ts, ps::Ps; check_args::Bool=true) where {
T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}}
check_args || return new{T,P,Ts,Ps}(xs, ps)
@check_args(
DiscreteNonParametric,
(length(xs) == length(ps), "length of support and probability vector must be equal"),
(ps, isprobvec(ps), "vector is not a probability vector"),
(xs, allunique(xs), "support must contain only unique elements"),
(xs, issorted_allunique(xs), "support must be sorted and contain only unique elements"),
)
sort_order = sortperm(xs)
new{T,P,Ts,Ps}(xs[sort_order], ps[sort_order])
new{T,P,Ts,Ps}(xs, ps)
end
end

DiscreteNonParametric(vs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {
T<:Real,P<:Real} =
DiscreteNonParametric{T,P,typeof(vs),typeof(ps)}(vs, ps; check_args=check_args)
function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {T<:Real,P<:Real}
# We always sort the support unless it can be deduced from the type of the support that it is sorted.
# Sorting can be skipped for all inputs by using the inner constructor.
if xs isa AbstractUnitRange
sortedxs = xs
sortedps = ps
else
sort_order = sortperm(xs)
sortedxs = xs[sort_order]
sortedps = ps[sort_order]
end
return DiscreteNonParametric{T,P,typeof(sortedxs),typeof(sortedps)}(sortedxs, sortedps; check_args=check_args)
end

Base.eltype(::Type{<:DiscreteNonParametric{T}}) where T = T

Expand Down
17 changes: 17 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ isunitvec(v::AbstractVector) = (norm(v) - 1.0) < 1.0e-12
isprobvec(p::AbstractVector{<:Real}) =
all(x -> x zero(x), p) && isapprox(sum(p), one(eltype(p)))

issorted_allunique(xs::AbstractUnitRange{<:Real}) = true
function issorted_allunique(xs::AbstractVector{<:Real})
xi_state = iterate(xs)
if xi_state === nothing
return true
end
xi, state = xi_state
while (xj_state = iterate(xs, state)) !== nothing
xj, state = xj_state
if xj <= xi
return false
end
xi = xj
end
return true
end

# get a type wide enough to represent all a distributions's parameters
# (if the distribution is parametric)
# if the distribution is not parametric, we need this to be a float so that
Expand Down
17 changes: 17 additions & 0 deletions test/univariate/discrete/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,21 @@ end
@test count(==(1e8), priorities[iat]) >= 13
end

@testset "AbstractVector" begin
# issue #1084
P = abs.(randn(5,4,2))
p = view(P,:,1,1)
p ./= sum(p)
d = @inferred(Categorical(p))
@test d isa Categorical{Float64, typeof(p)}
@test d.p === p

# #1832
x = rand(3,5)
x ./= sum(x; dims=1)
c = Categorical.(eachcol(x))
@test c isa Vector{<:Categorical}
@test all(ci.p isa SubArray for ci in c)
end

end
23 changes: 22 additions & 1 deletion test/univariate/discrete/discretenonparametric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,25 @@ end
# Different types
@test DiscreteNonParametric(1:2, [0.5, 0.5]) == DiscreteNonParametric([1, 2], [0.5f0, 0.5f0])
@test DiscreteNonParametric(1:2, [0.5, 0.5]) DiscreteNonParametric([1, 2], [0.5f0, 0.5f0])
end
end

@testset "AbstractVector (issue #1084)" begin
P = abs.(randn(5,4,2))
p = view(P,:,1,1)
p ./= sum(p)

d = @inferred(DiscreteNonParametric(Base.OneTo(5), p))
@test d isa DiscreteNonParametric
@test d.p === p
d = @inferred(DiscreteNonParametric(1:5, p))
@test d isa DiscreteNonParametric
@test d.p === p
d = @inferred(DiscreteNonParametric(1:1:5, p))
@test d isa DiscreteNonParametric
@test d.p !== p
@test d.p == p
d = @inferred(DiscreteNonParametric([1, 2, 3, 4, 5], p))
@test d isa DiscreteNonParametric
@test d.p !== p
@test d.p == p
end

0 comments on commit b5e13af

Please sign in to comment.