From a626fca88353270d68e69d1f6137b21ae8b81ff0 Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Sun, 23 Aug 2020 19:50:21 -0400 Subject: [PATCH 1/5] Optimize vMF sampling --- src/samplers/vonmisesfisher.jl | 65 +++++++++------------------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index 7fd979217..1534072fa 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -6,7 +6,7 @@ struct VonMisesFisherSampler b::Float64 x0::Float64 c::Float64 - Q::Matrix{Float64} + v::Vector{Float64} end function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) @@ -14,8 +14,8 @@ function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) b = _vmf_bval(p, κ) x0 = (1.0 - b) / (1.0 + b) c = κ * x0 + (p - 1) * log1p(-abs2(x0)) - Q = _vmf_rotmat(μ) - VonMisesFisherSampler(p, κ, b, x0, c, Q) + v = _vmf_householder_vec(μ) + VonMisesFisherSampler(p, κ, b, x0, c, v) end function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, @@ -36,7 +36,9 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, end # rotate - mul!(x, spl.Q, t) + scale = 2.0 * (spl.v' * t) + copyto!(x, t) + @. x -= (scale * spl.v) return x end @@ -59,8 +61,14 @@ _vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1) function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) # generate the W value -- the key step in simulating vMF # - # following movMF's document + # following movMF's document for the p > 3 case + # and Wenzel Jakob's document for the p == 3 case # + if p == 3 + ξ = rand(rng) + w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) + return w::Float64 + end r = (p - 1) / 2.0 betad = Beta(r, r) @@ -76,47 +84,8 @@ end _vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) = _vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ) -function _vmf_rotmat(u::Vector{Float64}) - # construct a rotation matrix Q - # s.t. Q * [1,0,...,0]^T --> u - # - # Strategy: construct a full-rank matrix - # with first column being u, and then - # perform QR factorization - # - - p = length(u) - A = zeros(p, p) - copyto!(view(A,:,1), u) - - # let k the be index of entry with max abs - k = 1 - a = abs(u[1]) - for i = 2:p - @inbounds ai = abs(u[i]) - if ai > a - k = i - a = ai - end - end - - # other columns of A will be filled with - # indicator vectors, except the one - # that activates the k-th entry - i = 1 - for j = 2:p - if i == k - i += 1 - end - A[i, j] = 1.0 - end - - # perform QR factorization - Q = Matrix(qr!(A).Q) - if dot(view(Q,:,1), u) < 0.0 # the first column was negated - for i = 1:p - @inbounds Q[i,1] = -Q[i,1] - end - end - return Q +function _vmf_householder_vec(u::Vector{Float64}) + u = normalize(u) + u[1] -= 1.0 + return normalize!(u) end From 48c424511a3f8548ae03e78b0d694c3ff2661922 Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Sun, 23 Aug 2020 21:08:47 -0400 Subject: [PATCH 2/5] Compute Householder vector in 1 pass instead of 3 or 4 --- src/samplers/vonmisesfisher.jl | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index 1534072fa..bbbb30670 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -36,8 +36,8 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, end # rotate - scale = 2.0 * (spl.v' * t) copyto!(x, t) + scale = 2.0 * (spl.v' * t) @. x -= (scale * spl.v) return x end @@ -84,8 +84,19 @@ end _vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) = _vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ) -function _vmf_householder_vec(u::Vector{Float64}) - u = normalize(u) - u[1] -= 1.0 - return normalize!(u) +function _vmf_householder_vec(μ::Vector{Float64}) + # assuming μ is a unit-vector (which it should be) + # can compute v in a single pass over μ + + p = length(μ) + v = zeros(p) + v[1] = μ[1] - 1.0 + s = sqrt(-2*v[1]) + v[1] /= s + + @inbounds for i in 2:p + v[i] = μ[i] / s + end + + return v end From e1850d19c6b126232cc9e779f000577b411f4b6e Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Fri, 9 Oct 2020 13:25:06 -0400 Subject: [PATCH 3/5] VonMisesFisher: Add tests for the w samplers in the p=3 case --- src/samplers/vonmisesfisher.jl | 24 ++++++++++++---------- test/vonmisesfisher.jl | 37 ++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index bbbb30670..9d6a67f19 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -58,18 +58,13 @@ end _vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1))) -function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) - # generate the W value -- the key step in simulating vMF - # - # following movMF's document for the p > 3 case - # and Wenzel Jakob's document for the p == 3 case - # - if p == 3 - ξ = rand(rng) - w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) - return w::Float64 - end +function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ) + ξ = rand(rng) + w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) + return w::Float64 +end +function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ) r = (p - 1) / 2.0 betad = Beta(r, r) z = rand(rng, betad) @@ -81,6 +76,13 @@ function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) return w::Float64 end +# generate the W value -- the key step in simulating vMF +# +# following movMF's document for the p != 3 case +# and Wenzel Jakob's document for the p == 3 case +_vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) = (p == 3) ? _vmf_genw3(rng, p, b, x0, c, κ) : _vmf_genwp(rng, p, b, x0, c, κ) + + _vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) = _vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ) diff --git a/test/vonmisesfisher.jl b/test/vonmisesfisher.jl index a7c07d975..5980c18d1 100644 --- a/test/vonmisesfisher.jl +++ b/test/vonmisesfisher.jl @@ -22,6 +22,35 @@ function gen_vmf_tdata(n::Int, p::Int, return X end + +function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) + p = 3 + + if ismissing(rng) + μ = randn(p) + else + μ = randn(rng, p) + end + μ = μ ./ norm(μ) + + s = Distributions.VonMisesFisherSampler(μ, float(κ)) + + genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] + genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] + + @test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01) + @test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ) + + # test mean and stdev against analytical formulas + coth_κ = coth(κ) + mean_w = coth_κ - 1/κ + var_w = 1 - coth_κ^2 + 1/κ^2 + + @test isapprox(mean(genw3_res), mean_w, atol=0.01) + @test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ) +end + + function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int, rng::Union{AbstractRNG, Missing} = missing) if ismissing(rng) @@ -65,6 +94,7 @@ function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int, x = rand(rng, d) end @test norm(x) ≈ 1.0 + @test insupport(d, x) if ismissing(rng) X = rand(d, n) @@ -73,6 +103,7 @@ function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int, end for i = 1:n @test norm(X[:,i]) ≈ 1.0 + @test insupport(d, X[:,i]) end # MLE @@ -119,4 +150,10 @@ ns = 10^6 (2, 2)] test_vonmisesfisher(p, κ, n, ns, rng) end + + if !ismissing(rng) + @testset "Testing genw with $key at (3, $κ)" for κ in [0.1, 0.5, 1.0, 2.0, 5.0] + test_genw3(κ, ns, rng) + end + end end From 4df026767d25ff0f88016e30681e729b8851799a Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Fri, 9 Oct 2020 19:07:41 -0400 Subject: [PATCH 4/5] Remove temporary variable and test Householder transformation for vMF sampling --- src/samplers/vonmisesfisher.jl | 48 +++++++++++++++++++--------------- test/vonmisesfisher.jl | 20 ++++++++++++++ 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index 9d6a67f19..c4c0fb335 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -18,37 +18,37 @@ function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) VonMisesFisherSampler(p, κ, b, x0, c, v) end -function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, - x::AbstractVector, t::AbstractVector) +@inline function _vmf_rot!(spl::VonMisesFisherSampler, x::AbstractVector) + # rotate + scale = 2.0 * (spl.v' * x) + @. x -= (scale * spl.v) + return x +end + + +function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) w = _vmf_genw(rng, spl) p = spl.p - t[1] = w + x[1] = w s = 0.0 - for i = 2:p - t[i] = ti = randn(rng) - s += abs2(ti) + @inbounds for i = 2:p + x[i] = xi = randn(rng) + s += abs2(xi) end - # normalize t[2:p] + # normalize x[2:p] r = sqrt((1.0 - abs2(w)) / s) - for i = 2:p - t[i] *= r + @inbounds for i = 2:p + x[i] *= r end - # rotate - copyto!(x, t) - scale = 2.0 * (spl.v' * t) - @. x -= (scale * spl.v) - return x + return _vmf_rot!(spl, x) end -_rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) = - _rand!(rng, spl, x, Vector{Float64}(undef, length(x))) function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractMatrix) - t = Vector{Float64}(undef, size(x, 1)) - for j = 1:size(x, 2) - _rand!(rng, spl, view(x,:,j), t) + @inbounds for j in axes(x, 2) + _rand!(rng, spl, view(x,:,j)) end return x end @@ -80,7 +80,13 @@ end # # following movMF's document for the p != 3 case # and Wenzel Jakob's document for the p == 3 case -_vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) = (p == 3) ? _vmf_genw3(rng, p, b, x0, c, κ) : _vmf_genwp(rng, p, b, x0, c, κ) +function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) + if p == 3 + return _vmf_genw3(rng, p, b, x0, c, κ) + else + return _vmf_genwp(rng, p, b, x0, c, κ) + end +end _vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) = @@ -91,7 +97,7 @@ function _vmf_householder_vec(μ::Vector{Float64}) # can compute v in a single pass over μ p = length(μ) - v = zeros(p) + v = similar(μ) v[1] = μ[1] - 1.0 s = sqrt(-2*v[1]) v[1] /= s diff --git a/test/vonmisesfisher.jl b/test/vonmisesfisher.jl index 5980c18d1..735b57cd5 100644 --- a/test/vonmisesfisher.jl +++ b/test/vonmisesfisher.jl @@ -22,6 +22,25 @@ function gen_vmf_tdata(n::Int, p::Int, return X end +function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) + if ismissing(rng) + μ = randn(p) + x = randn(p) + else + μ = randn(rng, p) + x = randn(rng, p) + end + κ = norm(μ) + μ = μ ./ κ + + s = Distributions.VonMisesFisherSampler(μ, κ) + H = I - 2*s.v*s.v' + + @test Distributions._vmf_rot!(s, copy(x)) ≈ (H*x) + +end + + function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) p = 3 @@ -149,6 +168,7 @@ ns = 10^6 (5, 2.0), (2, 2)] test_vonmisesfisher(p, κ, n, ns, rng) + test_vmf_rot(p, rng) end if !ismissing(rng) From 8d64f7d93291274e2afae03be10f8623d3f2231d Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Sat, 10 Oct 2020 12:23:08 -0400 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: John Zito --- src/samplers/vonmisesfisher.jl | 8 ++++---- test/vonmisesfisher.jl | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index c4c0fb335..29ce3daf1 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -18,10 +18,10 @@ function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) VonMisesFisherSampler(p, κ, b, x0, c, v) end -@inline function _vmf_rot!(spl::VonMisesFisherSampler, x::AbstractVector) +@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector) # rotate - scale = 2.0 * (spl.v' * x) - @. x -= (scale * spl.v) + scale = 2.0 * (v' * x) + @. x -= (scale * v) return x end @@ -42,7 +42,7 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) x[i] *= r end - return _vmf_rot!(spl, x) + return _vmf_rot!(spl.v, x) end diff --git a/test/vonmisesfisher.jl b/test/vonmisesfisher.jl index 735b57cd5..2d3e3a434 100644 --- a/test/vonmisesfisher.jl +++ b/test/vonmisesfisher.jl @@ -34,9 +34,10 @@ function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) μ = μ ./ κ s = Distributions.VonMisesFisherSampler(μ, κ) - H = I - 2*s.v*s.v' + v = μ - vcat(1, zeros(p-1)) + H = I - 2*v*v'/(v'*v) - @test Distributions._vmf_rot!(s, copy(x)) ≈ (H*x) + @test Distributions._vmf_rot!(s.v, copy(x)) ≈ (H*x) end