Skip to content

Commit

Permalink
Merge pull request FluxML#962 from mcabbott/getindex2
Browse files Browse the repository at this point in the history
RFC: more efficient `∇getindex`
  • Loading branch information
oxinabox authored May 25, 2021
2 parents bcc3921 + 1f492a7 commit 7c66eff
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
22 changes: 20 additions & 2 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ end

@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)

∇getindex(x::AbstractArray, inds) = dy -> begin
if inds isa NTuple{<:Any, Integer}
∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
if inds isa NTuple{N,Int} && T <: Number
dx = OneElement(dy, inds, axes(x))
elseif inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
Expand All @@ -44,6 +46,22 @@ end
return (dx, map(_->nothing, inds)...)
end

"""
OneElement(val, ind, axes) <: AbstractArray
Extremely simple `struct` used for the gradient of scalar `getindex`.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A} where {N} = new{T,N,I,A}(val, ind, axes)
end
Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))


_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing)
Expand Down
4 changes: 2 additions & 2 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ accum(x, y) =

accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)
accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)

@generated function accum(x::NamedTuple, y::NamedTuple)
# assumes that y has no keys apart from those also in x
Expand Down
18 changes: 18 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,24 @@ end
@test x[1] == x[2]
end

@testset "accumulation" begin
# from https://github.com/FluxML/Zygote.jl/issues/905
function net(x1)
x2 = x1
x3 = x1 + x2
x4 = x1 + x2 + x3
x5 = x1 + x2 + x3 + x4
x6 = x1 + x2 + x3 + x4 + x5
x7 = x1 + x2 + x3 + x4 + x5 + x6
x8 = x1 + x2 + x3 + x4 + x5 + x6 + x7
x9 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
end
loss(x) = sum(abs2, net(x))
@test gradient(loss, ones(10,10))[1] == fill(131072, 10, 10)
@test 150_000_000 > @allocated gradient(loss, ones(1000,1000))
end

@testset "tuples & broadcasting" begin
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
Expand Down

0 comments on commit 7c66eff

Please sign in to comment.