Skip to content

Commit

Permalink
Merge pull request #1162 from emerali/vmf-opt
Browse files Browse the repository at this point in the history
Optimize vMF sampling
  • Loading branch information
johnczito authored Oct 20, 2020
2 parents cf95808 + 8d64f7d commit a2b7f71
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 65 deletions.
118 changes: 53 additions & 65 deletions src/samplers/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,49 @@ struct VonMisesFisherSampler
b::Float64
x0::Float64
c::Float64
Q::Matrix{Float64}
v::Vector{Float64}
end

function VonMisesFisherSampler::Vector{Float64}, κ::Float64)
p = length(μ)
b = _vmf_bval(p, κ)
x0 = (1.0 - b) / (1.0 + b)
c = κ * x0 + (p - 1) * log1p(-abs2(x0))
Q = _vmf_rotmat(μ)
VonMisesFisherSampler(p, κ, b, x0, c, Q)
v = _vmf_householder_vec(μ)
VonMisesFisherSampler(p, κ, b, x0, c, v)
end

function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler,
x::AbstractVector, t::AbstractVector)
@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector)
# rotate
scale = 2.0 * (v' * x)
@. x -= (scale * v)
return x
end


function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector)
w = _vmf_genw(rng, spl)
p = spl.p
t[1] = w
x[1] = w
s = 0.0
for i = 2:p
t[i] = ti = randn(rng)
s += abs2(ti)
@inbounds for i = 2:p
x[i] = xi = randn(rng)
s += abs2(xi)
end

# normalize t[2:p]
# normalize x[2:p]
r = sqrt((1.0 - abs2(w)) / s)
for i = 2:p
t[i] *= r
@inbounds for i = 2:p
x[i] *= r
end

# rotate
mul!(x, spl.Q, t)
return x
return _vmf_rot!(spl.v, x)
end

_rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) =
_rand!(rng, spl, x, Vector{Float64}(undef, length(x)))

function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractMatrix)
t = Vector{Float64}(undef, size(x, 1))
for j = 1:size(x, 2)
_rand!(rng, spl, view(x,:,j), t)
@inbounds for j in axes(x, 2)
_rand!(rng, spl, view(x,:,j))
end
return x
end
Expand All @@ -56,12 +58,13 @@ end

_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))

function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
# generate the W value -- the key step in simulating vMF
#
# following movMF's document
#
function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ)
ξ = rand(rng)
w = 1.0 + (log+ (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end

function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ)
r = (p - 1) / 2.0
betad = Beta(r, r)
z = rand(rng, betad)
Expand All @@ -73,50 +76,35 @@ function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
return w::Float64
end

# generate the W value -- the key step in simulating vMF
#
# following movMF's document for the p != 3 case
# and Wenzel Jakob's document for the p == 3 case
function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
if p == 3
return _vmf_genw3(rng, p, b, x0, c, κ)
else
return _vmf_genwp(rng, p, b, x0, c, κ)
end
end


_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) =
_vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ)

function _vmf_rotmat(u::Vector{Float64})
# construct a rotation matrix Q
# s.t. Q * [1,0,...,0]^T --> u
#
# Strategy: construct a full-rank matrix
# with first column being u, and then
# perform QR factorization
#

p = length(u)
A = zeros(p, p)
copyto!(view(A,:,1), u)

# let k the be index of entry with max abs
k = 1
a = abs(u[1])
for i = 2:p
@inbounds ai = abs(u[i])
if ai > a
k = i
a = ai
end
end
function _vmf_householder_vec::Vector{Float64})
# assuming μ is a unit-vector (which it should be)
# can compute v in a single pass over μ

# other columns of A will be filled with
# indicator vectors, except the one
# that activates the k-th entry
i = 1
for j = 2:p
if i == k
i += 1
end
A[i, j] = 1.0
end
p = length(μ)
v = similar(μ)
v[1] = μ[1] - 1.0
s = sqrt(-2*v[1])
v[1] /= s

# perform QR factorization
Q = Matrix(qr!(A).Q)
if dot(view(Q,:,1), u) < 0.0 # the first column was negated
for i = 1:p
@inbounds Q[i,1] = -Q[i,1]
end
@inbounds for i in 2:p
v[i] = μ[i] / s
end
return Q

return v
end
58 changes: 58 additions & 0 deletions test/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,55 @@ function gen_vmf_tdata(n::Int, p::Int,
return X
end

function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing)
if ismissing(rng)
μ = randn(p)
x = randn(p)
else
μ = randn(rng, p)
x = randn(rng, p)
end
κ = norm(μ)
μ = μ ./ κ

s = Distributions.VonMisesFisherSampler(μ, κ)
v = μ - vcat(1, zeros(p-1))
H = I - 2*v*v'/(v'*v)

@test Distributions._vmf_rot!(s.v, copy(x)) (H*x)

end



function test_genw3::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
p = 3

if ismissing(rng)
μ = randn(p)
else
μ = randn(rng, p)
end
μ = μ ./ norm(μ)

s = Distributions.VonMisesFisherSampler(μ, float(κ))

genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]
genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]

@test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01)
@test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ)

# test mean and stdev against analytical formulas
coth_κ = coth(κ)
mean_w = coth_κ - 1/κ
var_w = 1 - coth_κ^2 + 1/κ^2

@test isapprox(mean(genw3_res), mean_w, atol=0.01)
@test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ)
end


function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int,
rng::Union{AbstractRNG, Missing} = missing)
if ismissing(rng)
Expand Down Expand Up @@ -65,6 +114,7 @@ function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int,
x = rand(rng, d)
end
@test norm(x) 1.0
@test insupport(d, x)

if ismissing(rng)
X = rand(d, n)
Expand All @@ -73,6 +123,7 @@ function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int,
end
for i = 1:n
@test norm(X[:,i]) 1.0
@test insupport(d, X[:,i])
end

# MLE
Expand Down Expand Up @@ -118,5 +169,12 @@ ns = 10^6
(5, 2.0),
(2, 2)]
test_vonmisesfisher(p, κ, n, ns, rng)
test_vmf_rot(p, rng)
end

if !ismissing(rng)
@testset "Testing genw with $key at (3, )" for κ in [0.1, 0.5, 1.0, 2.0, 5.0]
test_genw3(κ, ns, rng)
end
end
end

0 comments on commit a2b7f71

Please sign in to comment.