Skip to content

Commit

Permalink
fix performance tests for gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Nov 19, 2014
1 parent 97ae69b commit ed7b00d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
22 changes: 15 additions & 7 deletions perf/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,30 @@ benchmark_exponential() = (

## gamma

import Distributions: GammaRmathSampler, GammaMTSampler
import Distributions: GammaRmathSampler, GammaGDSampler, GammaGSSampler,
GammaMTSampler, GammaIPSampler

getname(::Type{GammaRmathSampler}) = "rmath"
getname(::Type{GammaGDSampler}) = "GD"
getname(::Type{GammaGSSampler}) = "GS"
getname(::Type{GammaMTSampler}) = "MT"
getname(::Type{GammaIPSampler}) = "IP"

benchmark_gamma() = (
make_procs(GammaRmathSampler, GammaMTSampler),
"(α, scale)", [(α, 1.0) for α in [0.5, 1.0, 2.0, 5.0, 20.0]])
benchmark_gamma_hi() = (
make_procs(GammaRmathSampler, GammaMTSampler, GammaGDSampler),
"Dist", [(Gamma(α, 1.0),) for α in [1.5, 2.0, 3.0, 5.0, 20.0]])

benchmark_gamma_lo() = (
make_procs(GammaRmathSampler, GammaGSSampler, GammaIPSampler),
"Dist", [(Gamma(α, 1.0),) for α in [0.1, 0.5, 0.9]])

### main

const dnames = ["categorical",
"binomial",
"poisson",
"exponential",
"gamma"]
"gamma_hi","gamma_lo"]

function printhelp()
println("Require exactly one argument. Usage:")
Expand Down Expand Up @@ -141,10 +148,11 @@ function do_benchmark(dname; verbose::Int=2)
dname == "binomial" ? benchmark_binomial() :
dname == "poisson" ? benchmark_poisson() :
dname == "exponential" ? benchmark_exponential() :
dname == "gamma" ? benchmark_gamma() :
dname == "gamma_hi" ? benchmark_gamma_hi() :
dname == "gamma_lo" ? benchmark_gamma_lo() :
error("benchmarking function for $dname has not been implemented.")

r = run(procs, cfgs; duration=0.2, verbose=verbose)
r = run(procs, cfgs; duration=0.5, verbose=verbose)
println()
show(r; unit=:mps, cfghead=cfghead)
end
Expand Down
4 changes: 2 additions & 2 deletions src/samplers/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ end
function GammaIPSampler{S<:Sampleable}(d::Gamma,::Type{S})
GammaIPSampler(Gamma(1.0+d.shape,d.scale), -1.0/d.shape)
end
GammaIPSampler(d::Gamma) = GammaIPSampler(d,GammaGDSampler)
GammaIPSampler(d::Gamma) = GammaIPSampler(d,GammaMTSampler)

function rand(s::GammaIPSampler)
x = rand(s.s)
Expand All @@ -231,7 +231,7 @@ end
# function sampler(d::Gamma)
# if d.shape < 1.0
# # TODO: d.shape = 0.5 : use scaled chisq
# GammaIPSampler(d,GammaGDSampler)
# GammaIPSampler(d)
# elseif d.shape == 1.0
# Exponential(d.scale)
# else
Expand Down

0 comments on commit ed7b00d

Please sign in to comment.