Skip to content

Commit

Permalink
Merge pull request #1 from FluxML/dg/grad
Browse files Browse the repository at this point in the history
Allow gradients in fmap
  • Loading branch information
ToucheSir authored Nov 21, 2021
2 parents d58d273 + b5b872b commit 40c547c
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 4 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "Zygote"]
27 changes: 27 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,29 @@ Equivalent to `functor(x)[1]`.
"""
children(x) = functor(x)[1]

function functor_tuple(f, x::Tuple, dx::Tuple)
map(x, dx) do x, x̄
_default_walk(f, x, x̄)
end
end
functor_tuple(f, x, dx) = f(x, dx)
functor_tuple(f, x, ::Nothing) = x

# @functor Chain
# Chain -> func = (layers = (Dense,Dense),), gs -> (layers...)
function _default_walk(f, x, dx)
func, re = functor(x)
map(func, dx) do x, x̄
# functor_tuple(f, x, x̄)
f(x, x̄)
end |> re
end

function _default_walk(f, x)
func, re = functor(x)
re(map(f, func))
end
_default_walk(f, ::Nothing, ::Nothing) = nothing

"""
fmap(f, x; exclude = isleaf, walk = Functors._default_walk)
Expand Down Expand Up @@ -205,3 +224,11 @@ function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
end
return output
end

# Allow gradients and other constructs that match the structure of the functor
# to allow for `map` style computations and return a modified version of the struct.
# This way we can use `fmap` to update the params with their gradients
function fmap(f, x, dx...; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Functors, Test
using Zygote

@testset "Functors.jl" begin

include("basics.jl")
include("base.jl")

include("basics.jl")
include("update.jl")
end
23 changes: 23 additions & 0 deletions test/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@testset "Generalized fmap over equivalent functors" begin
struct M{F,T,S}
σ::F
W::T
b::S
end

@functor M

(m::M)(x) = m.σ.(m.W * x .+ m.b)

m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
x = ones(Float32, 4, 2)
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
= Functors.fmap(m, m̄) do x, y
isnothing(x) && return y
isnothing(y) && return x
x .- 0.1f0 .* y
end

@test.W fill(0.8f0, size(m.W))
@test.b fill(-0.2f0, size(m.b))
end

0 comments on commit 40c547c

Please sign in to comment.