diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 156db42b9..7e1befd14 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -111,6 +111,17 @@ Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) +function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N} + if !ChainRulesCore.is_inplaceable_destination(xs) + xs = collect(xs) + end + xs[oe.ind...] += oe.val + return xs +end + +Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe) +Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe) +Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2) """ _setindex_zero(x, dy, inds...)