Skip to content

Commit

Permalink
cannot fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 19, 2024
1 parent 6f053f3 commit d6d761b
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 177 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.3.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
104 changes: 70 additions & 34 deletions src/trainables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,37 @@ function ChainRulesCore.rrule(::typeof(_trainables_nt), model)
return ps, _trainables_nt_back
end


struct TrainableNamedTupleWalk <: AbstractWalk end

function (::TrainableNamedTupleWalk)(recurse, x)
ch = trainable(x)
y = map(recurse, make_named_tuple(ch))
return y
end

struct TrainableNamedTupleBackWalk <: AbstractWalk end

function (::TrainableNamedTupleBackWalk)(recurse, model, Δps)
# @show 1 typeof(model) typeof(Δps)
ch = trainable(model)
Δ = unmake_named_tuple(ch, Δps)
# @show 2 typeof(ch) typeof(Δ)
Δ === nothing && return nothing
Δ === ZeroTangent() && return ZeroTangent()
y = mapvalue(recurse, ch, Δ)
# @show 3 typeof(model) typeof(ch) typeof(Δ) typeof(y)
return y
end


struct RestructureFromNT{T}
x::T
end

(re::RestructureFromNT)(ps) = rebuild_from_nt(re.x, ps)
(re::RestructureFromNT)(ps) = restructure_from_nt(re.x, ps)

function rebuild_from_nt(model, ps)
function restructure_from_nt(model, ps)
walk = RestructureFromNamedTupleWalk()
return fmap(model, ps; exclude=isnumeric, walk, cache=nothing) do x, p
return p
Expand All @@ -201,51 +225,57 @@ end
struct RestructureFromNamedTupleWalk <: AbstractWalk end

function (::RestructureFromNamedTupleWalk)(recurse, x, nt)
@show 1 x nt
children, re = functor(x)
@show 2 children
newchildren = map_commons(recurse, children, nt)
@show 3 x nt children newchildren
return re(newchildren)
end

function ChainRulesCore.rrule(::typeof(rebuild_from_nt), x, ps)
model = rebuild_from_nt(x, ps)
function rebuild_from_nt_back(Δmodel_raw)
function ChainRulesCore.rrule(::typeof(restructure_from_nt), x, ps)
model = restructure_from_nt(x, ps)
proj_ps = ProjectTo(ps)

function restructure_from_nt_back(Δmodel_raw)
Δmodel = unthunk(Δmodel_raw)
walk = RestructureFromNamedTupleWalk()
Δps = fmap(ps, Δmodel; exclude=isnumeric, walk, cache=nothing) do p, Δ
walk = RestructureFromNamedTupleBackWalk()
function exclude(x)
@show "exclude" x isnumeric(x)
# i += 1
# return i > 1
return isnumeric(x)
end
Δps = fmap(ps, Δmodel; exclude, walk, cache=nothing) do p, Δ
@show "fmap" Δ p

return Δ
end
return (NoTangent(), NoTangent(), Δps)
@show "rrule" Δmodel x ps Δps
@show typeof(Δmodel) typeof(ps) typeof(Δps)
Δps = (_1=ones(3), _2=zeros(3))
Δpst = Tangent{typeof(Δps)}(; Δps...)
# pR
return (NoTangent(), NoTangent(), proj_ps(Δpst))
end
return model, rebuild_from_nt_back
return model, restructure_from_nt_back
end


struct TrainableNamedTupleBackWalk <: AbstractWalk end

function (::TrainableNamedTupleBackWalk)(recurse, model, Δps)
@show 1 typeof(model) typeof(Δps)
ch = trainable(model)
Δ = unmake_named_tuple(ch, Δps)
@show 2 typeof(ch) typeof(Δ)
Δ === nothing && return nothing
Δ === ZeroTangent() && return ZeroTangent()
y = mapvalue(recurse, ch, Δ)
@show 3 typeof(model) typeof(ch) typeof(Δ) typeof(y)
struct RestructureFromNamedTupleBackWalk <: AbstractWalk end

function (::RestructureFromNamedTupleBackWalk)(recurse, ps, Δmodel)
@show 1 typeof(Δmodel) typeof(ps)
Δm = make_named_tuple(Δmodel)
@show 2 typeof(Δm) ps Δm
# Δm isa Float64 && return Δm
# Δm isa Array && return Δm
# ps isa Float64 && return ps
# ps isa Array && return ps
# return nothing
Δm === nothing && return nothing
Δm === ZeroTangent() && return ZeroTangent()
y = mapvalue(recurse, ps, Δm)
@show 3 typeof(Δmodel) typeof(Δm) typeof(y)
return y
end

struct TrainableNamedTupleWalk <: AbstractWalk end

function (::TrainableNamedTupleWalk)(recurse, x)
ch = trainable(x)
y = map(recurse, make_named_tuple(ch))
return y
end


function map_commons(f, x::NamedTuple{xkeys}, y) where {xkeys}
ykeys = propertynames(y)
vals = map(k -> k in ykeys ? f(x[k], getproperty(y, k)) : x[k], xkeys)
Expand All @@ -270,12 +300,18 @@ function map_commons(f, x::Vector, y)
return vals
end

make_named_tuple(x::NamedTuple) = x
make_named_tuple(x) = x
make_named_tuple(x::AbstractDict{Symbol}) = NamedTuple(x)
make_named_tuple(x::AbstractDict) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x))
make_named_tuple(x::Tuple) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)
make_named_tuple(x::Vector) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)

make_named_tuple(x::Tangent{<:Any,<:NamedTuple}) = x
make_named_tuple(x::Tangent{<:Any,<:AbstractDict{Symbol}}) = NamedTuple(x)
make_named_tuple(x::Tangent{<:Any,<:AbstractDict}) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x))
make_named_tuple(x::Tangent{<:Any,<:Tuple}) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)
make_named_tuple(x::Tangent{<:Any,<:Vector}) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)


unmake_named_tuple(x::NamedTuple, ps) = ps

Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@

mapvalue(f, x...) = map(f, x...)
mapvalue(f, x::NamedTuple, ys::NamedTuple...) = map(f, x, ys...)
mapvalue(f, x, y::NamedTuple{ykeys}) where {ykeys} =
NamedTuple{ykeys}((f(getproperty(x ,k), yk) for (k, yk) in pairs(y))) # used in rrule for restructure_from_nt

mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

# without theses, tuples are returned instead of NamedTuples
Expand Down
Loading

0 comments on commit d6d761b

Please sign in to comment.