Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ZeroVector with Zeros from FillArrays #1020

Merged
merged 2 commits into from
Dec 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["JuliaStats"]
version = "0.21.9"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -15,6 +16,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
FillArrays = "0.8"
PDMats = "0.9"
QuadGK = "2"
SpecialFunctions = "0.8"
Expand Down
2 changes: 2 additions & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import Base: size, length, convert, show, getindex, rand, vec, inv
import Base: sum, maximum, minimum, extrema, +, -, ==
import Base.Math: @horner

using FillArrays

using LinearAlgebra, Printf
import LinearAlgebra: rank

Expand Down
2 changes: 1 addition & 1 deletion src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ end

#Constructors mirror the ones for MvNormmal
MvLogNormal(μ::AbstractVector,Σ::AbstractPDMat) = MvLogNormal(MvNormal(μ,Σ))
MvLogNormal(Σ::AbstractPDMat) = MvLogNormal(MvNormal(ZeroVector(eltype(Σ),dim(Σ)),Σ))
MvLogNormal(Σ::AbstractPDMat) = MvLogNormal(MvNormal(Zeros{eltype(Σ)}(dim(Σ)),Σ))
MvLogNormal(μ::AbstractVector,Σ::Matrix) = MvLogNormal(MvNormal(μ,Σ))
MvLogNormal(μ::AbstractVector,σ::Vector) = MvLogNormal(MvNormal(μ,σ))
MvLogNormal(μ::AbstractVector,s::Real) = MvLogNormal(MvNormal(μ,s))
Expand Down
21 changes: 9 additions & 12 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ struct MvNormal{Cov<:AbstractPDMat,Mean<:AbstractVector} <: AbstractMvNormal
end
```

Here, the mean vector can be an instance of any `AbstractVector`, including `ZeroVector`.
The latter is simply an empty type indicating a vector filled with zeros. The covariance can be
Here, the mean vector can be an instance of any `AbstractVector`. The covariance can be
of any subtype of `AbstractPDMat`. Particularly, one can use `PDMat` for full covariance,
`PDiagMat` for diagonal covariance, and `ScalMat` for the isotropic covariance -- those
in the form of ``\\sigma \\mathbf{I}``. (See the Julia package
Expand All @@ -62,17 +61,17 @@ const IsoNormal = MvNormal{ScalMat, Vector{Float64}}
const DiagNormal = MvNormal{PDiagMat, Vector{Float64}}
const FullNormal = MvNormal{PDMat, Vector{Float64}}

const ZeroMeanIsoNormal = MvNormal{ScalMat, ZeroVector{Float64}}
const ZeroMeanDiagNormal = MvNormal{PDiagMat, ZeroVector{Float64}}
const ZeroMeanFullNormal = MvNormal{PDMat, ZeroVector{Float64}}
const ZeroMeanIsoNormal{Axes} = MvNormal{ScalMat, Zeros{Float64,1,Axes}}
const ZeroMeanDiagNormal{Axes} = MvNormal{PDiagMat, Zeros{Float64,1,Axes}}
const ZeroMeanFullNormal{Axes} = MvNormal{PDMat, Zeros{Float64,1,Axes}}
```
"""
abstract type AbstractMvNormal <: ContinuousMultivariateDistribution end

### Generic methods (for all AbstractMvNormal subtypes)

insupport(d::AbstractMvNormal, x::AbstractVector) =
length(d) == length(x) && allfinite(x)
length(d) == length(x) && all(isfinite, x)

mode(d::AbstractMvNormal) = mean(d)
modes(d::AbstractMvNormal) = [mean(d)]
Expand Down Expand Up @@ -181,9 +180,9 @@ const IsoNormal = MvNormal{Float64,ScalMat{Float64},Vector{Float64}}
const DiagNormal = MvNormal{Float64,PDiagMat{Float64,Vector{Float64}},Vector{Float64}}
const FullNormal = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},Vector{Float64}}

const ZeroMeanIsoNormal = MvNormal{Float64,ScalMat{Float64},ZeroVector{Float64}}
const ZeroMeanDiagNormal = MvNormal{Float64,PDiagMat{Float64,Vector{Float64}},ZeroVector{Float64}}
const ZeroMeanFullNormal = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},ZeroVector{Float64}}
const ZeroMeanIsoNormal{Axes} = MvNormal{Float64,ScalMat{Float64},Zeros{Float64,1,Axes}}
const ZeroMeanDiagNormal{Axes} = MvNormal{Float64,PDiagMat{Float64,Vector{Float64}},Zeros{Float64,1,Axes}}
const ZeroMeanFullNormal{Axes} = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},Zeros{Float64,1,Axes}}

### Construction
function MvNormal(μ::AbstractVector{T}, Σ::AbstractPDMat{T}) where {T<:Real}
Expand All @@ -196,9 +195,7 @@ function MvNormal(μ::AbstractVector, Σ::AbstractPDMat)
MvNormal(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, Σ))
end

function MvNormal(Σ::Cov) where {T, Cov<:AbstractPDMat{T}}
MvNormal{T,Cov,ZeroVector{T}}(ZeroVector(T, dim(Σ)), Σ)
end
MvNormal(Σ::AbstractPDMat) = MvNormal(Zeros{eltype(Σ)}(dim(Σ)), Σ)

MvNormal(μ::AbstractVector{<:Real}, Σ::Matrix{<:Real}) = MvNormal(μ, PDMat(Σ))
MvNormal(μ::AbstractVector{<:Real}, Σ::Union{Symmetric{<:Real}, Hermitian{<:Real}}) = MvNormal(μ, PDMat(Σ))
Expand Down
24 changes: 12 additions & 12 deletions src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ which is also a subtype of `AbstractMvNormal` to represent a multivariate normal
canonical parameters. Particularly, `MvNormalCanon` is defined as:

```julia
struct MvNormalCanon{P<:AbstractPDMat,V<:Union{Vector,ZeroVector}} <: AbstractMvNormal
struct MvNormalCanon{P<:AbstractPDMat,V<:AbstractVector} <: AbstractMvNormal
μ::V # the mean vector
h::V # potential vector, i.e. inv(Σ) * μ
J::P # precision matrix, i.e. inv(Σ)
Expand All @@ -33,9 +33,9 @@ const FullNormalCanon = MvNormalCanon{PDMat, Vector{Float64}}
const DiagNormalCanon = MvNormalCanon{PDiagMat, Vector{Float64}}
const IsoNormalCanon = MvNormalCanon{ScalMat, Vector{Float64}}

const ZeroMeanFullNormalCanon = MvNormalCanon{PDMat, ZeroVector{Float64}}
const ZeroMeanDiagNormalCanon = MvNormalCanon{PDiagMat, ZeroVector{Float64}}
const ZeroMeanIsoNormalCanon = MvNormalCanon{ScalMat, ZeroVector{Float64}}
const ZeroMeanFullNormalCanon{Axes} = MvNormalCanon{PDMat, Zeros{Float64,1}}
const ZeroMeanDiagNormalCanon{Axes} = MvNormalCanon{PDiagMat, Zeros{Float64,1}}
const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{ScalMat, Zeros{Float64,1,Axes}}
```

A multivariate distribution with canonical parameterization can be constructed using a common constructor `MvNormalCanon` as:
Expand Down Expand Up @@ -74,9 +74,9 @@ const FullNormalCanon = MvNormalCanon{Float64, PDMat{Float64,Matrix{Float64}},Ve
const DiagNormalCanon = MvNormalCanon{Float64,PDiagMat{Float64,Vector{Float64}},Vector{Float64}}
const IsoNormalCanon = MvNormalCanon{Float64,ScalMat{Float64},Vector{Float64}}

const ZeroMeanFullNormalCanon = MvNormalCanon{Float64,PDMat{Float64,Matrix{Float64}},ZeroVector{Float64}}
const ZeroMeanDiagNormalCanon = MvNormalCanon{Float64,PDiagMat{Float64,Vector{Float64}},ZeroVector{Float64}}
const ZeroMeanIsoNormalCanon = MvNormalCanon{Float64,ScalMat{Float64},ZeroVector{Float64}}
const ZeroMeanFullNormalCanon{Axes} = MvNormalCanon{Float64,PDMat{Float64,Matrix{Float64}},Zeros{Float64,1,Axes}}
const ZeroMeanDiagNormalCanon{Axes} = MvNormalCanon{Float64,PDiagMat{Float64,Vector{Float64}},Zeros{Float64,1,Axes}}
const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{Float64,ScalMat{Float64},Zeros{Float64,1,Axes}}


### Constructors
Expand All @@ -100,9 +100,9 @@ function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{S}, J::P) where
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
end

function MvNormalCanon(J::P) where P<:AbstractPDMat
z = ZeroVector(eltype(J), dim(J))
MvNormalCanon{eltype(J),P,ZeroVector{eltype(J)}}(z, z, J)
function MvNormalCanon(J::AbstractPDMat)
z = Zeros{eltype(J)}(dim(J))
MvNormalCanon(z, z, J)
end

function MvNormalCanon(h::AbstractVector{T}, J::P) where {T<:Real, P<:AbstractPDMat}
Expand Down Expand Up @@ -143,13 +143,13 @@ end

meanform(d::MvNormalCanon) = MvNormal(d.μ, inv(d.J))
# meanform{C, T<:Real}(d::MvNormalCanon{T,C,Vector{T}}) = MvNormal(d.μ, inv(d.J))
# meanform{C, T<:Real}(d::MvNormalCanon{T,C,ZeroVector{T}}) = MvNormal(inv(d.J))
# meanform{C, T<:Real}(d::MvNormalCanon{T,C,Zeros{T}}) = MvNormal(inv(d.J))

function canonform(d::MvNormal{T,C,<:AbstractVector{T}}) where {C, T<:Real}
J = inv(d.Σ)
return MvNormalCanon(d.μ, J * collect(d.μ), J)
end
canonform(d::MvNormal{T,C,ZeroVector{T}}) where {C, T<:Real} = MvNormalCanon(inv(d.Σ))
canonform(d::MvNormal{T,C,Zeros{T}}) where {C, T<:Real} = MvNormalCanon(inv(d.Σ))

### Basic statistics

Expand Down
12 changes: 6 additions & 6 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ function GenericMvTDist(df::T, μ::Mean, Σ::Cov, zmean::Bool) where {Cov<:Abstr
GenericMvTDist{R, typeof(S), typeof(m)}(R(df), d, zmean, m, S)
end

GenericMvTDist(df::Real, μ::Mean, Σ::Cov) where {Cov<:AbstractPDMat, Mean<:AbstractVector} =
GenericMvTDist(df, μ, Σ, allzeros(μ))
GenericMvTDist(df::Real, μ::AbstractVector, Σ::AbstractPDMat) =
GenericMvTDist(df, μ, Σ, all(iszero, μ))

function GenericMvTDist(df::T, Σ::Cov) where {Cov<:AbstractPDMat, T<:Real}
R = Base.promote_eltype(T, Σ)
GenericMvTDist(df, zeros(R,dim(Σ)), Σ, true)
function GenericMvTDist(df::Real, Σ::AbstractPDMat)
R = Base.promote_eltype(df, Σ)
GenericMvTDist(df, Zeros{R}(dim(Σ)), Σ, true)
end

GenericMvTDist{T,Cov,Mean}(df, μ, Σ) where {T,Cov,Mean} =
Expand Down Expand Up @@ -115,7 +115,7 @@ end
# evaluation (for GenericMvTDist)

insupport(d::AbstractMvTDist, x::AbstractVector{T}) where {T<:Real} =
length(d) == length(x) && allfinite(x)
length(d) == length(x) && all(isfinite, x)

function sqmahal(d::GenericMvTDist, x::AbstractVector{T}) where T<:Real
z = d.zeromean ? x : x - d.μ
Expand Down
73 changes: 6 additions & 67 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,21 @@ macro check_args(D, cond)
end
end

## a type to indicate zero vector
"""
An immutable vector of zeros of type T
"""
struct ZeroVector{T} <: AbstractVector{T}
len::Int
end

ZeroVector(::Type{T}, n::Int) where {T} = ZeroVector{T}(n)

Base.length(v::ZeroVector) = v.len
Base.size(v::ZeroVector) = (v.len,)
Base.getindex(v::ZeroVector{T}, i) where {T} = zero(T)

Base.Vector(v::ZeroVector{T}) where {T} = zeros(T, v.len)
Base.convert(::Type{Vector{T}}, v::ZeroVector{T}) where {T} = Vector(v)
Base.convert(::Type{<:Vector}, v::ZeroVector{T}) where {T} = Vector(v)

Base.convert(::Type{ZeroVector{T}}, v::ZeroVector) where {T} = ZeroVector{T}(length(v))

Base.broadcast(::Union{typeof(+),typeof(-)}, x::AbstractArray, v::ZeroVector) = x
Base.broadcast(::typeof(+), v::ZeroVector, x::AbstractArray) = x
Base.broadcast(::typeof(-), v::ZeroVector, x::AbstractArray) = -x

Base.broadcast(::Union{typeof(+),typeof(-)}, x::Number, v::ZeroVector) = fill(x, v.len)
Base.broadcast(::typeof(+), v::ZeroVector, x::Number) = fill(x, v.len)
Base.broadcast(::typeof(-), v::ZeroVector, x::Number) = fill(-x, v.len)
Base.broadcast(::typeof(*), v::ZeroVector, ::Number) = v

##### Utility functions

isunitvec(v::AbstractVector{T}) where {T} = (norm(v) - 1.0) < 1.0e-12

function allfinite(x::AbstractArray{T}) where {T<:Real}
for i in eachindex(x)
if !isfinite(x[i])
return false
end
end
return true
end

function allzeros(x::AbstractArray{T}) where {T<:Real}
for i in eachindex(x)
if !(x[i] == zero(T))
return false
end
end
return true
end

allzeros(x::ZeroVector) = true
isunitvec(v::AbstractVector) = (norm(v) - 1.0) < 1.0e-12

allnonneg(xs::AbstractArray{<:Real}) = all(x -> x >= 0, xs)

isprobvec(p::AbstractVector{T}) where {T<:Real} =
allnonneg(p) && isapprox(sum(p), one(T))
isprobvec(p::AbstractVector{<:Real}) =
all(x -> x ≥ zero(x), p) && isapprox(sum(p), one(eltype(p)))

pnormalize!(v::AbstractVector{<:Real}) = (v ./= sum(v); v)

add!(x::AbstractArray, y::AbstractVector) = broadcast!(+, x, x, y)
add!(x::AbstractVecOrMat, y::ZeroVector) = x
add!(x::AbstractArray, y::Zeros) = x

function multiply!(x::AbstractArray, c::Number)
for i in eachindex(x)
@inbounds x[i] *= c
end
return x
end
multiply!(x::AbstractArray, c::Number) = (x .*= c; x)

function exp!(x::AbstractArray)
for i in eachindex(x)
@inbounds x[i] = exp(x[i])
end
return x
end
exp!(x::AbstractArray) = (x .= exp.(x); x)

# get a type wide enough to represent all a distributions's parameters
# (if the distribution is parametric)
Expand Down
4 changes: 2 additions & 2 deletions test/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ end

for (d1, d2) in Iterators.product(dist_list, dist_list)
d3 = convolve(d1, d2)
@test isa(d3, IsoNormal) || isa(d3, DiagNormal)
@test d3 isa Union{IsoNormal,DiagNormal,ZeroMeanIsoNormal,ZeroMeanDiagNormal}
@test d3.μ == d1.μ .+ d2.μ
@test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats
end
Expand Down Expand Up @@ -188,7 +188,7 @@ end

for (d1, d2) in Iterators.product(dist_list, dist_list)
d3 = convolve(d1, d2)
@test isa(d3, FullNormal)
@test d3 isa Union{FullNormal,ZeroMeanFullNormal}
@test d3.μ == d1.μ .+ d2.μ
@test d3.Σ.mat == d1.Σ.mat + d2.Σ.mat # isequal not defined for PDMats
end
Expand Down
8 changes: 6 additions & 2 deletions test/mvlognormal.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Tests on Multivariate LogNormal distributions

using Distributions, PDMats
using Distributions, FillArrays, PDMats
using LinearAlgebra, Random, Test


Expand All @@ -18,7 +18,11 @@ function test_mvlognormal(g::MvLogNormal, n_tsamples::Int=10^6,
e = entropy(g)
@test partype(g) == Float64
@test isa(mn, Vector{Float64})
@test isa(md, Vector{Float64})
if g.normal.μ isa Zeros{Float64,1}
@test md isa Fill{Float64,1}
else
@test md isa Vector{Float64}
end
@test isa(mo, Vector{Float64})
@test isa(s, Vector{Float64})
@test isa(S, Matrix{Float64})
Expand Down
19 changes: 0 additions & 19 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,11 @@ r = RealInterval(1.5, 4.0)
A = rand(1:10, 5, 5)
B = rand(Float32, 4)
C = 1//2
Z = Distributions.ZeroVector(Float64, 5)
L = rand(Float32, 4, 4)
D = PDMats.PDMat(L * L')
@test typeof(convert(Distributions.ZeroVector{Float32}, Z)) == Distributions.ZeroVector{Float32}

for v in (15, π, 0x33, 14.0)
@test Z .* v == Z
end

for idx in eachindex(Z)
@test Z[idx] == zero(eltype(typeof(Z)))
end

# Ensure that utilities functions works with abstract arrays

@test Distributions.allfinite(GenericArray([-1, 0, Inf])) == false
@test Distributions.allfinite(GenericArray([0, 0, 0]))

@test Distributions.allzeros(GenericArray([-1, 0, 1])) == false
@test Distributions.allzeros(GenericArray([0, 0, 0]))

@test Distributions.allnonneg(GenericArray([-1, 0, 1])) == false
@test Distributions.allnonneg(GenericArray([0, 0, 0]))

@test isprobvec(GenericArray([1, 1, 1])) == false
@test isprobvec(GenericArray([1/3, 1/3, 1/3]))

Expand Down