Skip to content

Commit

Permalink
Initial working Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Aug 23, 2023
1 parent 222da06 commit 78a89cd
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 60 deletions.
49 changes: 2 additions & 47 deletions src/adaptive_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,54 +30,9 @@ function InverseFunctions.inverse(f::CompositeFlow)
end


struct RQSplineCouplingModule <: AbstractFlow
flow_module::Function
abstract type AbstractFlowModule <: AbstractFlow
end

export RQSplineCouplingModule
@functor RQSplineCouplingModule


"""
RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer,
K::Union{Integer, Vector{Integer}} = 10
)
Construct an instance of `RQSplineCouplingModule` for a `ǹ_dims` -dimensíonal input. Use `block_target_elements`
to specify which block in the module transforms which components of the input. Use `K` to specify the desired
number of spline segments used for the rational quadratic spline functions (defaults to 10).
Note: This constructor does not ensure each element of the input is transformed by a block. If desired, this
must be ensured in `block_target_elements`.
"""
function RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer,
K::Union{Integer, Vector{Integer}} = 10,
compute_unit::AbstractComputeUnit = CPUnit()
)
@argcheck K isa Integer || length(K) == length(block_target_elements) throw(DomainError(K, "please specify the same number of values for K as there are blocks"))

n_blocks = length(block_target_elements)
blocks = Vector{RQSplineCouplingBlock}(undef, n_blocks)
n_out_neural_net = K isa Vector ? 3 .* K .- 1 : 3 .* fill(K, n_blocks) .- 1

for i in 1:n_blocks
transformation_mask = fill(false, n_dims)
transformation_mask[block_target_elements[i]] .= true
neural_net = get_neural_net(n_dims - sum(transformation_mask), n_out_neural_net[i])
blocks[i] = RQSplineCouplingBlock(transformation_mask, neural_net, compute_unit)
end
return fchain(blocks)
end

function RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Integer = 1,
K::Union{Integer, Vector{Integer}} = 10
)

n_blocks = ceil(Integer, n_dims / block_target_elements)
vectorized_bte = [UnitRange(i + 1, i + block_target_elements) for i in range(start = 0, stop = (n_blocks - 2) * block_target_elements, step = block_target_elements)]
push!(vectorized_bte, UnitRange((n_blocks - 1) * block_target_elements + 1, n_dims))

RQSplineCouplingModule(n_dims, vectorized_bte, K)
abstract type AbstractFlowBlock <: AbstractFlowModule
end
59 changes: 47 additions & 12 deletions src/optimize_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,34 @@

std_normal_logpdf(x::Real) = -(abs2(x) + log2π)/2

function mvnormal_negll_flow(flow::Function, X::AbstractMatrix{<:Real})
function mvnormal_negll_flow(flow::F, X::AbstractMatrix{<:Real}) where F<:AbstractFlow
nsamples = size(X, 2)

Y, ladj = with_logabsdet_jacobian(flow, X)
ll = (sum(std_normal_logpdf.(Y)) + sum(ladj)) / nsamples

return -ll
end

function mvnormal_negll_flow(flow::B, X::AbstractMatrix{<:Real}) where B<:AbstractFlowBlock
nsamples = size(X, 2)

Y, ladj = with_logabsdet_jacobian(flow, X)
ll = (sum(std_normal_logpdf.(Y[flow.mask,:])) + sum(ladj)) / nsamples

return -ll
end
export mvnormal_negll_flow

function mvnormal_negll_flow_grad(flow, X::AbstractMatrix{<:Real})
function mvnormal_negll_flow_grad(flow::F, X::AbstractMatrix{<:Real}) where F<:AbstractFlow
negll, back = Zygote.pullback(mvnormal_negll_flow, flow, X)
d_flow = back(one(eltype(X)))[1]
return negll, d_flow
end

export mvnormal_negll_flow_grad

function optimize_flow(smpls::VectorOfSimilarVectors{<:Real},
initial_flow,
initial_flow::F where F<:AbstractFlow,
optimizer;
nbatches::Integer = 100,
nepochs::Integer = 100,
Expand All @@ -46,23 +56,48 @@ function optimize_flow(smpls::VectorOfSimilarVectors{<:Real},
end
export optimize_flow

# temporary hack

function optimize_flow_sequentially(smpls::VectorOfSimilarVectors{<:Real},
initial_flow::CompositeFlow,
optimizer;
nbatches::Integer = 100,
nepochs::Integer = 100,
optstate = Optimisers.setup(optimizer, deepcopy(initial_flow)),
negll_history = Vector{Float64}(),
shuffle_samples::Bool = false
)

optimized_blocks = Vector{Function}(undef, length(initial_flow.flow.fs))
for block in initial_flow.flow.fs
res = optimize_flow(smpls, block, optimizer; nbatches, nepochs, optstate, negll_history, shuffle_samples)
optimized_blocks[i] = res.result
optimized_modules = Vector{AbstractFlow}(undef, length(initial_flow.flow.fs))
module_optimizer_states = Vector{NamedTuple}(undef, length(initial_flow.flow.fs))
module_negll_hists = Vector{Vector}(undef, length(initial_flow.flow.fs))

for (i,flow_module) in enumerate(initial_flow.flow.fs)
opt_module, opt_state, negll_hist = optimize_flow_sequentially(smpls, flow_module, optimizer; nbatches, nepochs, shuffle_samples)
optimized_modules[i] = opt_module
module_optimizer_states[i] = opt_state
module_negll_hists[i] = negll_hist
end
return CompositeFlow(optimized_blocks)

(CompositeFlow(optimized_modules), module_optimizer_states, module_negll_hists)
end

function optimize_flow_sequentially(smpls::VectorOfSimilarVectors{<:Real},
initial_flow::M where M<:AbstractFlowModule,
optimizer;
nbatches::Integer = 100,
nepochs::Integer = 100,
shuffle_samples::Bool = false
)

optimized_blocks = Vector{AbstractFlow}(undef, length(initial_flow.flow_module.fs))
block_optimizer_states = Vector{NamedTuple}(undef, length(initial_flow.flow_module.fs))
block_negll_hists = Vector{Vector}(undef, length(initial_flow.flow_module.fs))

for (i,block) in enumerate(initial_flow.flow_module.fs)
opt_flow, opt_state, negll_hist = optimize_flow(smpls, block, optimizer; nbatches, nepochs, shuffle_samples = shuffle_samples)
optimized_blocks[i] = opt_flow
block_optimizer_states[i] = opt_state
block_negll_hists[i] = negll_hist
end

(typeof(initial_flow)(optimized_blocks), block_optimizer_states, block_negll_hists)
end
export optimize_flow_sequentially
72 changes: 71 additions & 1 deletion src/rqspline_coupling.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,76 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

struct RQSplineCouplingBlock <: Function
struct RQSplineCouplingModule <: AbstractFlowModule
flow_module::Function
end

export RQSplineCouplingModule
@functor RQSplineCouplingModule

function RQSplineCouplingModule(blocks::Vector{F}) where F <: Function
return RQSplineCouplingModule(fchain(blocks))
end

"""
RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer,
K::Union{Integer, Vector{Integer}} = 10
)
Construct an instance of `RQSplineCouplingModule` for a `ǹ_dims` -dimensíonal input. Use `block_target_elements`
to specify which block in the module transforms which components of the input. Use `K` to specify the desired
number of spline segments used for the rational quadratic spline functions (defaults to 10).
Note: This constructor does not ensure each element of the input is transformed by a block. If desired, this
must be ensured in `block_target_elements`.
"""
function RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer,
K::Union{Integer, Vector{Integer}} = 10,
compute_unit::AbstractComputeUnit = CPUnit()
)
@argcheck K isa Integer || length(K) == length(block_target_elements) throw(DomainError(K, "please specify the same number of values for K as there are blocks"))

n_blocks = length(block_target_elements)
blocks = Vector{RQSplineCouplingBlock}(undef, n_blocks)
n_out_neural_net = K isa Vector ? 3 .* K .- 1 : 3 .* fill(K, n_blocks) .- 1

for i in 1:n_blocks
transformation_mask = fill(false, n_dims)
transformation_mask[block_target_elements[i]] .= true
neural_net = get_neural_net(n_dims - sum(transformation_mask), n_out_neural_net[i])
blocks[i] = RQSplineCouplingBlock(transformation_mask, neural_net, compute_unit)
end
return RQSplineCouplingModule(fchain(blocks))
end

function RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Integer = 1,
K::Union{Integer, Vector{Integer}} = 10
)

n_blocks = ceil(Integer, n_dims / block_target_elements)
vectorized_bte = [UnitRange(i + 1, i + block_target_elements) for i in range(start = 0, stop = (n_blocks - 2) * block_target_elements, step = block_target_elements)]
push!(vectorized_bte, UnitRange((n_blocks - 1) * block_target_elements + 1, n_dims))

RQSplineCouplingModule(n_dims, vectorized_bte, K)
end

function ChangesOfVariables.with_logabsdet_jacobian(
f::RQSplineCouplingModule,
x::Any
)
with_logabsdet_jacobian(f.flow_module, x)
end

(f::RQSplineCouplingModule)(x::Any) = f.flow_module(x)
(f::RQSplineCouplingModule)(vs::AbstractValueShape) = vs

function InverseFunctions.inverse(f::RQSplineCouplingModule)
return RQSplineCouplingModule(InverseFunctions.inverse(f.flow_module).fs)
end


struct RQSplineCouplingBlock <: AbstractFlowBlock
mask::Vector{Bool}
nn::Chain
nn_parameters::NamedTuple
Expand Down

0 comments on commit 78a89cd

Please sign in to comment.