Skip to content

Commit

Permalink
Add ScaleShiftModule and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Sep 1, 2023
1 parent 2c7c0c4 commit 0a94124
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 2 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,26 @@ uuid = "24d2106d-e7e1-4641-aa0a-4a5934943aa1"
version = "0.1.0"

[deps]
AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
HeterogeneousComputing = "2182be2a-124f-4a91-8389-f06db5907a21"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AffineMaps = "0.1, 0.2"
ArgCheck = "2"
ArraysOfArrays = "0.5.1, 0.6"
ChangesOfVariables = "0.1.3"
Expand All @@ -28,7 +32,7 @@ HeterogeneousComputing = "0.1, 0.2"
InverseFunctions = "0.1"
Lux = "0.5"
MonotonicSplines = "0.1"
Optimisers = "0.2"
Optimisers = "0.2, 0.3"
StatsFuns = "1"
ValueShapes = "0.8.3, 0.9, 0.10"
Zygote = "0.6"
Expand Down
4 changes: 4 additions & 0 deletions src/AdaptiveFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,27 @@ Adaptive normalizing flows.
"""
module AdaptiveFlows

using AffineMaps
using ArgCheck
using ArraysOfArrays
using ChangesOfVariables
using FunctionChains
using Functors
using HeterogeneousComputing
using InverseFunctions
using LinearAlgebra
using Lux
using MonotonicSplines
using Optimisers
using Random
using Statistics
using StatsFuns
using ValueShapes
using Zygote

include("adaptive_flows.jl")
include("optimize_flow.jl")
include("rqspline_coupling.jl")
include("scale_shift.jl")
include("utils.jl")
end # module
10 changes: 10 additions & 0 deletions src/adaptive_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ function InverseFunctions.inverse(f::CompositeFlow)
return CompositeFlow(InverseFunctions.inverse(f.flow).fs)
end

function prepend_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow
return CompositeFlow([new_module, f.flow.fs...])
end
export prepend_flow_module

function append_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow
return CompositeFlow([f.flow.fs..., new_module])
end
export append_flow_module

"""
AbstractFlowModule <: AbstractFlow
Expand Down
34 changes: 34 additions & 0 deletions src/scale_shift.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

struct ScaleShiftModule <: AbstractFlowModule
A::Matrix{Real}
b::Vector{Real}
end

export ScaleShiftModule
@functor ScaleShiftModule

function ScaleShiftModule(stds::AbstractVector, means::AbstractVector)
A = Diagonal(inv.(stds))
return ScaleShiftModule(A, .- A * means)
end

function ScaleShiftModule(x::AbstractArray)
stds = vec(std(x, dims = 2))
means = vec(mean(x, dims = 2))
ScaleShiftModule(stds, means)
end

function ChangesOfVariables.with_logabsdet_jacobian(f::ScaleShiftModule, x::Any)
y, ladj = ChangesOfVariables.with_logabsdet_jacobian(MulAdd(f.A, f.b), x)

return y, fill(ladj, 1, size(y,2))
end

(f::ScaleShiftModule)(x::AbstractMatrix) = MulAdd(f.A, f.b)(x)
(f::ScaleShiftModule)(vs::AbstractValueShape) = vs

function InverseFunctions.inverse(f::ScaleShiftModule)
A = inv(f.A)
return ScaleShiftModule(A, .- A * f.b)
end
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Expand All @@ -7,9 +8,11 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
HeterogeneousComputing = "2182be2a-124f-4a91-8389-f06db5907a21"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"

Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import Test

Test.@testset "Package AdaptiveFlows" begin
include("test_aqua.jl")
include("test_adaptive_flows.jl")
include("test_aqua.jl")
include("test_docs.jl")
include("test_scale_shift.jl")
include("test_optimize_flow.jl")
include("test_rqspline_coupling.jl")
end # testset
5 changes: 5 additions & 0 deletions test/test_adaptive_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ x = randn(rng, n_dims, n_smpls)
vs_test = valshape(x)

comp_flow_test = CompositeFlow([RQSplineCouplingModule(4), RQSplineCouplingModule(4)])
prepended_flow_test = prepend_flow_module(comp_flow_test, ScaleShiftModule(ones(4), zeros(4)))
appended_flow_test = append_flow_module(comp_flow_test, ScaleShiftModule(ones(4), zeros(4)))

# test outputs
# comp_flow_y_test, comp_flow_ladj_test = with_logabsdet_jacobian(comp_flow_test, x)
Expand All @@ -32,4 +34,7 @@ comp_flow_ladj_test = readdlm("test_outputs/comp_flow_ladj_test.txt")

@test all(isapprox.(ChangesOfVariables.with_logabsdet_jacobian(InverseFunctions.inverse(comp_flow_test), comp_flow_y_test), (x, .- comp_flow_ladj_test)))
@test isapprox(InverseFunctions.inverse(comp_flow_test)(comp_flow_y_test), x)

@test prepended_flow_test.flow.fs[1] isa ScaleShiftModule
@test appended_flow_test.flow.fs[end] isa ScaleShiftModule
end
38 changes: 38 additions & 0 deletions test/test_scale_shift.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

using AdaptiveFlows
using Test

using ArraysOfArrays
using InverseFunctions
using LinearAlgebra
using Random
using Statistics
using ValueShapes

# test inputs
n_dims = 4
n_smpls = 10

rng = MersenneTwister(1234)
x = muladd(Diagonal(randn(rng, n_dims)), randn(rng, n_dims, n_smpls), randn(rng, n_dims))

smpls = nestedview(x)
vs_test = valshape(x)

scale_shift_test = ScaleShiftModule(x)
inv_scale_shift_test = inverse(scale_shift_test)

y_test = scale_shift_test(x)
x_inverted_test = inv_scale_shift_test(y_test)

stds_test = vec(std(y_test, dims = 2))
means_test = vec(mean(y_test, dims = 2))

@testset "ScaleShiftModule" begin
@test all(isapprox.(stds_test, 1)) && all(isapprox.(means_test, 0, atol = 1f-15))
@test all(isapprox.(x_inverted_test, x))

@test scale_shift_test(vs_test) == vs_test
@test all(isapprox.(with_logabsdet_jacobian(scale_shift_test, x)[2], 10.637223371435223))
end

0 comments on commit 0a94124

Please sign in to comment.