Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate minimum(A; dims = 1) for cartesian indexed cases. #43618

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ _axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = ()

@inline Base.axes(bc::Broadcasted{<:Any, <:NTuple{N}}, d::Integer) where N =
d <= N ? axes(bc)[d] : OneTo(1)
@inline Base.axes(bc::Broadcasted{<:Any, Nothing}, d::Integer) =
(ax = axes(bc); d <= length(ax) ? ax[d] : OneTo(1))

Base.axes1(bc::Broadcasted) = axes(bc, 1)
Base.axes1(bc::Broadcasted{<:AbstractArrayStyle{0}}) = OneTo(1)

BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
Expand Down
17 changes: 14 additions & 3 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ mapreduce_empty(::typeof(identity), op, T) = reduce_empty(op, T)
mapreduce_empty(::typeof(abs), op, T) = abs(reduce_empty(op, T))
mapreduce_empty(::typeof(abs2), op, T) = abs2(reduce_empty(op, T))

mapreduce_empty(f::typeof(abs), ::typeof(max), T) = abs(zero(T))
mapreduce_empty(f::typeof(abs2), ::typeof(max), T) = abs2(zero(T))
mapreduce_empty(::typeof(abs), ::typeof(max), T) = abs(zero(T))
mapreduce_empty(::typeof(abs2), ::typeof(max), T) = abs2(zero(T))

# For backward compatibility:
mapreduce_empty_iter(f, op, itr, ItrEltype) =
Expand Down Expand Up @@ -412,6 +412,10 @@ mapreduce_first(f, op, x) = reduce_first(op, f(x))

_mapreduce(f, op, A::AbstractArrayOrBroadcasted) = _mapreduce(f, op, IndexStyle(A), A)

_mapreduce(f, op, A::AbstractVector) = _mapreduce(f, op, IndexLinear(), A)

_mapreduce(f, op, A::AbstractZeroDimArray) = mapreduce_first(f, op, A[])

function _mapreduce(f, op, ::IndexLinear, A::AbstractArrayOrBroadcasted)
inds = LinearIndices(A)
n = length(inds)
Expand All @@ -437,7 +441,14 @@ end

mapreduce(f, op, a::Number) = mapreduce_first(f, op, a)

_mapreduce(f, op, ::IndexCartesian, A::AbstractArrayOrBroadcasted) = mapfoldl(f, op, A)
function _mapreduce(f, op, ::IndexCartesian, A::AbstractArrayOrBroadcasted)
(isempty(A) || length(axes1(A)) < 16) && return mapfoldl(f, op, A)
mapfoldl(op, CartesianIndices(tail(axes(A)))) do IA
@inline elf(i) = @inbounds f(A[i, IA])
ax1 = axes1(A)
mapreduce_impl(elf, op, ax1, firstindex(ax1), lastindex(ax1))
end
end

"""
reduce(op, itr; [init])
Expand Down
17 changes: 12 additions & 5 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,18 +267,25 @@ function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArrayOrBroadcasted)
if reducedim1(R, A)
# keep the accumulator as a local variable when reducing along the first dimension
i1 = first(axes1(R))
ax1 = axes1(A)
@inbounds for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
r = R[i1,IR]
@simd for i in axes(A, 1)
r = op(r, f(A[i, IA]))
if op === min || op === max #|| op === _extrema_op
elf(i) = @inbounds f(A[i, IA])
r = mapreduce_impl(elf, op, ax1, firstindex(ax1), lastindex(ax1))
R[i1,IR] = op(R[i1,IR], r)
else
r = R[i1,IR]
@simd for i in ax1
r = op(r, f(A[i, IA]))
end
R[i1,IR] = r
end
R[i1,IR] = r
end
else
@inbounds for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
@simd for i in axes(A, 1)
@simd for i in axes1(A)
R[i,IR] = op(R[i,IR], f(A[i,IA]))
end
end
Expand Down
16 changes: 14 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -911,9 +911,9 @@ end
@test IndexStyle(bc) == IndexLinear()
@test reduce(paren, bc) == reduce(paren, xs)
# If `Broadcasted` does not have `IndexLinear` style, it should
# hit the `foldl` branch:
# behave like a cartesian-indexed Array (PR #43618)
@test IndexStyle(bcraw) == IndexCartesian()
@test reduce(paren, bcraw) == foldl(paren, xs)
@test reduce(paren, bcraw) == reduce(paren, view(xs, 1:length(xs), 1:1))

# issue #41055
bc = Broadcast.instantiate(Broadcast.broadcasted(Base.literal_pow, Ref(^), [1,2], Ref(Val(2))))
Expand Down Expand Up @@ -1079,3 +1079,15 @@ end
y = randn(2)
@inferred(test(x, y)) == [0, 0]
end

@testset "axes1 and axes" begin
bc = Base.broadcasted(+, reshape(1:6,3,:), 1)
@test Base.axes(bc) == (Base.OneTo(3),Base.OneTo(2))
@test Base.axes1(bc) == axes(bc, 1) == Base.OneTo(3)
bc = Broadcast.instantiate(bc)
@test Base.axes(bc) == (Base.OneTo(3),Base.OneTo(2))
@test Base.axes1(bc) == axes(bc, 1) == Base.OneTo(3)
bc = Base.broadcasted(+, fill(1), 1)
@test Base.axes(bc) == ()
@test Base.axes1(bc) == axes(bc, 1) == Base.OneTo(1)
end