-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
267 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,83 @@ | ||
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT). | ||
|
||
abstract type AbstractFlow <: Function | ||
end | ||
|
||
|
||
struct CompositeFlow <: AbstractFlow | ||
flow::Function | ||
end | ||
|
||
export CompositeFlow | ||
@functor CompositeFlow | ||
|
||
function CompositeFlow(modules::Vector{F}) where F <: Function | ||
return CompositeFlow(fchain(modules)) | ||
end | ||
|
||
function ChangesOfVariables.with_logabsdet_jacobian( | ||
f::CompositeFlow, | ||
x::Any | ||
) | ||
with_logabsdet_jacobian(f.flow, x) | ||
end | ||
|
||
(f::CompositeFlow)(x::Any) = f.flow(x) | ||
(f::CompositeFlow)(vs::AbstractValueShape) = vs | ||
|
||
function InverseFunctions.inverse(f::CompositeFlow) | ||
return CompositeFlow(InverseFunctions.inverse(f.flow).fs) | ||
end | ||
|
||
|
||
struct RQSplineCouplingModule <: AbstractFlow | ||
flow_module::Function | ||
end | ||
|
||
export RQSplineCouplingModule | ||
@functor RQSplineCouplingModule | ||
|
||
|
||
""" | ||
RQSplineCouplingModule(n_dims::Integer, | ||
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer, | ||
K::Union{Integer, Vector{Integer}} = 10 | ||
) | ||
Construct an instance of `RQSplineCouplingModule` for a `ǹ_dims` -dimensíonal input. Use `block_target_elements` | ||
to specify which block in the module transforms which components of the input. Use `K` to specify the desired | ||
number of spline segments used for the rational quadratic spline functions (defaults to 10). | ||
Note: This constructor does not ensure each element of the input is transformed by a block. If desired, this | ||
must be ensured in `block_target_elements`. | ||
""" | ||
function RQSplineCouplingModule(n_dims::Integer, | ||
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer, | ||
K::Union{Integer, Vector{Integer}} = 10, | ||
compute_unit::AbstractComputeUnit = CPUnit() | ||
) | ||
@argcheck K isa Integer || length(K) == length(block_target_elements) throw(DomainError(K, "please specify the same number of values for K as there are blocks")) | ||
|
||
n_blocks = length(block_target_elements) | ||
blocks = Vector{RQSplineCouplingBlock}(undef, n_blocks) | ||
n_out_neural_net = K isa Vector ? 3 .* K .- 1 : 3 .* fill(K, n_blocks) .- 1 | ||
|
||
for i in 1:n_blocks | ||
transformation_mask = fill(false, n_dims) | ||
transformation_mask[block_target_elements[i]] .= true | ||
neural_net = get_neural_net(n_dims - sum(transformation_mask), n_out_neural_net[i]) | ||
blocks[i] = RQSplineCouplingBlock(transformation_mask, neural_net, compute_unit) | ||
end | ||
return fchain(blocks) | ||
end | ||
|
||
function RQSplineCouplingModule(n_dims::Integer, | ||
block_target_elements::Integer = 1, | ||
K::Union{Integer, Vector{Integer}} = 10 | ||
) | ||
|
||
n_blocks = ceil(Integer, n_dims / block_target_elements) | ||
vectorized_bte = [UnitRange(i + 1, i + block_target_elements) for i in range(start = 0, stop = (n_blocks - 2) * block_target_elements, step = block_target_elements)] | ||
push!(vectorized_bte, UnitRange((n_blocks - 1) * block_target_elements + 1, n_dims)) | ||
|
||
RQSplineCouplingModule(n_dims, vectorized_bte, K) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT). | ||
|
||
std_normal_logpdf(x::Real) = -(abs2(x) + log2π)/2 | ||
|
||
function mvnormal_negll_flow(flow::Function, X::AbstractMatrix{<:Real}) | ||
nsamples = size(X, 2) | ||
|
||
Y, ladj = with_logabsdet_jacobian(flow, X) | ||
ll = (sum(std_normal_logpdf.(Y[flow.mask,:])) + sum(ladj)) / nsamples | ||
|
||
return -ll | ||
end | ||
|
||
function mvnormal_negll_flow_grad(flow, X::AbstractMatrix{<:Real}) | ||
negll, back = Zygote.pullback(mvnormal_negll_flow, flow, X) | ||
d_flow = back(one(eltype(X)))[1] | ||
return negll, d_flow | ||
end | ||
|
||
|
||
function optimize_flow(smpls::VectorOfSimilarVectors{<:Real}, | ||
initial_flow, | ||
optimizer; | ||
nbatches::Integer = 100, | ||
nepochs::Integer = 100, | ||
optstate = Optimisers.setup(optimizer, deepcopy(initial_flow)), | ||
negll_history = Vector{Float64}(), | ||
shuffle_samples::Bool = false | ||
) | ||
batchsize = round(Int, length(smpls) / nbatches) | ||
batches = collect(Iterators.partition(smpls, batchsize)) | ||
flow = deepcopy(initial_flow) | ||
state = deepcopy(optstate) | ||
negll_hist = Vector{Float64}() | ||
for i in 1:nepochs | ||
for batch in batches | ||
negll, d_flow = mvnormal_negll_flow_grad(flow, flatview(batch)) | ||
state, flow = Optimisers.update(state, flow, d_flow) | ||
push!(negll_hist, negll) | ||
end | ||
if shuffle_samples | ||
batches = collect(Iterators.partition(shuffle(smpls), batchsize)) | ||
end | ||
end | ||
(result = flow, optimizer_state = state, negll_history = vcat(negll_history, negll_hist)) | ||
end | ||
export optimize_flow | ||
|
||
# temporary hack | ||
function optimize_flow_sequentially(smpls::VectorOfSimilarVectors{<:Real}, | ||
initial_flow::CompositeFlow, | ||
optimizer; | ||
nbatches::Integer = 100, | ||
nepochs::Integer = 100, | ||
optstate = Optimisers.setup(optimizer, deepcopy(initial_flow)), | ||
negll_history = Vector{Float64}(), | ||
shuffle_samples::Bool = false | ||
) | ||
|
||
optimized_blocks = Vector{Function}(undef, length(initial_flow.flow.fs)) | ||
for block in initial_flow.flow.fs | ||
res = optimize_flow(smpls, block, optimizer; nbatches, nepochs, optstate, negll_history, shuffle_samples) | ||
optimized_blocks[i] = res.result | ||
end | ||
return CompositeFlow(optimized_blocks) | ||
end | ||
|
||
export optimize_flow_sequentially |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT). | ||
|
||
struct RQSplineCouplingBlock <: Function | ||
mask::Vector{Bool} | ||
nn::Chain | ||
nn_parameters::NamedTuple | ||
nn_state::NamedTuple | ||
end | ||
|
||
export RQSplineCouplingBlock | ||
@functor RQSplineCouplingBlock | ||
|
||
function RQSplineCouplingBlock(mask::Vector{Bool}, nn::Chain, compute_unit::AbstractComputeUnit) | ||
rng = Random.default_rng() | ||
Random.seed!(rng, 0) | ||
|
||
lux_compute_unit = compute_unit isa CPUnit ? cpu_device() : gpu_device() | ||
|
||
nn_parameters, nn_state = Lux.setup(rng, nn) .|> lux_compute_unit | ||
|
||
return RQSplineCouplingBlock(mask, nn, nn_parameters, nn_state) | ||
end | ||
|
||
|
||
function ChangesOfVariables.with_logabsdet_jacobian( | ||
f::RQSplineCouplingBlock, | ||
x::Any | ||
) | ||
apply_rqs_coupling_flow(f, x) | ||
end | ||
|
||
(f::RQSplineCouplingBlock)(x::Any) = apply_rqs_coupling_flow(f, x)[1] | ||
(f::RQSplineCouplingBlock)(vs::AbstractValueShape) = vs | ||
|
||
function InverseFunctions.inverse(f::RQSplineCouplingBlock) | ||
return InverseRQSplineCouplingBlock(f.nn, f.mask) | ||
end | ||
|
||
|
||
struct InverseRQSplineCouplingBlock <: Function | ||
mask::Vector{Bool} | ||
nn::Chain | ||
end | ||
|
||
export InverseRQSplineCouplingBlock | ||
@functor InverseRQSplineCouplingBlock | ||
|
||
function ChangesOfVariables.with_logabsdet_jacobian( | ||
f::InverseRQSplineCouplingBlock, | ||
x::Any | ||
) | ||
return apply_rqs_coupling_flow(f, x) | ||
end | ||
|
||
(f::InverseRQSplineCouplingBlock)(x::Any) = apply_rqs_coupling_flow(f, x)[1] | ||
(f::InverseRQSplineCouplingBlock)(vs::AbstractValueShape) = vs | ||
|
||
function InverseFunctions.inverse(f::InverseRQSplineCouplingBlock) | ||
return RQSplineCouplingBlock(f.nn, f.mask) | ||
end | ||
|
||
|
||
function apply_rqs_coupling_flow(flow::Union{RQSplineCouplingBlock, InverseRQSplineCouplingBlock}, x::Any) # make x typestable | ||
|
||
rq_spline = flow isa RQSplineCouplingBlock ? RQSpline : InvRQSpline | ||
n_dims_to_transform = sum(flow.mask) | ||
|
||
input_mask = .~flow.mask | ||
y, ladj = with_logabsdet_jacobian(rq_spline(get_params(flow.nn(x[input_mask,:], flow.nn_parameters, flow.nn_state)[1], n_dims_to_transform)...), x[flow.mask,:]) | ||
|
||
return MonotonicSplines._sort_dimensions(y, x, flow.mask), ladj | ||
end | ||
|
||
export apply_rqs_coupling_flow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT). | ||
|
||
function get_neural_net(n_in::Integer, | ||
n_out::Integer, | ||
n_hidden_layers::Integer = 1, | ||
n_in_hidden::Integer = 20 | ||
#compute_device::AbstractComputeDevice = CPUnit(), | ||
) | ||
|
||
layers = vcat([Dense(n_in, n_in_hidden, relu)], repeat([Dense(n_in_hidden, n_in_hidden, relu)], n_hidden_layers), [Dense(n_in_hidden, n_out)]) | ||
return Chain(layers...) | ||
end |