Skip to content

Commit

Permalink
optimize +
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed May 22, 2023
1 parent 68d01c9 commit a706ae3
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))

function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N}
if !ChainRulesCore.is_inplaceable_destination(xs)
xs = collect(xs)
end
xs[oe.ind...] += oe.val
return xs
end

Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe)
Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe)
Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2)

"""
_setindex_zero(x, dy, inds...)
Expand Down

0 comments on commit a706ae3

Please sign in to comment.