Skip to content

Commit

Permalink
Support array inputs of shape (draws, [chains[, parameters...]]) (#49)
Browse files Browse the repository at this point in the history
* Update and test utility functions

* Support arbitrary numbers of dimensions

* Update tests
  • Loading branch information
sethaxen authored Apr 15, 2023
1 parent 0a704ea commit a41028c
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 67 deletions.
45 changes: 24 additions & 21 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,21 @@ function Base.getproperty(r::PSISResult, k::Symbol)
if k === :weights
log_weights = getfield(r, :log_weights)
getfield(r, :normalized) && return exp.(log_weights)
return LogExpFunctions.softmax(log_weights; dims=sample_dims(log_weights))
return LogExpFunctions.softmax(log_weights; dims=_sample_dims(log_weights))
elseif k === :nparams
log_weights = getfield(r, :log_weights)
return if ndims(log_weights) == 1
1
else
prod(Base.Fix1(size, log_weights), param_dims(log_weights))
param_dims = _param_dims(log_weights)
prod(Base.Fix1(size, log_weights), param_dims; init=1)
end
elseif k === :ndraws
log_weights = getfield(r, :log_weights)
return ndims(log_weights) == 1 ? length(log_weights) : size(log_weights, 1)
return size(log_weights, 1)
elseif k === :nchains
log_weights = getfield(r, :log_weights)
return ndims(log_weights) == 3 ? size(log_weights, 2) : 1
return size(log_weights, 2)
end
k === :pareto_shape && return pareto_shape(r)
k === :ess && return ess_is(r)
Expand Down Expand Up @@ -182,18 +183,13 @@ While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in
# Arguments
- `log_ratios`: an array of logarithms of importance ratios, with one of the following
sizes:
+ `(draws,)`: a vector of draws for a single parameter from a single chain
+ `(draws, params)`: a matrix of draws for a multiple parameter from a single chain
+ `(draws, chains, params...)`: an array of draws for multiple parameters from
multiple chains, e.g. as might be generated with Markov chain Monte Carlo.
- `log_ratios`: an array of logarithms of importance ratios, with size
`(draws, [chains, [parameters...]])`, where `chains>1` would be used when chains are
generated using Markov chain Monte Carlo.
- `reff::Union{Real,AbstractArray}`: the ratio(s) of effective sample size of
`log_ratios` and the actual sample size `reff = ess/(ndraws * nchains)`, used to account
`log_ratios` and the actual sample size `reff = ess/(draws * chains)`, used to account
for autocorrelation, e.g. due to Markov chain Monte Carlo. If an array, it must have the
size `(params...,)` to match `log_ratios`.
size `(parameters...,)` to match `log_ratios`.
# Keywords
Expand Down Expand Up @@ -221,7 +217,7 @@ function psis(logr, reff=1; kwargs...)
return psis!(logw, reff; kwargs...)
end

function psis!(logw::AbstractVector, reff=1; normalize::Bool=true, warn::Bool=true)
function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool=true)
S = length(logw)
reff_val = first(reff)
M = tail_length(reff_val, S)
Expand All @@ -247,10 +243,17 @@ function psis!(logw::AbstractVector, reff=1; normalize::Bool=true, warn::Bool=tr
_maybe_log_normalize!(logw, normalize)
return PSISResult(logw, reff_val, M, tail_dist, normalize)
end
function psis!(logw::AbstractMatrix, reff=1; kwargs...)
result = psis!(vec(logw), only(reff); kwargs...)
# unflatten log_weights
return PSISResult(
logw, result.reff, result.tail_length, result.tail_dist, result.normalized
)
end
function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=true)
T = typeof(float(one(eltype(logw))))
# if an array defines custom indices (e.g. AbstractDimArray), we preserve them
param_axes = map(Base.Fix1(axes, logw), param_dims(logw))
param_axes = _param_axes(logw)

# allocate containers
reffs = similar(logw, eltype(reff), param_axes)
Expand All @@ -259,11 +262,11 @@ function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=tru
tail_dists = similar(logw, Union{Missing,GeneralizedPareto{T}}, param_axes)

# call psis! in parallel for all parameters
Threads.@threads for inds in CartesianIndices(param_axes)
logw_i = vec(param_draws(logw, inds))
result_i = psis!(logw_i, reffs[inds]; normalize=normalize, warn=false)
tail_lengths[inds] = result_i.tail_length
tail_dists[inds] = result_i.tail_dist
Threads.@threads for i in _eachparamindex(logw)
logw_i = _selectparam(logw, i)
result_i = psis!(logw_i, reffs[i]; normalize=normalize, warn=false)
tail_lengths[i] = result_i.tail_length
tail_dists[i] = result_i.tail_dist
end

# combine results
Expand Down
3 changes: 1 addition & 2 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ function ess_is(r::PSISResult; bad_shape_missing::Bool=true)
return _apply_missing(neff, r.tail_dist; bad_shape_missing=bad_shape_missing)
end
function ess_is(weights; reff=1)
dims = sample_dims(weights)
dims = _sample_dims(weights)
return reff ./ dropdims(sum(abs2, weights; dims=dims); dims=dims)
end
ess_is(weights::AbstractVector; reff::Real=1) = reff / sum(abs2, weights)

function _apply_missing(neff, dist; bad_shape_missing)
return bad_shape_missing && pareto_shape(dist) > 0.7 ? missing : neff
Expand Down
29 changes: 13 additions & 16 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,27 @@ missing_to_nan(x::AbstractArray{>:Missing}) = replace(x, missing => NaN)
missing_to_nan(::Missing) = NaN
missing_to_nan(x) = x

# dimensions corresponding to draws (and maybe chains)
_sample_dims(x::AbstractArray) = ntuple(identity, min(2, ndims(x)))

# dimension corresponding to parameters
function param_dims(x)
N = ndims(x)
@assert N > 1
N == 2 && return (2,)
N 3 && return ntuple(i -> i + 2, N - 2)
end
_param_dims(x::AbstractArray) = ntuple(i -> i + 2, max(0, ndims(x) - 2))

# axes corresponding to parameters
_param_axes(x::AbstractArray) = map(Base.Fix1(axes, x), _param_dims(x))

# view of all draws
function param_draws(x::AbstractArray, i::CartesianIndex)
# iterate over all parameters; combine with _selectparam
_eachparamindex(x::AbstractArray) = CartesianIndices(_param_axes(x))

# view of all draws for a param
function _selectparam(x::AbstractArray, i::CartesianIndex)
sample_dims = ntuple(_ -> Colon(), ndims(x) - length(i))
return view(x, sample_dims..., i)
end

# dimensions corresponding to draws and chains
function sample_dims(x::AbstractArray)
d = param_dims(x)
return filter((d), ntuple(identity, ndims(x)))
end
sample_dims(::AbstractVector) = Colon()

function _maybe_log_normalize!(x::AbstractArray, normalize::Bool)
if normalize
x .-= LogExpFunctions.logsumexp(x; dims=sample_dims(x))
x .-= LogExpFunctions.logsumexp(x; dims=_sample_dims(x))
end
return x
end
12 changes: 6 additions & 6 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ using DimensionalData: Dimensions, DimArray
proposal = Normal()
target = TDist(7)
rng = MersenneTwister(42)
x = rand(rng, proposal, 100, 30)
x = rand(rng, proposal, 100, 1, 30)
log_ratios = logpdf.(target, x) .- logpdf.(proposal, x)
reff = [100; ones(29)]
result = psis(log_ratios, reff)
Expand Down Expand Up @@ -110,8 +110,8 @@ end
]
proposal = Exponential(θ)
k_exp = 1 - θ
for sz in ((100_000,), (100_000, 5), (100_000, 4, 5))
dims = length(sz) == 1 ? Colon() : 1:(length(sz) - 1)
for sz in ((100_000,), (100_000, 4), (100_000, 4, 5))
dims = length(sz) < 3 ? Colon() : 1:(length(sz) - 1)
rng = MersenneTwister(42)
x = rand(rng, proposal, sz)
logr = logpdf.(target, x) .- logpdf.(proposal, x)
Expand All @@ -126,16 +126,16 @@ end
@test !(r2.log_weights r.log_weights)
@test r2.weights r.weights

if length(sz) == 3
if length(sz) > 1
@test all(r.tail_length .== PSIS.tail_length(1, 400_000))
else
@test all(r.tail_length .== PSIS.tail_length(1, 100_000))
end

k = r.pareto_shape
@test k isa (length(sz) == 1 ? Number : AbstractVector)
@test k isa (length(sz) < 3 ? Number : AbstractVector)
tail_dist = r.tail_dist
if length(sz) == 1
if length(sz) < 3
@test tail_dist isa PSIS.GeneralizedPareto
@test tail_dist.k == k
else
Expand Down
2 changes: 1 addition & 1 deletion test/plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using Test
end

@testset "plot(::PSISResult; seriestype=:path)" begin
result = psis(randn(100, 10))
result = psis(randn(100, 2, 10))
plt = plot(result; seriestype=:path)
@test length(plt.series_list) == 1
@test plt[1][1][:x] == eachindex(result.pareto_shape)
Expand Down
53 changes: 32 additions & 21 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,51 @@ using Test
using DimensionalData: Dimensions, DimArray

@testset "utils" begin
@testset "param_dim" begin
@testset "_param_dims" begin
x = randn(100)
@test PSIS._param_dims(x) == ()

x = randn(100, 10)
@test PSIS.param_dims(x) == (2,)
@test PSIS._param_dims(x) == ()

x = randn(100, 4, 10)
@test PSIS.param_dims(x) == (3,)
@test PSIS._param_dims(x) == (3,)

x = randn(100, 4, 5, 10)
@test PSIS.param_dims(x) == (3, 4)
@test PSIS._param_dims(x) == (3, 4)

x = randn(100, 4, 5, 6, 10)
@test PSIS.param_dims(x) == (3, 4, 5)
@test PSIS._param_dims(x) == (3, 4, 5)
end

@testset "param_draws" begin
x = randn(100, 10)
@test PSIS.param_draws(x, CartesianIndex(3)) === view(x, :, 3)

x = randn(100, 4, 10)
@test PSIS.param_draws(x, CartesianIndex(5)) === view(x, :, :, 5)

x = randn(100, 4, 5, 10)
@test PSIS.param_draws(x, CartesianIndex(5, 6)) === view(x, :, :, 5, 6)

x = randn(100, 4, 5, 6, 10)
@test PSIS.param_draws(x, CartesianIndex(5, 6, 7)) === view(x, :, :, 5, 6, 7)
@testset "_eachparamindex/_selectparam" begin
x = randn(100)
@test size(PSIS._eachparamindex(x)) == ()
@test PSIS._selectparam(x, PSIS._eachparamindex(x)[1]) == x

x = randn(100, 4)
@test size(PSIS._eachparamindex(x)) == ()
@test PSIS._selectparam(x, PSIS._eachparamindex(x)[1]) == x

x = randn(100, 4, 5)
@test size(PSIS._eachparamindex(x)) == (5,)
@test PSIS._selectparam.(Ref(x), PSIS._eachparamindex(x)) ==
collect(eachslice(x; dims=3))

x = randn(100, 4, 5, 3)
@test size(PSIS._eachparamindex(x)) == (5, 3)
if VERSION v"1.9"
@test PSIS._selectparam.(Ref(x), PSIS._eachparamindex(x)) ==
eachslice(x; dims=(3, 4))
end
end

@testset "sample_dims" begin
@testset "_sample_dims" begin
x = randn(100)
@test PSIS.sample_dims(x) === Colon()
@test PSIS._sample_dims(x) === (1,)
x = randn(100, 10)
@test PSIS.sample_dims(x) === (1,)
@test PSIS._sample_dims(x) === (1, 2)
x = randn(100, 4, 10)
@test PSIS.sample_dims(x) === (1, 2)
@test PSIS._sample_dims(x) === (1, 2)
end
end

0 comments on commit a41028c

Please sign in to comment.