Skip to content

Commit

Permalink
Use correct iteration numbers in chain (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jul 5, 2021
1 parent 7e70b1c commit 00da11a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.6.2"
version = "0.6.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
18 changes: 15 additions & 3 deletions src/mcmcchains-connect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ function AbstractMCMC.bundle_samples(
sampler::MHSampler,
state,
chain_type::Type{Chains};
discard_initial=0,
thinning=1,
param_names=missing,
kwargs...
)
Expand All @@ -25,7 +27,9 @@ function AbstractMCMC.bundle_samples(
push!(param_names, :lp)

# Bundle everything up and return a Chains struct.
return Chains(vals, param_names, (internals = [:lp],))
return Chains(
vals, param_names, (internals = [:lp],); start=discard_initial + 1, thin=thinning,
)
end

function AbstractMCMC.bundle_samples(
Expand All @@ -34,6 +38,8 @@ function AbstractMCMC.bundle_samples(
sampler::MHSampler,
state,
chain_type::Type{Chains};
discard_initial=0,
thinning=1,
param_names=missing,
kwargs...
)
Expand All @@ -59,7 +65,9 @@ function AbstractMCMC.bundle_samples(
end

# Bundle everything up and return a Chains struct.
return Chains(vals, param_names, (internals = [:lp],))
return Chains(
vals, param_names, (internals = [:lp],); start=discard_initial + 1, thin=thinning,
)
end

function AbstractMCMC.bundle_samples(
Expand All @@ -68,6 +76,8 @@ function AbstractMCMC.bundle_samples(
sampler::Ensemble,
state,
chain_type::Type{Chains};
discard_initial=0,
thinning=1,
param_names=missing,
kwargs...
)
Expand Down Expand Up @@ -100,5 +110,7 @@ function AbstractMCMC.bundle_samples(
push!(param_names, :lp)

# Bundle everything up and return a Chains struct.
return Chains(vals, param_names, (internals=[:lp],))
return Chains(
vals, param_names, (internals = [:lp],); start=discard_initial + 1, thin=thinning,
)
end
35 changes: 33 additions & 2 deletions test/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,27 @@
# perform stretch move and sample from prior in initial step
Random.seed!(100)
sampler = Ensemble(1_000, StretchProposal([InverseGamma(2, 3), Normal(0, 1)]))

chain = sample(model, sampler, 1_000;
param_names = ["s", "m"], chain_type = Chains)

@test chain isa Chains
@test range(chain) == 1:1_000
@test mean(chain["s"]) 49/24 atol=0.1
@test mean(chain["m"]) 7/6 atol=0.1

chain2 = sample(
model,
sampler,
1_000;
param_names = ["s", "m"],
chain_type = Chains,
discard_initial=25,
thinning=4,
)
@test chain2 isa Chains
@test range(chain2) == range(26; step=4, length=1_000)
@test mean(chain2["s"]) 49/24 atol=0.1
@test mean(chain2["m"]) 7/6 atol=0.1
end

@testset "transformed space" begin
Expand All @@ -44,9 +60,24 @@
sampler = Ensemble(1_000, StretchProposal(MvNormal(2, 1)))
chain = sample(model, sampler, 1_000;
param_names = ["logs", "m"], chain_type = Chains)

@test chain isa Chains
@test range(chain) == 1:1_000
@test mean(exp, chain["logs"]) 49/24 atol=0.1
@test mean(chain["m"]) 7/6 atol=0.1

chain2 = sample(
model,
sampler,
1_000;
param_names = ["logs", "m"],
chain_type = Chains,
discard_initial=25,
thinning=4,
)
@test chain2 isa Chains
@test range(chain2) == range(26; step=4, length=1_000)
@test mean(exp, chain2["logs"]) 49/24 atol=0.1
@test mean(chain2["m"]) 7/6 atol=0.1
end
end
end
33 changes: 33 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,19 @@ include("util.jl")
param_names=["μ", "σ"], chain_type=Chains
)
@test chain1 isa Chains
@test range(chain1) == 1:10_000
@test mean(chain1["μ"]) 0.0 atol=0.1
@test mean(chain1["σ"]) 1.0 atol=0.1

chain1b = sample(
model, StaticMH([Normal(0,1), Normal(0, 1)]), 10_000;
param_names=["μ", "σ"], chain_type=Chains, discard_initial=25, thinning=4,
)
@test chain1b isa Chains
@test range(chain1b) == range(26; step=4, length=10_000)
@test mean(chain1b["μ"]) 0.0 atol=0.1
@test mean(chain1b["σ"]) 1.0 atol=0.1

# NamedTuple of parameters
chain2 = sample(
model,
Expand All @@ -92,16 +102,39 @@ include("util.jl")
chain_type=Chains
)
@test chain2 isa Chains
@test range(chain2) == 1:10_000
@test mean(chain2["μ"]) 0.0 atol=0.1
@test mean(chain2["σ"]) 1.0 atol=0.1

chain2b = sample(
model,
MetropolisHastings(
= StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1)))
), 10_000;
chain_type=Chains, discard_initial=25, thinning=4,
)
@test chain2b isa Chains
@test range(chain2b) == range(26; step=4, length=10_000)
@test mean(chain2b["μ"]) 0.0 atol=0.1
@test mean(chain2b["σ"]) 1.0 atol=0.1

# Scalar parameter
chain3 = sample(
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
StaticMH(Normal(0, 1)), 10_000; param_names=["μ"], chain_type=Chains
)
@test chain3 isa Chains
@test range(chain3) == 1:10_000
@test mean(chain3["μ"]) 0.0 atol=0.1

chain3b = sample(
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
StaticMH(Normal(0, 1)), 10_000;
param_names=["μ"], chain_type=Chains, discard_initial=25, thinning=4,
)
@test chain3b isa Chains
@test range(chain3b) == range(26; step=4, length=10_000)
@test mean(chain3b["μ"]) 0.0 atol=0.1
end

@testset "Proposal styles" begin
Expand Down

2 comments on commit 00da11a

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/40294

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.3 -m "<description of version>" 00da11a5f7c955f835ead25369ffffbc47688614
git push origin v0.6.3

Please sign in to comment.