Skip to content

Commit

Permalink
Merge pull request #30102 from JuliaLang/nl/findminmax
Browse files Browse the repository at this point in the history
Fix findmin/findmax to return cartesian indices for BitMatrix
  • Loading branch information
nalimilan authored Nov 28, 2018
2 parents 076ea45 + a9161d9 commit 402c747
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
30 changes: 30 additions & 0 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,36 @@ function findprev(testf::Function, B::BitArray, start::Integer)
end
#findlast(testf::Function, B::BitArray) = findprev(testf, B, 1) ## defined in array.jl

function findmax(a::BitArray)
isempty(a) && throw(ArgumentError("BitArray must be non-empty"))
m, mi = false, 1
ti = 1
ac = a.chunks
for i = 1:length(ac)
@inbounds k = trailing_zeros(ac[i])
ti += k
k == 64 || return (true, @inbounds keys(a)[ti])
end
return m, @inbounds keys(a)[mi]
end

function findmin(a::BitArray)
isempty(a) && throw(ArgumentError("BitArray must be non-empty"))
m, mi = true, 1
ti = 1
ac = a.chunks
for i = 1:length(ac)-1
@inbounds k = trailing_ones(ac[i])
ti += k
k == 64 || return (false, @inbounds keys(a)[ti])
end
l = Base._mod64(length(a)-1) + 1
@inbounds k = trailing_ones(ac[end] & Base._msk_end(l))
ti += k
k == l || return (false, @inbounds keys(a)[ti])
return (m, @inbounds keys(a)[mi])
end

# findall helper functions
# Generic case (>2 dimensions)
function allindices!(I, B::BitArray)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module LinearAlgebra
import Base: \, /, *, ^, +, -, ==
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos,
cosh, cot, coth, csc, csch, eltype, exp, findmax, findmin, fill!, floor, getindex, hcat,
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims,
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
setindex!, show, similar, sin, sincos, sinh, size, size_to_strides, sqrt, StridedReinterpretArray,
Expand Down
30 changes: 0 additions & 30 deletions stdlib/LinearAlgebra/src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,36 +171,6 @@ function istril(A::BitMatrix)
return true
end

function findmax(a::BitArray)
isempty(a) && throw(ArgumentError("BitArray must be non-empty"))
m, mi = false, 1
ti = 1
ac = a.chunks
for i = 1:length(ac)
@inbounds k = trailing_zeros(ac[i])
ti += k
k == 64 || return (true, ti)
end
return m, mi
end

function findmin(a::BitArray)
isempty(a) && throw(ArgumentError("BitArray must be non-empty"))
m, mi = true, 1
ti = 1
ac = a.chunks
for i = 1:length(ac)-1
@inbounds k = trailing_ones(ac[i])
ti += k
k == 64 || return (false, ti)
end
l = Base._mod64(length(a)-1) + 1
@inbounds k = trailing_ones(ac[end] & Base._msk_end(l))
ti += k
k == l || return (false, ti)
return m, mi
end

# fast 8x8 bit transpose from Henry S. Warrens's "Hacker's Delight"
# http://www.hackersdelight.org/hdcodetxt/transpose8.c.txt
function transpose8x8(x::UInt64)
Expand Down
5 changes: 5 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,11 @@ end
@test isnan(findmax([NaN, NaN, 0.0/0.0])[1])
@test findmax([NaN, NaN, 0.0/0.0])[2] == 1

# Check that cartesian indices are returned for matrices
@test argmax([10 12; 9 11]) === CartesianIndex(1, 2)
@test argmin([10 12; 9 11]) === CartesianIndex(2, 1)
@test findmax([10 12; 9 11]) === (12, CartesianIndex(1, 2))
@test findmin([10 12; 9 11]) === (9, CartesianIndex(2, 1))
end

@testset "permutedims" begin
Expand Down
2 changes: 2 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,8 @@ timesofar("linalg")
for b1 in [falses(v1), trues(v1),
BitArray([1,0,1,1,0]),
BitArray([0,0,1,1,0]),
BitArray([1 0; 1 1]),
BitArray([0 0; 1 1]),
bitrand(v1)]
@check_bit_operation findmin(b1)
@check_bit_operation findmax(b1)
Expand Down

0 comments on commit 402c747

Please sign in to comment.