Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize vMF sampling #1162

Merged
merged 5 commits into from
Oct 20, 2020
Merged

Optimize vMF sampling #1162

merged 5 commits into from
Oct 20, 2020

Conversation

emerali
Copy link
Contributor

@emerali emerali commented Aug 24, 2020

Resolves #1161

Had some free time so I just implemented it right away lol

@codecov-commenter
Copy link

codecov-commenter commented Aug 24, 2020

Codecov Report

Merging #1162 into master will decrease coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1162      +/-   ##
==========================================
- Coverage   79.91%   79.89%   -0.03%     
==========================================
  Files         115      115              
  Lines        5905     5899       -6     
==========================================
- Hits         4719     4713       -6     
  Misses       1186     1186              
Impacted Files Coverage Δ
src/samplers/vonmisesfisher.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 253c3a0...48c4245. Read the comment docs.

@emerali
Copy link
Contributor Author

emerali commented Sep 3, 2020

Some (rough) Benchmarks:

master (e127e9d):

julia> for p in [2, 3, 4, 1000]
           vmf = VonMisesFisher(rand(p)); @show p
           show(stdout, MIME"text/plain"(), @benchmark rand($vmf))
           println()
           show(stdout, MIME"text/plain"(), @benchmark rand($vmf, 1000))
           println()
       end
p = 2
BenchmarkTools.Trial: 
  memory estimate:  928 bytes
  allocs estimate:  10
  --------------
  minimum time:     2.208 μs (0.00% GC)
  median time:      2.545 μs (0.00% GC)
  mean time:        2.772 μs (0.80% GC)
  maximum time:     226.321 μs (97.56% GC)
  --------------
  samples:          10000
  evals/sample:     9
BenchmarkTools.Trial: 
  memory estimate:  16.56 KiB
  allocs estimate:  10
  --------------
  minimum time:     226.167 μs (0.00% GC)
  median time:      246.178 μs (0.00% GC)
  mean time:        258.843 μs (0.12% GC)
  maximum time:     1.349 ms (80.53% GC)
  --------------
  samples:          10000
  evals/sample:     1
p = 3
BenchmarkTools.Trial: 
  memory estimate:  1.22 KiB
  allocs estimate:  10
  --------------
  minimum time:     3.831 μs (0.00% GC)
  median time:      4.264 μs (0.00% GC)
  mean time:        4.619 μs (0.65% GC)
  maximum time:     175.653 μs (94.06% GC)
  --------------
  samples:          10000
  evals/sample:     8
BenchmarkTools.Trial: 
  memory estimate:  24.63 KiB
  allocs estimate:  11
  --------------
  minimum time:     315.719 μs (0.00% GC)
  median time:      334.156 μs (0.00% GC)
  mean time:        347.443 μs (0.13% GC)
  maximum time:     1.338 ms (72.30% GC)
  --------------
  samples:          10000
  evals/sample:     1
p = 4
BenchmarkTools.Trial: 
  memory estimate:  1.50 KiB
  allocs estimate:  10
  --------------
  minimum time:     5.311 μs (0.00% GC)
  median time:      5.901 μs (0.00% GC)
  mean time:        6.273 μs (0.55% GC)
  maximum time:     183.746 μs (95.64% GC)
  --------------
  samples:          10000
  evals/sample:     6
BenchmarkTools.Trial: 
  memory estimate:  32.72 KiB
  allocs estimate:  11
  --------------
  minimum time:     314.511 μs (0.00% GC)
  median time:      337.742 μs (0.00% GC)
  mean time:        348.028 μs (0.15% GC)
  maximum time:     1.086 ms (67.73% GC)
  --------------
  samples:          10000
  evals/sample:     1
p = 1000
BenchmarkTools.Trial: 
  memory estimate:  31.08 MiB
  allocs estimate:  16
  --------------
  minimum time:     176.674 ms (0.00% GC)
  median time:      229.130 ms (0.26% GC)
  mean time:        266.637 ms (4.73% GC)
  maximum time:     516.976 ms (21.10% GC)
  --------------
  samples:          19
  evals/sample:     1
BenchmarkTools.Trial: 
  memory estimate:  38.70 MiB
  allocs estimate:  17
  --------------
  minimum time:     806.788 ms (0.08% GC)
  median time:      938.293 ms (0.07% GC)
  mean time:        948.694 ms (2.52% GC)
  maximum time:     1.123 s (11.19% GC)
  --------------
  samples:          6
  evals/sample:     1

@emerali
Copy link
Contributor Author

emerali commented Sep 28, 2020

This PR (48c4245):

julia> for p in [2, 3, 4, 1000]
           vmf = VonMisesFisher(rand(p)); @show p
           show(stdout, MIME"text/plain"(), @benchmark rand($vmf))
           println()
           show(stdout, MIME"text/plain"(), @benchmark rand($vmf, 1000))
           println()
       end
p = 2
BenchmarkTools.Trial: 
  memory estimate:  288 bytes
  allocs estimate:  3
  --------------
  minimum time:     318.882 ns (0.00% GC)
  median time:      353.138 ns (0.00% GC)
  mean time:        389.258 ns (2.49% GC)
  maximum time:     8.583 μs (91.68% GC)
  --------------
  samples:          10000
  evals/sample:     246
BenchmarkTools.Trial: 
  memory estimate:  15.94 KiB
  allocs estimate:  3
  --------------
  minimum time:     207.156 μs (0.00% GC)
  median time:      226.899 μs (0.00% GC)
  mean time:        237.945 μs (0.10% GC)
  maximum time:     1.189 ms (71.64% GC)
  --------------
  samples:          10000
  evals/sample:     1
p = 3
BenchmarkTools.Trial: 
  memory estimate:  336 bytes
  allocs estimate:  3
  --------------
  minimum time:     241.591 ns (0.00% GC)
  median time:      254.943 ns (0.00% GC)
  mean time:        280.737 ns (1.98% GC)
  maximum time:     3.008 μs (86.80% GC)
  --------------
  samples:          10000
  evals/sample:     421
BenchmarkTools.Trial: 
  memory estimate:  23.73 KiB
  allocs estimate:  4
  --------------
  minimum time:     114.325 μs (0.00% GC)
  median time:      119.033 μs (0.00% GC)
  mean time:        124.773 μs (0.34% GC)
  maximum time:     1.250 ms (88.97% GC)
  --------------
  samples:          10000
  evals/sample:     1
p = 4
BenchmarkTools.Trial: 
  memory estimate:  336 bytes
  allocs estimate:  3
  --------------
  minimum time:     355.152 ns (0.00% GC)
  median time:      406.272 ns (0.00% GC)
  mean time:        441.461 ns (1.19% GC)
  maximum time:     4.822 μs (86.08% GC)
  --------------
  samples:          10000
  evals/sample:     204
BenchmarkTools.Trial: 
  memory estimate:  31.55 KiB
  allocs estimate:  4
  --------------
  minimum time:     247.321 μs (0.00% GC)
  median time:      275.635 μs (0.00% GC)
  mean time:        285.842 μs (0.17% GC)
  maximum time:     1.177 ms (75.45% GC)
  --------------
  samples:          10000
  evals/sample:     1
p = 1000
BenchmarkTools.Trial: 
  memory estimate:  23.81 KiB
  allocs estimate:  3
  --------------
  minimum time:     13.957 μs (0.00% GC)
  median time:      15.014 μs (0.00% GC)
  mean time:        16.490 μs (3.78% GC)
  maximum time:     1.668 ms (80.91% GC)
  --------------
  samples:          10000
  evals/sample:     1
BenchmarkTools.Trial: 
  memory estimate:  7.64 MiB
  allocs estimate:  4
  --------------
  minimum time:     10.936 ms (0.00% GC)
  median time:      11.199 ms (0.00% GC)
  mean time:        11.792 ms (3.50% GC)
  maximum time:     83.431 ms (86.34% GC)
  --------------
  samples:          424
  evals/sample:     1

@emerali
Copy link
Contributor Author

emerali commented Sep 28, 2020

Added a couple more benchmarks for p=2 and p=4. You can see that there's a dip in the computation time for p=3, since we avoid rejection sampling there. Hence the reduction in computation time for p=3 doesn't just come from the improved rotation operation.

Additionally, I noticed that the vMF distribution errors for p=1. In this case the distribution technically reduces to a discrete distribution on {+1, -1} (basically a post-processed Bernoulli), should vMF have a special case for that, or should the constructor error out and tell the user to use a Bernoulli distribution instead? I'd assume that anyone using the vMF distribution likely won't make use of this case, but it's probably a good idea for the sampling function to not error out in that scenario regardless.

There's a similar issue for kappa = 0 (vMF becomes a uniform spherical distribution), though currently the vMF constructor rejects that value.

@johnczito
Copy link
Member

johnczito commented Oct 9, 2020

Thanks for this! Looks good to me. My only suggestion is to take this opportunity to make the unit tests more convincing. Might include:

  • Test the two w samplers against one another in the p = 3 case;
  • Check that the x that's constructed at the end of rand! is the same as if you naively computed H(μ - e₁) * t.

I am in favor of things collapsing to the special cases when p = 1 or κ = 0 (especially the latter), but I think that should be handled in a separate PR.

@codecov-io
Copy link

codecov-io commented Oct 9, 2020

Codecov Report

Merging #1162 into master will increase coverage by 1.88%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1162      +/-   ##
==========================================
+ Coverage   79.91%   81.80%   +1.88%     
==========================================
  Files         115      116       +1     
  Lines        5905     6550     +645     
==========================================
+ Hits         4719     5358     +639     
- Misses       1186     1192       +6     
Impacted Files Coverage Δ
src/samplers/vonmisesfisher.jl 100.00% <100.00%> (ø)
src/univariate/discrete/bernoulli.jl 90.76% <0.00%> (-0.76%) ⬇️
src/univariate/discrete/skellam.jl 81.25% <0.00%> (-0.24%) ⬇️
src/qq.jl 100.00% <0.00%> (ø)
src/convolution.jl 100.00% <0.00%> (ø)
src/samplers/gamma.jl 100.00% <0.00%> (ø)
src/multivariate/product.jl 100.00% <0.00%> (ø)
src/truncated/exponential.jl 100.00% <0.00%> (ø)
src/univariate/continuous/ksonesided.jl 0.00% <0.00%> (ø)
src/univariate/discrete/soliton.jl 90.90% <0.00%> (ø)
... and 91 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 253c3a0...8d64f7d. Read the comment docs.

@emerali
Copy link
Contributor Author

emerali commented Oct 9, 2020

I've added the tests for the w sampler. It now compares the sample statistics of the p=3 case against those of the rejection sampling method. In addition, the sample statistics for the p=3 case are compared against analytical formulas for the mean/variance.

Check that the x that's constructed at the end of rand! is the same as if you naively computed H(μ - e₁) * t.

I'm having trouble reading this notation, do you mean construct the Householder matrix explicitly and apply it to the vector t as a comparison?

@emerali
Copy link
Contributor Author

emerali commented Oct 9, 2020

I agree with dealing with those special cases in a separate PR, particularly the κ = 0 case, since I feel that a UniformSpherical distribution deserves an implementation in this package.

Another optimization I believe should be possible would be to defer the sampling for the p=2 case to the univariate VonMises distribution, and then post process the result to produce a vector (via sincos), though again, that might be better suited for a separate PR.

@emerali
Copy link
Contributor Author

emerali commented Oct 9, 2020

Also, quick question, is the temporary variable t in _rand! necessary anymore? AFAICT, we can skip it now, since the only reason it was needed before (AFAIK) was because the old implementation needed to apply the matrix Q, but we can do this matrix application fully in-place now. Just need a sanity check on this.

@johnczito
Copy link
Member

I'm having trouble reading this notation, do you mean construct the Householder matrix explicitly and apply it to the vector t as a comparison?

Exactly. I believe the correctness of it, but no reason not to document it.

I agree with dealing with those special cases in a separate PR, particularly the κ = 0 case, since I feel that a UniformSpherical distribution deserves an implementation in this package.

Yeah, if Normal(0, 0) and TDist(Inf) work without error, then these special cases should too.

is the temporary variable t in _rand! necessary anymore?

Not sure.

src/samplers/vonmisesfisher.jl Outdated Show resolved Hide resolved
src/samplers/vonmisesfisher.jl Outdated Show resolved Hide resolved
test/vonmisesfisher.jl Outdated Show resolved Hide resolved
test/vonmisesfisher.jl Outdated Show resolved Hide resolved
Co-authored-by: John Zito <johnczito@users.noreply.github.com>
@emerali
Copy link
Contributor Author

emerali commented Oct 10, 2020

Nice, once this gets merged I'll get started on the UniformSpherical/UniformBall distributions, and then the special cases for vMF will come after.

@johnczito
Copy link
Member

I'll merge this in 24 hours if there is no other feedback.

@johnczito johnczito merged commit a2b7f71 into JuliaStats:master Oct 20, 2020
@emerali emerali deleted the vmf-opt branch October 20, 2020 23:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC: Performance Improvements for the VonMisesFisher sampler
4 participants