Skip to content

Commit

Permalink
simplify min/max init
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Jan 4, 2022
1 parent 941e3b9 commit e09c63d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 49 deletions.
47 changes: 11 additions & 36 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ end
# reducedim_initarray is called by
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), init)
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init::T) where {T} = reducedim_initarray(A, region, init, T)
# TODO: extend this to minimum and maximum
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, ::UndefInitializer, ::Type{R}) where {R} = similar(A,R,reduced_indices(A,region))
# TODO: better way to handle reducedim initialization
#
Expand Down Expand Up @@ -126,45 +125,21 @@ function _reducedim_init(f, op, fv, fop, A, region)
end

# initialization when computing minima and maxima requires a little care
for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin))
@eval function reducedim_init(f, op::typeof($f1), A::AbstractArray, region)
# First compute the reduce indices. This will throw an ArgumentError
# if any region is invalid
ri = reduced_indices(A, region)
function reducedim_init(f::F, ::Union{typeof(min),typeof(max)}, A::AbstractArray, region) where {F}
# First compute the reduce indices. This will throw an ArgumentError
# if any region is invalid
ri = reduced_indices(A, region)

# Next, throw if reduction is over a region with length zero
any(i -> isempty(axes(A, i)), region) && _empty_reduce_error()
# Next, throw if reduction is over a region with length zero
any(i -> isempty(axes(A, i)), region) && _empty_reduce_error()

# Make a view of the first slice of the region
A1 = view(A, ri...)
# Make a view of the first slice of the region
A1 = view(A, ri...)

if isempty(A1)
# If the slice is empty just return non-view version as the initial array
return copy(A1)
else
# otherwise use the min/max of the first slice as initial value
v0 = mapreduce(f, $f2, A1)

T = _realtype(f, promote_union(eltype(A)))
Tr = v0 isa T ? T : typeof(v0)

# but NaNs and missing need to be avoided as initial values
if (v0 == v0) === false
# v0 is NaN
v0 = $initval
elseif isunordered(v0)
# v0 is missing or a third-party unordered value
Tnm = nonmissingtype(Tr)
# TODO: Some types, like BigInt, don't support typemin/typemax.
# So a Matrix{Union{BigInt, Missing}} can still error here.
v0 = $typeextreme(Tnm)
end
# v0 may have changed type.
Tr = v0 isa T ? T : typeof(v0)
# calculate the output type
T = promote_typejoin_union(_return_type(f, Tuple{eltype(A)}))

return reducedim_initarray(A, region, v0, Tr)
end
end
map!(f, reducedim_initarray(A,region,undef,T), A1)
end
reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(max), A::AbstractArray{T}, region) where {T} =
reducedim_initarray(A, region, zero(f(zero(T))), _realtype(f, T))
Expand Down
20 changes: 7 additions & 13 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,35 +378,29 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1))
@test size(extrema(A,dims=(1,2,3))) == size(maximum(A,dims=(1,2,3)))
@test extrema(x->div(x, 2), A, dims=(2,3)) == reshape([(0,11),(1,12)],2,1,1)

# TODO: drop `a′` once `minimum` and `maximum` is fixed
# (the following test_broken pass)
function test_extrema(a, a′ = a; dims_test = ((), 1, 2, (1,2), 3))
function test_extrema(a; dims_test = ((), 1, 2, (1,2), 3))
for dims in dims_test
vext = extrema(a; dims)
vmin, vmax = minimum(a; dims), maximum(a; dims)
@test all(x -> isequal(x[1], x[2:3]), zip(vext,vmin,vmax)) || foreach(i -> display(i),(eltype(a), vext,vmin,vmax))
vmin, vmax = minimum(a; dims), maximum(a; dims)
@test all(x -> isequal(x[1], x[2:3]), zip(vext,vmin,vmax))
end
end
@test_broken minimum([missing BigInt(1)], dims = 2)[1] === missing
@testset "0.0,-0.0 test for extrema with dims" begin
@test extrema([-0.0;0.0], dims = 1)[1] === (-0.0,0.0)
@test tuple(extrema([-0.0;0.0], dims = 2)...) === ((-0.0, -0.0), (0.0, 0.0))
end
@testset "NaN/missing test for extrema with dims #43599" begin
for sz = (3, 10, 100)
for T in (Int, BigInt, Float64, BigFloat)
Aₘ = Matrix{Union{Float64, Missing}}(rand(-sz:sz, sz, sz))
Aₘ = Matrix{Union{T, Missing}}(rand(-sz:sz, sz, sz))
Aₘ[rand(1:sz*sz, sz)] .= missing
ATₘ = Matrix{Union{T, Missing}}(Aₘ)
test_extrema(ATₘ, Aₘ)
test_extrema(Aₘ)
if T <: AbstractFloat
Aₙ = map(i -> ismissing(i) ? T(NaN) : i, Aₘ)
ATₙ = map(i -> ismissing(i) ? T(NaN) : i, ATₘ)
test_extrema(ATₙ, Aₙ)
test_extrema(Aₙ)
p = rand(1:sz*sz, sz)
Aₘ[p] .= NaN
ATₘ[p] .= NaN
test_extrema(ATₘ, Aₘ)
test_extrema(Aₘ)
end
end
end
Expand Down

0 comments on commit e09c63d

Please sign in to comment.