Skip to content

Commit

Permalink
arrays of arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 11, 2022
1 parent 337f365 commit 756b450
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
au, _ = functor(typeof(x), aux)
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
Tangent{typeof(x), typeof(y)}(y)
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
Tangent{typeof(x), typeof(y)}(y)
end
end

function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
Expand Down
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ trainable(x) = functor(x)[1]
_trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple"
map(c -> c in tr ? c : nothing, ch)
Expand Down
18 changes: 18 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ m4 = (x = m1, y = m1, z = collect(4:6.0))
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
m6 = (a = m1, b = [4.0 + im], c = m1)
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]

@testset "flatten & rebuild" begin
@test destructure(m1)[1] isa Vector{Float64}
Expand All @@ -31,12 +32,20 @@ m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0
@test m6′.a === m6′.c
@test m6′.b == [7 + 4im]

# struct, trainable
@test destructure(m7)[1] == 1:3
m7′ = destructure(m7)[2]([10,20,30])
@test m7′.a == (sin, [10,20,30])
@test m7′.b == (cos, [4,5,6])
@test m7′.c == (tan, [7,8,9])

@test destructure(m8)[1] == 1:5
m8′ = destructure(m8)[2](1:5)
@test m8′[1].x === m8′[1].y
@test m8′[2].b.y === false
@test m8′[3][1] == [5.0]

# errors
@test_throws Exception destructure(m7)[2]([10,20])
@test_throws Exception destructure(m7)[2]([10,20,30,40])
end
Expand All @@ -57,6 +66,11 @@ end
@test g6.a isa Vector{Float64}
@test g6.b == [0+im]

g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
@test g8[1].x == [2,4,6]
@test g8[2].b.x == [8]
@test g8[3] == [[10.0]]

@testset "second derivative" begin
@test_broken gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
Expand Down Expand Up @@ -90,6 +104,10 @@ end
@test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
@test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]

v8, re8 = destructure(m8)
@test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]

@testset "second derivative" begin
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
@test_broken gradient(collect(1:6.0)) do y
Expand Down

0 comments on commit 756b450

Please sign in to comment.