Skip to content

Commit

Permalink
Add optional boundary to probability simplex (#608)
Browse files Browse the repository at this point in the history
* Add optional boundary to probability simplex

* use :Open and :Closed

* addressing code review

* cover last line

* make a test test the line that needs a test
  • Loading branch information
mateuszbaran authored May 21, 2023
1 parent 440519b commit 133eef8
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <seth.axen@gmail.com>", "Mateusz Baran <mateuszbaran89@gmail.com>", "Ronny Bergmann <manopt@ronnybergmann.net>", "Antoine Levitt <antoine.levitt@gmail.com>"]
version = "0.8.61"
version = "0.8.62"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand Down
68 changes: 55 additions & 13 deletions src/manifolds/ProbabilitySimplex.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@doc raw"""
ProbabilitySimplex{n} <: AbstractDecoratorManifold{𝔽}
ProbabilitySimplex{n,boundary} <: AbstractDecoratorManifold{𝔽}
The (relative interior of) the probability simplex is the set
````math
Expand All @@ -8,6 +8,13 @@ The (relative interior of) the probability simplex is the set
````
where $\mathbb{1}=(1,…,1)^{\mathrm{T}}∈ ℝ^{n+1}$ denotes the vector containing only ones.
If `boundary` is set to `:open`, then the object represents an open simplex. Otherwise,
that is when `boundary` is set to `:closed`, the boundary is also included:
````math
\hat{Δ}^n := \biggl\{ p ∈ ℝ^{n+1}\ \big|\ p_i \geq 0 \text{ for all } i=1,…,n+1,
\text{ and } ⟨\mathbb{1},p⟩ = \sum_{i=1}^{n+1} p_i = 1\biggr\},
````
This set is also called the unit simplex or standard simplex.
The tangent space is given by
Expand All @@ -23,15 +30,28 @@ where $\mathcal N \subset 2𝕊^n$ is given by $\varphi(p) = 2\sqrt{p}$.
This implementation follows the notation in [^ÅströmPetraSchmitzerSchnörr2017].
# Constructor
ProbabilitySimplex(n::Int; boundary::Symbol=:open)
[^ÅströmPetraSchmitzerSchnörr2017]:
> F. Åström, S. Petra, B. Schmitzer, C. Schnörr: “Image Labeling by Assignment”,
> Journal of Mathematical Imaging and Vision, 58(2), pp. 221–238, 2017.
> doi: [10.1007/s10851-016-0702-4](https://doi.org/10.1007/s10851-016-0702-4)
> arxiv: [1603.05285](https://arxiv.org/abs/1603.05285).
"""
struct ProbabilitySimplex{n} <: AbstractDecoratorManifold{ℝ} end
struct ProbabilitySimplex{n,boundary} <: AbstractDecoratorManifold{ℝ} end

ProbabilitySimplex(n::Int) = ProbabilitySimplex{n}()
function ProbabilitySimplex(n::Int; boundary::Symbol=:open)
if boundary !== :open && boundary !== :closed
throw(
ArgumentError(
"boundary can only be set to :open or :closed; received $boundary",
),
)
end
return ProbabilitySimplex{n,boundary}()
end

"""
FisherRaoMetric <: AbstractMetric
Expand Down Expand Up @@ -91,8 +111,14 @@ Check whether `p` is a valid point on the [`ProbabilitySimplex`](@ref) `M`, i.e.
the embedding with positive entries that sum to one
The tolerance for the last test can be set using the `kwargs...`.
"""
function check_point(M::ProbabilitySimplex, p; kwargs...)
if minimum(p) <= 0
function check_point(M::ProbabilitySimplex{n,boundary}, p; kwargs...) where {n,boundary}
if boundary === :closed && minimum(p) < 0
return DomainError(
minimum(p),
"The point $(p) does not lie on the $(M) since it has negative entries.",
)
end
if boundary === :open && minimum(p) <= 0
return DomainError(
minimum(p),
"The point $(p) does not lie on the $(M) since it has nonpositive entries.",
Expand Down Expand Up @@ -190,18 +216,34 @@ injectivity_radius(M::ProbabilitySimplex) = 0
injectivity_radius(M::ProbabilitySimplex, ::AbstractRetractionMethod) = 0

@doc raw"""
inner(M::ProbabilitySimplex,p,X,Y)
inner(M::ProbabilitySimplex, p, X, Y)
Compute the inner product of two tangent vectors `X`, `Y` from the tangent space $T_pΔ^n$ at
`p`. The formula reads
````math
g_p(X,Y) = \sum_{i=1}^{n+1}\frac{X_iY_i}{p_i}
````
When `M` includes boundary, we can just skip coordinates where ``p_i`` is equal to 0, see
Proposition 2.1 in [^AyJostLeSchwachhöfer2017].
[^AyJostLeSchwachhöfer2017]:
> N. Ay, J. Jost, H. V. Le, and L. Schwachhöfer, Information Geometry. in Ergebnisse der
> Mathematik und ihrer Grenzgebiete. 3. Folge / A Series of Modern Surveys in
> Mathematics. Springer International Publishing, 2017.
> doi: [10.1007/978-3-319-56478-4](https://doi.org/10.1007/978-3-319-56478-4)
"""
function inner(::ProbabilitySimplex, p, X, Y)
function inner(::ProbabilitySimplex{n,boundary}, p, X, Y) where {n,boundary}
d = zero(Base.promote_eltype(p, X, Y))
@inbounds for i in eachindex(p, X, Y)
d += X[i] * Y[i] / p[i]
if boundary === :closed
@inbounds for i in eachindex(p, X, Y)
if p[i] > 0
d += X[i] * Y[i] / p[i]
end
end
else
@inbounds for i in eachindex(p, X, Y)
d += X[i] * Y[i] / p[i]
end
end
return d
end
Expand Down Expand Up @@ -379,16 +421,16 @@ function riemannian_gradient!(M::ProbabilitySimplex, X, p, Y; kwargs...)
return X
end

function Base.show(io::IO, ::ProbabilitySimplex{n}) where {n}
return print(io, "ProbabilitySimplex($(n))")
function Base.show(io::IO, ::ProbabilitySimplex{n,boundary}) where {n,boundary}
return print(io, "ProbabilitySimplex($(n); boundary=$boundary)")
end

@doc raw"""
zero_vector(M::ProbabilitySimplex,p)
zero_vector(M::ProbabilitySimplex, p)
returns the zero tangent vector in the tangent space of the point `p` from the
[`ProbabilitySimplex`](@ref) `M`, i.e. its representation by the zero vector in the embedding.
"""
zero_vector(::ProbabilitySimplex, ::Any)

zero_vector!(M::ProbabilitySimplex, v, p) = fill!(v, 0)
zero_vector!(::ProbabilitySimplex, X, p) = fill!(X, 0)
12 changes: 12 additions & 0 deletions test/manifolds/probability_simplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,16 @@ include("../utils.jl")
riemannian_gradient!(M, Z, p, Y)
@test X == Z
end

@testset "Simplex with boundary" begin
Mb = ProbabilitySimplex(2; boundary=:closed)
p = [0, 0.5, 0.5]
X = [0, 1, -1]
Y = [0, 2, -2]
@test is_point(Mb, p)
@test_throws DomainError is_point(Mb, p .- 1, true)
@test inner(Mb, p, X, Y) == 8

@test_throws ArgumentError ProbabilitySimplex(2; boundary=:tomato)
end
end

2 comments on commit 133eef8

@mateuszbaran
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/83990

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.62 -m "<description of version>" 133eef8013b0ba0397447ae0b5bfb184bb020e01
git push origin v0.8.62

Please sign in to comment.