From 24a6111c851c8dfab87628c5227b11a2dbf89648 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 8 Aug 2022 16:28:29 -0700 Subject: [PATCH] Treat Pairs(NamedTuple) as NamedTuple for indexing This prevents issues with double-counting when using kwargs. --- src/lib/base.jl | 28 ++++++++++++++++++++++++++-- test/features.jl | 17 ++++++++++++++--- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/lib/base.jl b/src/lib/base.jl index 79dfb77b6..21ca62b1c 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -119,11 +119,11 @@ end # named tuple @adjoint function pairs(t::NamedTuple{N}) where N - + pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,) pairs_namedtuple_pullback(dx::Tuple{}) = (NamedTuple(),) - + function pairs_namedtuple_pullback(Δ::Dict) t0 = map(zero, t) for (idx, v) in Δ @@ -145,6 +145,30 @@ else @adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...)) end +# Keyword arguments pretend to be a Dict, but are secretly wrapping a NamedTuple. +# We can treat them much the same, just with some plumbing to handle the extra `itr` field. +function _pullback(::AContext, ::typeof(getindex), + ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, k) + # So we don't close over kwarg values in the pullback + data = map(_ -> nothing, NamedTuple(ps)) + function kwargs_getindex_pullback(Δ) + dps = (data = Base.setindex(data, Δ, k), itr = nothing) + return (nothing, dps, nothing) + end + return ps[k], kwargs_getindex_pullback +end + +function _pullback(cx::AContext, ::typeof(literal_getindex), + ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K + val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K)) + function kwargs_literal_getindex_pullback(Δ) + dps = (data = gf_back(Δ)[2], itr = nothing) + return (nothing, dps, nothing) + end + return val, kwargs_literal_getindex_pullback +end + +# Misc. @adjoint function Base.getfield(p::Pair, i::Int) function pair_getfield_pullback(Δ) f, s = i == 1 ? (Δ, nothing) : (nothing, Δ) diff --git a/test/features.jl b/test/features.jl index d4f68d36b..4c16267f2 100644 --- a/test/features.jl +++ b/test/features.jl @@ -552,6 +552,17 @@ end @test gradient(x -> x[].a, Ref((a=1, b=2))) == ((x = (a = 1, b = nothing),),) @test gradient(x -> x[1][].a, [Ref((a=1, b=2)), Ref((a=3, b=4))]) == ([(x = (a = 1, b = nothing),), nothing],) @test gradient(x -> x[1].a, [(a=1, b=2), "three"]) == ([(a = 1, b = nothing), nothing],) + + @testset "indexing kwargs" begin + inner_lit_index(; kwargs...) = kwargs[:x] + outer_lit_index(; kwargs...) = inner_lit_index(; x=kwargs[:x]) + + inner_dyn_index(k; kwargs...) = kwargs[k] + outer_dyn_index(k; kwargs...) = inner_dyn_index(k; x=kwargs[k]) + + @test gradient(x -> outer_lit_index(; x), 0.0) == (1.0,) + @test gradient((x, k) -> outer_dyn_index(k; x), 0.0, :x) == (1.0, nothing) + end end function type_test() @@ -562,7 +573,7 @@ end @testset "Pairs" begin @test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0 - @test (x->10*pairs((a=x, b=2))[2])'(100) === 0 + @test (x->10*pairs((a=x, b=2))[2])'(100) === nothing foo(;kw...) = 1 @test gradient(() -> foo(a=1,b=2.0)) === () @@ -578,8 +589,8 @@ end @testset "kwarg splatting, pass in object" begin g(; kwargs...) = kwargs[:x] * kwargs[:z] h(somedata) = g(; somedata...) - @test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),) - @test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),) + @test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),) + @test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),) end @testset "Iterators" begin