Skip to content

Commit

Permalink
∇getindex(::AbstractZero) paths
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Dec 21, 2022
1 parent c579410 commit 7bb7d98
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...)
end

function rrule(::typeof(getindex), x::AbstractArray, inds...)
function getindex_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
end
nots = map(Returns(NoTangent()), inds)
getindex_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
getindex_pullback(z::AbstractZero) = (NoTangent(), z, nots...)
return x[inds...], getindex_pullback
end

Expand All @@ -90,6 +89,7 @@ function ∇getindex(x::AbstractArray, dy, inds...)
∇getindex!(dx, dy, plain_inds...)
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
end
∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z

"""
_setindex_zero(x, dy, inds...)
Expand Down Expand Up @@ -191,10 +191,9 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...)
end

function rrule(::typeof(view), x::AbstractArray, inds...)
function view_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
end
nots = map(Returns(NoTangent()), inds)
view_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
view_pullback(z::AbstractZero) = (NoTangent(), z, nots...)
return view(x, inds...), view_pullback
end

Expand Down

0 comments on commit 7bb7d98

Please sign in to comment.