Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimization for discrete uniform distributions of equal size #17

Merged
merged 5 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "ExactOptimalTransport"
uuid = "24df6009-d856-477c-ac5c-91f668376b31"
authors = ["JuliaOptimalTransport"]
version = "0.1.1"
version = "0.1.2"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -16,6 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
Distances = "0.9.0, 0.10"
Distributions = "0.24, 0.25"
FillArrays = "0.12"
MathOptInterface = "0.9"
PDMats = "0.10, 0.11"
QuadGK = "2"
Expand Down
1 change: 1 addition & 0 deletions src/ExactOptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ExactOptimalTransport
using Distances
using MathOptInterface
using Distributions
using FillArrays
using PDMats
using QuadGK
using StatsBase: StatsBase
Expand Down
99 changes: 56 additions & 43 deletions src/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,30 +263,38 @@ a sparse matrix.
See also: [`ot_cost`](@ref), [`emd`](@ref)
"""
function ot_plan(_, μ::DiscreteNonParametric, ν::DiscreteNonParametric)
# unpack the probabilities of the two distributions
# Unpack the probabilities of the two distributions
# Note: support of `DiscreteNonParametric` is sorted
μprobs = probs(μ)
νprobs = probs(ν)

# create the iterator
# note: support of `DiscreteNonParametric` is sorted
iter = Discrete1DOTIterator(μprobs, νprobs)

# create arrays for the indices of the two histograms and the optimal flow between the
# corresponding points
n = length(iter)
I = Vector{Int}(undef, n)
J = Vector{Int}(undef, n)
W = Vector{Base.promote_eltype(μprobs, νprobs)}(undef, n)

# compute the sparse optimal transport plan
@inbounds for (idx, (i, j, w)) in enumerate(iter)
I[idx] = i
J[idx] = j
W[idx] = w
T = Base.promote_eltype(μprobs, νprobs)

return if μprobs isa FillArrays.AbstractFill &&
νprobs isa FillArrays.AbstractFill &&
length(μprobs) == length(νprobs)
# Special case: discrete uniform distributions of the same "size"
k = length(μprobs)
sparse(1:k, 1:k, T(first(μprobs)), k, k)
else
# Generic case
# Create the iterator
iter = Discrete1DOTIterator(μprobs, νprobs)

# create arrays for the indices of the two histograms and the optimal flow between the
# corresponding points
n = length(iter)
I = Vector{Int}(undef, n)
J = Vector{Int}(undef, n)
W = Vector{T}(undef, n)

# compute the sparse optimal transport plan
@inbounds for (idx, (i, j, w)) in enumerate(iter)
I[idx] = i
J[idx] = j
W[idx] = w
end
sparse(I, J, W, length(μprobs), length(νprobs))
end
γ = sparse(I, J, W, length(μprobs), length(νprobs))

return γ
end

"""
Expand All @@ -305,45 +313,50 @@ A pre-computed optimal transport `plan` may be provided.
See also: [`ot_plan`](@ref), [`emd2`](@ref)
"""
function ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; plan=nothing)
return _ot_cost(c, μ, ν, plan)
end

# compute cost from scratch if no plan is provided
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, ::Nothing)
# unpack the probabilities of the two distributions
# Extract support and probabilities of discrete distributions
# Note: support of `DiscreteNonParametric` is sorted
μsupport = support(μ)
νsupport = support(ν)
μprobs = probs(μ)
νprobs = probs(ν)

return if μprobs isa FillArrays.AbstractFill &&
νprobs isa FillArrays.AbstractFill &&
length(μprobs) == length(νprobs)
# Special case: discrete uniform distributions of the same "size"
# In this case we always just compute `sum(c.(μsupport .- νsupport))` and scale it
# We use pairwise summation and avoid allocations
# (https://github.com/JuliaLang/julia/pull/31020)
T = Base.promote_eltype(μprobs, νprobs)
T(first(μprobs)) *
sum(Broadcast.instantiate(Broadcast.broadcasted(c, μsupport, νsupport)))
else
# Generic case
_ot_cost(c, μsupport, μprobs, νsupport, νprobs, plan)
end
end

# compute cost from scratch if no plan is provided
function _ot_cost(c, μsupport, μprobs, νsupport, νprobs, ::Nothing)
# create the iterator
# note: support of `DiscreteNonParametric` is sorted
iter = Discrete1DOTIterator(μprobs, νprobs)

# compute the cost
μsupport = support(μ)
νsupport = support(ν)
cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter)

return cost
return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter)
end

# if a sparse plan is provided, we just iterate through the non-zero entries
function _ot_cost(
c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan::SparseMatrixCSC
)
function _ot_cost(c, μsupport, _, νsupport, _, plan::SparseMatrixCSC)
# extract non-zero flows
I, J, W = findnz(plan)

# compute the cost
μsupport = support(μ)
νsupport = support(ν)
cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W))

return cost
return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W))
end

# fallback: compute cost matrix (probably often faster to compute cost from scratch)
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan)
return dot(plan, StatsBase.pairwise(c, support(μ), support(ν)))
function _ot_cost(c, μsupport, _, νsupport, _, plan)
return dot(plan, StatsBase.pairwise(c, μsupport, νsupport))
end

################
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end
"""
discretemeasure(
support::AbstractVector,
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
probs::AbstractVector{<:Real}=FillArrays.Fill(inv(length(support)), length(support)),
)

Construct a finite discrete probability measure with `support` and corresponding
Expand Down Expand Up @@ -42,13 +42,13 @@ using KernelFunctions
"""
function discretemeasure(
support::AbstractVector{<:Real},
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)),
)
return DiscreteNonParametric(support, probs)
end
function discretemeasure(
support::AbstractVector,
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)),
)
return FiniteDiscreteMeasure{typeof(support),typeof(probs)}(support, probs)
end
Expand Down
122 changes: 72 additions & 50 deletions test/exact.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ExactOptimalTransport

using Distances
using FillArrays
using PythonOT: PythonOT
using Tulip
using MathOptInterface
Expand Down Expand Up @@ -110,56 +111,77 @@ Random.seed!(100)
end

@testset "discrete case" begin
# random source and target marginal
m = 30
μprobs = normalize!(rand(m), 1)
μsupport = randn(m)
μ = DiscreteNonParametric(μsupport, μprobs)

n = 50
νprobs = normalize!(rand(n), 1)
νsupport = randn(n)
ν = DiscreteNonParametric(νsupport, νprobs)

# compute OT plan
γ = @inferred(ot_plan(euclidean, μ, ν))
@test γ isa SparseMatrixCSC
@test size(γ) == (m, n)
@test vec(sum(γ; dims=2)) ≈ μ.p
@test vec(sum(γ; dims=1)) ≈ ν.p

# consistency checks
I, J, W = findnz(γ)
@test all(w > zero(w) for w in W)
@test sum(W) ≈ 1
@test sort(unique(I)) == 1:m
@test sort(unique(J)) == 1:n
@test sort(I .+ J) == 2:(m + n)

# compute OT cost
c = @inferred(ot_cost(euclidean, μ, ν))

# compare with computation with explicit cost matrix
# DiscreteNonParametric sorts the support automatically, here we have to sort
# manually
C = pairwise(Euclidean(), μsupport', νsupport'; dims=2)
c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer())
@test c2 ≈ c rtol = 1e-5

# compare with POT
# disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds
# error
# @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
# @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")

# do not use the probabilities of μ and ν to ensure that the provided plan is
# used
μ2 = DiscreteNonParametric(μsupport, reverse(μprobs))
ν2 = DiscreteNonParametric(νsupport, reverse(νprobs))
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ))
@test c2 ≈ c
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ)))
@test c2 ≈ c
# different random sources and target marginals:
# non-uniform + different size, uniform + different size, uniform + equal size
for (μ, ν) in (
(
DiscreteNonParametric(randn(30), normalize!(rand(30), 1)),
DiscreteNonParametric(randn(50), normalize!(rand(50), 1)),
),
(
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
DiscreteNonParametric(randn(50), Fill(1 / 50, 50)),
),
(
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
),
)
# extract support, probabilities, and "size"
μsupport = support(μ)
μprobs = probs(μ)
m = length(μprobs)

νsupport = support(ν)
νprobs = probs(ν)
n = length(νprobs)

# compute OT plan
γ = @inferred(ot_plan(euclidean, μ, ν))
@test γ isa SparseMatrixCSC
@test size(γ) == (m, n)
@test vec(sum(γ; dims=2)) ≈ μ.p
@test vec(sum(γ; dims=1)) ≈ ν.p

# consistency checks
I, J, W = findnz(γ)
@test all(w > zero(w) for w in W)
@test sum(W) ≈ 1
@test sort(unique(I)) == 1:m
@test sort(unique(J)) == 1:n
@test sort(I .+ J) == if μprobs isa Fill && νprobs isa Fill && m == n
# Optimized version for special case (discrete uniform + equal size)
2:2:(m + n)
else
# Generic case (not optimized)
2:(m + n)
end

# compute OT cost
c = @inferred(ot_cost(euclidean, μ, ν))

# compare with computation with explicit cost matrix
# DiscreteNonParametric sorts the support automatically, here we have to sort
# manually
C = pairwise(Euclidean(), μsupport', νsupport'; dims=2)
c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer())
@test c2 ≈ c rtol = 1e-5

# compare with POT
# disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds
# error
# @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
# @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")

# do not use the probabilities of μ and ν to ensure that the provided plan is
# used
μ2 = DiscreteNonParametric(μsupport, reverse(μprobs))
ν2 = DiscreteNonParametric(νsupport, reverse(νprobs))
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ))
@test c2 ≈ c
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ)))
@test c2 ≈ c
end
end
end

Expand Down