Skip to content

Commit

Permalink
Initial Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Aug 22, 2023
1 parent 5b9bd3f commit 222da06
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 2 deletions.
14 changes: 14 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ uuid = "24d2106d-e7e1-4641-aa0a-4a5934943aa1"
version = "0.1.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
HeterogeneousComputing = "2182be2a-124f-4a91-8389-f06db5907a21"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1.6"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![Codecov](https://codecov.io/gh/bat/AdaptiveFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/bat/AdaptiveFlows.jl)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)


This package aims to provide a toolkit for constructing modular and high performance [Normalizing Flows](https://arxiv.org/abs/1908.09257).
## Documentation

* [Documentation for stable version](https://bat.github.io/AdaptiveFlows.jl/stable)
Expand Down
19 changes: 18 additions & 1 deletion src/AdaptiveFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,23 @@ Adaptive normalizing flows.
"""
module AdaptiveFlows

include("adaptive_flow.jl")
using ArgCheck
using ArraysOfArrays
using ChangesOfVariables
using FunctionChains
using Functors
using HeterogeneousComputing
using InverseFunctions
using Lux
using MonotonicSplines
using Optimisers
using Random
using StatsFuns
using ValueShapes
using Zygote

include("adaptive_flow.jl")
include("optimize_flow.jl")
include("rqspline_coupling.jl")
include("utils.jl")
end # module
80 changes: 80 additions & 0 deletions src/adaptive_flow.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,83 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

abstract type AbstractFlow <: Function
end


struct CompositeFlow <: AbstractFlow
flow::Function
end

export CompositeFlow
@functor CompositeFlow

function CompositeFlow(modules::Vector{F}) where F <: Function
return CompositeFlow(fchain(modules))
end

function ChangesOfVariables.with_logabsdet_jacobian(
f::CompositeFlow,
x::Any
)
with_logabsdet_jacobian(f.flow, x)
end

(f::CompositeFlow)(x::Any) = f.flow(x)
(f::CompositeFlow)(vs::AbstractValueShape) = vs

function InverseFunctions.inverse(f::CompositeFlow)
return CompositeFlow(InverseFunctions.inverse(f.flow).fs)
end


struct RQSplineCouplingModule <: AbstractFlow
flow_module::Function
end

export RQSplineCouplingModule
@functor RQSplineCouplingModule


"""
RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer,
K::Union{Integer, Vector{Integer}} = 10
)
Construct an instance of `RQSplineCouplingModule` for a `ǹ_dims` -dimensíonal input. Use `block_target_elements`
to specify which block in the module transforms which components of the input. Use `K` to specify the desired
number of spline segments used for the rational quadratic spline functions (defaults to 10).
Note: This constructor does not ensure each element of the input is transformed by a block. If desired, this
must be ensured in `block_target_elements`.
"""
function RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Union{Vector{Vector{I}}, Vector{UnitRange{I}}} where I <: Integer,
K::Union{Integer, Vector{Integer}} = 10,
compute_unit::AbstractComputeUnit = CPUnit()
)
@argcheck K isa Integer || length(K) == length(block_target_elements) throw(DomainError(K, "please specify the same number of values for K as there are blocks"))

n_blocks = length(block_target_elements)
blocks = Vector{RQSplineCouplingBlock}(undef, n_blocks)
n_out_neural_net = K isa Vector ? 3 .* K .- 1 : 3 .* fill(K, n_blocks) .- 1

for i in 1:n_blocks
transformation_mask = fill(false, n_dims)
transformation_mask[block_target_elements[i]] .= true
neural_net = get_neural_net(n_dims - sum(transformation_mask), n_out_neural_net[i])
blocks[i] = RQSplineCouplingBlock(transformation_mask, neural_net, compute_unit)
end
return fchain(blocks)
end

function RQSplineCouplingModule(n_dims::Integer,
block_target_elements::Integer = 1,
K::Union{Integer, Vector{Integer}} = 10
)

n_blocks = ceil(Integer, n_dims / block_target_elements)
vectorized_bte = [UnitRange(i + 1, i + block_target_elements) for i in range(start = 0, stop = (n_blocks - 2) * block_target_elements, step = block_target_elements)]
push!(vectorized_bte, UnitRange((n_blocks - 1) * block_target_elements + 1, n_dims))

RQSplineCouplingModule(n_dims, vectorized_bte, K)
end
68 changes: 68 additions & 0 deletions src/optimize_flow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

std_normal_logpdf(x::Real) = -(abs2(x) + log2π)/2

function mvnormal_negll_flow(flow::Function, X::AbstractMatrix{<:Real})
nsamples = size(X, 2)

Y, ladj = with_logabsdet_jacobian(flow, X)
ll = (sum(std_normal_logpdf.(Y[flow.mask,:])) + sum(ladj)) / nsamples

return -ll
end

function mvnormal_negll_flow_grad(flow, X::AbstractMatrix{<:Real})
negll, back = Zygote.pullback(mvnormal_negll_flow, flow, X)
d_flow = back(one(eltype(X)))[1]
return negll, d_flow
end


function optimize_flow(smpls::VectorOfSimilarVectors{<:Real},
initial_flow,
optimizer;
nbatches::Integer = 100,
nepochs::Integer = 100,
optstate = Optimisers.setup(optimizer, deepcopy(initial_flow)),
negll_history = Vector{Float64}(),
shuffle_samples::Bool = false
)
batchsize = round(Int, length(smpls) / nbatches)
batches = collect(Iterators.partition(smpls, batchsize))
flow = deepcopy(initial_flow)
state = deepcopy(optstate)
negll_hist = Vector{Float64}()
for i in 1:nepochs
for batch in batches
negll, d_flow = mvnormal_negll_flow_grad(flow, flatview(batch))
state, flow = Optimisers.update(state, flow, d_flow)
push!(negll_hist, negll)
end
if shuffle_samples
batches = collect(Iterators.partition(shuffle(smpls), batchsize))
end
end
(result = flow, optimizer_state = state, negll_history = vcat(negll_history, negll_hist))
end
export optimize_flow

# temporary hack
function optimize_flow_sequentially(smpls::VectorOfSimilarVectors{<:Real},
initial_flow::CompositeFlow,
optimizer;
nbatches::Integer = 100,
nepochs::Integer = 100,
optstate = Optimisers.setup(optimizer, deepcopy(initial_flow)),
negll_history = Vector{Float64}(),
shuffle_samples::Bool = false
)

optimized_blocks = Vector{Function}(undef, length(initial_flow.flow.fs))
for block in initial_flow.flow.fs
res = optimize_flow(smpls, block, optimizer; nbatches, nepochs, optstate, negll_history, shuffle_samples)
optimized_blocks[i] = res.result
end
return CompositeFlow(optimized_blocks)
end

export optimize_flow_sequentially
74 changes: 74 additions & 0 deletions src/rqspline_coupling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

struct RQSplineCouplingBlock <: Function
mask::Vector{Bool}
nn::Chain
nn_parameters::NamedTuple
nn_state::NamedTuple
end

export RQSplineCouplingBlock
@functor RQSplineCouplingBlock

function RQSplineCouplingBlock(mask::Vector{Bool}, nn::Chain, compute_unit::AbstractComputeUnit)
rng = Random.default_rng()
Random.seed!(rng, 0)

lux_compute_unit = compute_unit isa CPUnit ? cpu_device() : gpu_device()

nn_parameters, nn_state = Lux.setup(rng, nn) .|> lux_compute_unit

return RQSplineCouplingBlock(mask, nn, nn_parameters, nn_state)
end


function ChangesOfVariables.with_logabsdet_jacobian(
f::RQSplineCouplingBlock,
x::Any
)
apply_rqs_coupling_flow(f, x)
end

(f::RQSplineCouplingBlock)(x::Any) = apply_rqs_coupling_flow(f, x)[1]
(f::RQSplineCouplingBlock)(vs::AbstractValueShape) = vs

function InverseFunctions.inverse(f::RQSplineCouplingBlock)
return InverseRQSplineCouplingBlock(f.nn, f.mask)
end


struct InverseRQSplineCouplingBlock <: Function
mask::Vector{Bool}
nn::Chain
end

export InverseRQSplineCouplingBlock
@functor InverseRQSplineCouplingBlock

function ChangesOfVariables.with_logabsdet_jacobian(
f::InverseRQSplineCouplingBlock,
x::Any
)
return apply_rqs_coupling_flow(f, x)
end

(f::InverseRQSplineCouplingBlock)(x::Any) = apply_rqs_coupling_flow(f, x)[1]
(f::InverseRQSplineCouplingBlock)(vs::AbstractValueShape) = vs

function InverseFunctions.inverse(f::InverseRQSplineCouplingBlock)
return RQSplineCouplingBlock(f.nn, f.mask)
end


function apply_rqs_coupling_flow(flow::Union{RQSplineCouplingBlock, InverseRQSplineCouplingBlock}, x::Any) # make x typestable

rq_spline = flow isa RQSplineCouplingBlock ? RQSpline : InvRQSpline
n_dims_to_transform = sum(flow.mask)

input_mask = .~flow.mask
y, ladj = with_logabsdet_jacobian(rq_spline(get_params(flow.nn(x[input_mask,:], flow.nn_parameters, flow.nn_state)[1], n_dims_to_transform)...), x[flow.mask,:])

return MonotonicSplines._sort_dimensions(y, x, flow.mask), ladj
end

export apply_rqs_coupling_flow
12 changes: 12 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

function get_neural_net(n_in::Integer,
n_out::Integer,
n_hidden_layers::Integer = 1,
n_in_hidden::Integer = 20
#compute_device::AbstractComputeDevice = CPUnit(),
)

layers = vcat([Dense(n_in, n_in_hidden, relu)], repeat([Dense(n_in_hidden, n_in_hidden, relu)], n_hidden_layers), [Dense(n_in_hidden, n_out)])
return Chain(layers...)
end

0 comments on commit 222da06

Please sign in to comment.