From e38170dd3048c16521fcf11acb47055b4ed38f6a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 1 Nov 2023 23:02:09 -0400 Subject: [PATCH 1/3] in-place destructure --- src/Optimisers.jl | 2 +- src/destructure.jl | 73 +++++++++++++++++++++++++++++++++++++++++---- test/destructure.jl | 42 ++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 6 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 2e115c4..81db21a 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -13,7 +13,7 @@ include("utils.jl") include("adjust.jl") include("destructure.jl") -export destructure +export destructure, destructure! include("trainables.jl") export trainables diff --git a/src/destructure.jl b/src/destructure.jl index a628452..ccc9add 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -9,6 +9,8 @@ Copies all [`trainable`](@ref Optimisers.trainable), [`isnumeric`](@ref Optimise to a vector, and returns also a function which reverses this transformation. Differentiable. +See also [`destructure!`](@ref). + # Example ```jldoctest julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im]))) @@ -31,6 +33,36 @@ function destructure(x) flat, Restructure(x, off, len) end +""" + destructure!(model) -> vector, reconstructor + +This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model. +Requires that all trainable parameters in the model be mutable arrays! + +# Example +```jldoctest +julia> m = (x=[1.0, 2.0], y=(sin, Float32[3.0 4.0], cos)) + +julia> v, re! = destructure!(m) +([1.0, 2.0, 3.0, 4.0], Restructure!(NamedTuple, ..., 4)) + +julia> m === re!([3, 5, 7, 9]) # mutates the original m, and returns it +true + +julia> m +(x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos)) +``` +""" +function destructure!(x) + flat, off, len = _flatten(x) + flat, Restructure!(x, off, len) +end + +# function destructure!(flat::AbstractVector, x) +# flat, off, len = _flatten!(flat, x) +# flat, Restructure!(x, off, len) +# end + """ Restructure(Model, ..., length) @@ -55,12 +87,20 @@ struct Restructure{T,S} model::T offsets::S length::Int + mutate::Bool end -(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length) +Restructure(model, offsets, length) = Restructure(model, offsets, length, false) +Restructure!(model, offsets, length) = Restructure(model, offsets, length, true) + +(re::Restructure)(flat::AbstractVector) = re.mutate ? _rebuild!(re.model, re.offsets, flat, re.length) : _rebuild(re.model, re.offsets, flat, re.length) (re::Restructure)(x, flat::AbstractVector) = re(flat)(x) -Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") Base.length(re::Restructure) = re.length +function Base.show(io::IO, re::Restructure{T}) where T + print(io, "Restructure", re.mutate ? "!" : "") + print(io, "(", T.name.name, ", ..., ", re.length, ")") +end + # This flattens a model, and returns a web of offsets for later use: function _flatten(x) isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case @@ -75,6 +115,17 @@ function _flatten(x) isempty(arrays) && return Bool[], off, 0 return reduce(vcat, arrays), off, len[] end +# function _flatten!(flat, x) +# isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case +# len = Ref(0) +# off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y +# o = len[] +# copyto!(flat, o, _vec(y)) +# len[] = o + length(y) +# o +# end +# flat, off, len[] +# end struct TrainableStructWalk <: AbstractWalk end @@ -97,10 +148,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trai _getat(y, o, flat) end end +# (mutating version, same arguments & same return) +function _rebuild!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...) + len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))")) + fmap(x, off; exclude = isnumeric, walk, kw...) do y, o + copyto!(y, _getat(y, o, flat, view)) + end + x +end -_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) -_getat(y::AbstractArray, o::Int, flat::AbstractVector) = - ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes +_getat(y::Number, o::Int, flat::AbstractVector, _...) = ProjectTo(y)(flat[o + 1]) +_getat(y::AbstractArray, o::Int, flat::AbstractVector, get=getindex) = + ProjectTo(y)(reshape(get(flat, o .+ (1:length(y))), axes(y))) # ProjectTo is just correcting eltypes struct _Trainable_biwalk <: AbstractWalk end @@ -135,6 +194,10 @@ function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...) _rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT) _rebuild(x, off, flat, len; kw...), _rebuild_back end +function ChainRulesCore.rrule(::typeof(_rebuild!), x, off, flat, len; kw...) + _rebuild!_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT) + _rebuild!(x, off, flat, len; kw...), _rebuild!_back +end _zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad! ChainRulesCore.@non_differentiable _zero(x) diff --git a/test/destructure.jl b/test/destructure.jl index 232a900..ff742cb 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -24,8 +24,10 @@ m9 = (a = m1, b = mat, c = [mat, m1]) @test destructure(m9)[1] == 1:7 @test destructure(m1)[2](7:9) == [7,8,9] + @test m1 == 1:3 # not mutated @test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9]) @test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9]) + @test m3.z == 4:6 # not mutated m4′ = destructure(m4)[2](4:9) @test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9]) @test m4′.x === m4′.y @@ -60,11 +62,31 @@ m9 = (a = m1, b = mat, c = [mat, m1]) @test_throws Exception destructure(m7)[2]([10,20,30,40]) end +@testset "destructure!" begin + m3′ = deepcopy(m3) + @test destructure!(m3′)[1] == 1:6 + @test destructure!(m3′)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9]) + @test m3′ == (x = [4,5,6], y = sin, z = [7,8,9]) + + m7′ = deepcopy(m7) + @test destructure!(m7′)[1] == 1:3 + destructure!(m7′)[2]([10,20,30]) + @test m7′.a == (sin, [10,20,30]) + @test m7′.b == (cos, [4,5,6]) + @test m7′.c == (tan, [7,8,9]) + + # errors + @test_throws Exception destructure!(m7)[2]([10,20]) + @test_throws Exception destructure!(m7)[2]([10,20,30,40]) +end + @testset "gradient of flatten" begin @test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] + @test gradient(m -> destructure!(m)[1][1], m1)[1] == [1,0,0] @test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) @test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) @test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test gradient(m -> destructure!(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) @test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) g5 = gradient(m -> destructure(m)[1][3], m5)[1] @@ -206,6 +228,26 @@ end end end +@testset "gradient of rebuild!" begin + re1 = destructure!(deepcopy(m1))[2] + @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] + + re2 = destructure!(deepcopy(m2))[2] + @test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] + + re3 = destructure!(deepcopy(m3))[2] + @test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] + @test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + + re4 = destructure!(deepcopy(m4))[2] + @test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] + @test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] + @test gradient(rand(6)) do x + m = re4(x) + m.x[1] + 2*m.y[2] + 3*m.z[3] + end[1] == [1,2,0, 0,0,3] +end + @testset "Flux issue 1826" begin v, re = destructure((x=[1,2.0], y=[3,4,5.0])) @test gradient(zero(v)) do w From 058a25b7f7c88e4e0c43421e0e8e091c9f4954d7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:30:09 -0400 Subject: [PATCH 2/3] use 5-arg copyto --- src/destructure.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index ccc9add..baea015 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -152,14 +152,18 @@ end function _rebuild!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...) len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))")) fmap(x, off; exclude = isnumeric, walk, kw...) do y, o - copyto!(y, _getat(y, o, flat, view)) + # copyto!(y, _getat_view(y, o, flat)) + copyto!(y, 1, flat, o+1, length(y)) end x end -_getat(y::Number, o::Int, flat::AbstractVector, _...) = ProjectTo(y)(flat[o + 1]) -_getat(y::AbstractArray, o::Int, flat::AbstractVector, get=getindex) = - ProjectTo(y)(reshape(get(flat, o .+ (1:length(y))), axes(y))) # ProjectTo is just correcting eltypes +_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) +_getat(y::AbstractArray, o::Int, flat::AbstractVector) = + ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes + +# _getat_view(y::AbstractArray, o::Int, flat::AbstractVector) = +# view(flat, o .+ (1:length(y))) struct _Trainable_biwalk <: AbstractWalk end From de381dd676c462b8ab39196553e1af5e25ff2be3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:50:04 -0400 Subject: [PATCH 3/3] restore 2-arg version, and add scary warning --- src/destructure.jl | 50 ++++++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index baea015..bc83ff1 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -35,9 +35,22 @@ end """ destructure!(model) -> vector, reconstructor + destructure!(vector, model) -> vector, reconstructor -This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model. -Requires that all trainable parameters in the model be mutable arrays! +These are variants of [`destructure`](@ref), returning a reconstruction function +which mutates the original model, instead of making a new one. +The second method also mutates an existing flat vector. + +They require that all trainable parameters in the model be mutable arrays, +else `re!` will give an error. + +!!! warning "Gradients" + Despite using mutation, they should be safe to use within Zygote, + with the important caveat that you must use the model returned, `m2 = re!(v)`, not the original. + Even though `m2 === m`, for Zygote to trace what results are used where, it has to see + the returned object being used. + If you discard `m2` and call for example `Flux.mse(m(x), y)` with the original model `m`, + Zygote will give silently wrong results. # Example ```jldoctest @@ -51,6 +64,9 @@ true julia> m (x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos)) + +julia> v2, re2! = destructure!(rand(4), m) # works the same way +([3.0, 5.0, 7.0, 9.0], Restructure!(NamedTuple, ..., 4)) ``` """ function destructure!(x) @@ -58,10 +74,10 @@ function destructure!(x) flat, Restructure!(x, off, len) end -# function destructure!(flat::AbstractVector, x) -# flat, off, len = _flatten!(flat, x) -# flat, Restructure!(x, off, len) -# end +function destructure!(flat::AbstractVector, x) + flat, off, len = _flatten!(flat, x) + flat, Restructure!(x, off, len) +end """ Restructure(Model, ..., length) @@ -115,17 +131,17 @@ function _flatten(x) isempty(arrays) && return Bool[], off, 0 return reduce(vcat, arrays), off, len[] end -# function _flatten!(flat, x) -# isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case -# len = Ref(0) -# off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y -# o = len[] -# copyto!(flat, o, _vec(y)) -# len[] = o + length(y) -# o -# end -# flat, off, len[] -# end +function _flatten!(flat, x) + isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case + len = Ref(0) + off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y + o = len[] + copyto!(flat, o+1, _vec(y)) + len[] = o + length(y) + o + end + flat, off, len[] +end struct TrainableStructWalk <: AbstractWalk end