diff --git a/base/bitarray.jl b/base/bitarray.jl index afae953078d4d..f74a477b0260c 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -1510,6 +1510,36 @@ function findprev(testf::Function, B::BitArray, start::Integer) end #findlast(testf::Function, B::BitArray) = findprev(testf, B, 1) ## defined in array.jl +function findmax(a::BitArray) + isempty(a) && throw(ArgumentError("BitArray must be non-empty")) + m, mi = false, 1 + ti = 1 + ac = a.chunks + for i = 1:length(ac) + @inbounds k = trailing_zeros(ac[i]) + ti += k + k == 64 || return (true, @inbounds keys(a)[ti]) + end + return m, @inbounds keys(a)[mi] +end + +function findmin(a::BitArray) + isempty(a) && throw(ArgumentError("BitArray must be non-empty")) + m, mi = true, 1 + ti = 1 + ac = a.chunks + for i = 1:length(ac)-1 + @inbounds k = trailing_ones(ac[i]) + ti += k + k == 64 || return (false, @inbounds keys(a)[ti]) + end + l = Base._mod64(length(a)-1) + 1 + @inbounds k = trailing_ones(ac[end] & Base._msk_end(l)) + ti += k + k == l || return (false, @inbounds keys(a)[ti]) + return (m, @inbounds keys(a)[mi]) +end + # findall helper functions # Generic case (>2 dimensions) function allindices!(I, B::BitArray) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index e30faea9e67bc..57b48b424c70c 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -10,7 +10,7 @@ module LinearAlgebra import Base: \, /, *, ^, +, -, == import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech, asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos, - cosh, cot, coth, csc, csch, eltype, exp, findmax, findmin, fill!, floor, getindex, hcat, + cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat, getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech, setindex!, show, similar, sin, sincos, sinh, size, size_to_strides, sqrt, StridedReinterpretArray, diff --git a/stdlib/LinearAlgebra/src/bitarray.jl b/stdlib/LinearAlgebra/src/bitarray.jl index ea127ddfeebf4..3e38b073a992b 100644 --- a/stdlib/LinearAlgebra/src/bitarray.jl +++ b/stdlib/LinearAlgebra/src/bitarray.jl @@ -171,36 +171,6 @@ function istril(A::BitMatrix) return true end -function findmax(a::BitArray) - isempty(a) && throw(ArgumentError("BitArray must be non-empty")) - m, mi = false, 1 - ti = 1 - ac = a.chunks - for i = 1:length(ac) - @inbounds k = trailing_zeros(ac[i]) - ti += k - k == 64 || return (true, ti) - end - return m, mi -end - -function findmin(a::BitArray) - isempty(a) && throw(ArgumentError("BitArray must be non-empty")) - m, mi = true, 1 - ti = 1 - ac = a.chunks - for i = 1:length(ac)-1 - @inbounds k = trailing_ones(ac[i]) - ti += k - k == 64 || return (false, ti) - end - l = Base._mod64(length(a)-1) + 1 - @inbounds k = trailing_ones(ac[end] & Base._msk_end(l)) - ti += k - k == l || return (false, ti) - return m, mi -end - # fast 8x8 bit transpose from Henry S. Warrens's "Hacker's Delight" # http://www.hackersdelight.org/hdcodetxt/transpose8.c.txt function transpose8x8(x::UInt64) diff --git a/test/arrayops.jl b/test/arrayops.jl index 962d8a37149ee..896bf399f9d4e 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -567,6 +567,11 @@ end @test isnan(findmax([NaN, NaN, 0.0/0.0])[1]) @test findmax([NaN, NaN, 0.0/0.0])[2] == 1 + # Check that cartesian indices are returned for matrices + @test argmax([10 12; 9 11]) === CartesianIndex(1, 2) + @test argmin([10 12; 9 11]) === CartesianIndex(2, 1) + @test findmax([10 12; 9 11]) === (12, CartesianIndex(1, 2)) + @test findmin([10 12; 9 11]) === (9, CartesianIndex(2, 1)) end @testset "permutedims" begin diff --git a/test/bitarray.jl b/test/bitarray.jl index d1ac4bb5ed061..f3923df37c44b 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -1523,6 +1523,8 @@ timesofar("linalg") for b1 in [falses(v1), trues(v1), BitArray([1,0,1,1,0]), BitArray([0,0,1,1,0]), + BitArray([1 0; 1 1]), + BitArray([0 0; 1 1]), bitrand(v1)] @check_bit_operation findmin(b1) @check_bit_operation findmax(b1)