Skip to content

Commit

Permalink
Merge pull request #1295 from FluxML/bc/kwargs-getindex-nokeygrad
Browse files Browse the repository at this point in the history
Handle nothing grads for `Pairs.data`
  • Loading branch information
ToucheSir authored Aug 29, 2022
2 parents 7e057d1 + 4183226 commit 0ede287
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ function _pullback(cx::AContext, ::typeof(literal_getindex),
ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K
val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K))
function kwargs_literal_getindex_pullback(Δ)
dps = (data = gf_back(Δ)[2], itr = nothing)
dps = (data = gradindex(gf_back(Δ), 2), itr = nothing)
return (nothing, dps, nothing)
end
return val, kwargs_literal_getindex_pullback
Expand Down
4 changes: 4 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,10 @@ end
h(somedata) = g(; somedata...)
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),)
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),)

# for when no kwargs have grads backpropogated
no_kwarg_grad(x; kwargs...) = x[kwargs[:i]]
@test gradient(x -> no_kwarg_grad(x; i=1), [1]) == ([1],)
end

@testset "Iterators" begin
Expand Down

0 comments on commit 0ede287

Please sign in to comment.