Skip to content

Commit

Permalink
Rework find functions (#8)
Browse files Browse the repository at this point in the history
* Change the bounds checking behaviour of the find* functions to match those of
  `Vector`.
* Add an optimised generic fallback which, unlike the AbstractArray fallbacl,
  does not boundscheck in its loop body
* Add a fastpath for findprev to dispatch to Libc's memrchr
* More thoroughly test find functions
  • Loading branch information
jakobnissen authored Dec 11, 2024
1 parent cbab8ab commit 83b4218
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 27 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ authors = ["Jakob Nybo Nissen <jakobnybonissen@gmail.com>"]
[weakdeps]
StringViews = "354b36f9-a18e-4713-926e-db85100087ba"

[extensions]
StringViewsExt = "StringViews"

[compat]
Aqua = "0.8.7"
StringViews = "1"
Test = "1.11"
julia = "1.11"

[extensions]
StringViewsExt = "StringViews"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
StringViews = "354b36f9-a18e-4713-926e-db85100087ba"
Expand Down
114 changes: 103 additions & 11 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ function Base.getindex(v::MemoryView, i::Integer)
@inbounds ref[]
end

function Base.similar(
mem::MemoryView{T1, M},
::Type{T2},
dims::Tuple{Int},
) where {T1, T2, M}
function Base.similar(::MemoryView{T1, M}, ::Type{T2}, dims::Tuple{Int}) where {T1, T2, M}
len = Int(only(dims))::Int
memory = Memory{T2}(undef, len)
MemoryView{T2, M}(unsafe, memoryref(memory), len)
Expand Down Expand Up @@ -89,6 +85,23 @@ end
Base.getindex(v::MemoryView, ::Colon) = v
Base.view(v::MemoryView, idx::AbstractUnitRange) = v[idx]

function truncate(mem::MemoryView, include_last::Integer)
lst = Int(include_last)::Int
@boundscheck if (lst % UInt) > length(mem) % UInt
throw(BoundsError(mem, lst))
end
typeof(mem)(unsafe, mem.ref, lst)
end

function truncate_start_nonempty(mem::MemoryView, from::Integer)
frm = Int(from)::Int
@boundscheck if ((frm - 1) % UInt) length(mem) % UInt
throw(BoundsError(mem, frm))
end
newref = @inbounds memoryref(mem.ref, frm)
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
end

function Base.unsafe_copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
iszero(length(src)) && return dst
@inbounds unsafe_copyto!(dst.ref, src.ref, length(src))
Expand All @@ -105,6 +118,17 @@ function Base.copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
unsafe_copyto!(dst, src)
end

# Optimised methods that don't boundscheck
function Base.findnext(p::Function, mem::MemoryView, start::Integer)
i = Int(start)::Int
@boundscheck (i < 1 && throw(BoundsError(mem, i)))
@inbounds while i <= length(mem)
p(mem[i]) && return i
i += 1
end
nothing
end

# The following two methods could be collapsed, but they aren't for two reasons:
# * To prevent ambiguity with Base
# * Because we DON'T want this code to run with MemoryView{Union{UInt8, Int8}}.
Expand All @@ -126,16 +150,25 @@ function Base.findnext(
_findnext(mem, p.x, start)
end

@inline function _findnext(
function Base.findnext(
::typeof(iszero),
mem::Union{MemoryView{Int8}, MemoryView{UInt8}},
i::Integer,
)
_findnext(mem, zero(eltype(mem)), i)
end

Base.@propagate_inbounds function _findnext(
mem::MemoryView{T},
byte::Union{T},
byte::T,
start::Integer,
) where {T <: Union{UInt8, Int8}}
start = Int(start)::Int
real_start = max(start, 1)
v = @inbounds ImmutableMemoryView(mem[real_start:end])
v_ind = @something memchr(v, byte) return nothing
v_ind + real_start - 1
@boundscheck(start < 1 && throw(BoundsError(mem, start)))
start > length(mem) && return nothing
im = @inbounds truncate_start_nonempty(ImmutableMemoryView(mem), start)
v_ind = @something memchr(im, byte) return nothing
v_ind + start - 1
end

function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
Expand All @@ -151,6 +184,65 @@ function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UI
p == C_NULL ? nothing : (p - ptr) % Int + 1
end

function Base.findprev(p::Function, mem::MemoryView, start::Integer)
i = Int(start)::Int
@boundscheck (i > length(mem) && throw(BoundsError(mem, i)))
@inbounds while i > 0
p(mem[i]) && return i
i -= 1
end
nothing
end

function Base.findprev(
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, UInt8},
mem::MemoryView{UInt8},
start::Integer,
)
_findprev(mem, p.x, start)
end

function Base.findprev(
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, Int8},
mem::MemoryView{Int8},
start::Integer,
)
_findprev(mem, p.x, start)
end

function Base.findprev(
::typeof(iszero),
mem::Union{MemoryView{Int8}, MemoryView{UInt8}},
i::Integer,
)
_findprev(mem, zero(eltype(mem)), i)
end

Base.@propagate_inbounds function _findprev(
mem::MemoryView{T},
byte::T,
start::Integer,
) where {T <: Union{UInt8, Int8}}
start = Int(start)::Int
@boundscheck (start > length(mem) && throw(BoundsError(mem, start)))
start < 1 && return nothing
im = @inbounds truncate(ImmutableMemoryView(mem), start)
memrchr(im, byte)
end

function memrchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
isempty(mem) && return nothing
GC.@preserve mem begin
ptr = Ptr{UInt8}(pointer(mem))
p = @ccall memrchr(
ptr::Ptr{UInt8},
(byte % UInt8)::UInt8,
length(mem)::Int,
)::Ptr{Cvoid}
end
p == C_NULL ? nothing : (p - ptr) % Int + 1
end

const Bits =
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128, Char}

Expand Down
82 changes: 69 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,6 @@ end
@test v2 == v1
end

@testset "Find" begin
mem = MemoryView([4, 3, 2])
@test findfirst(==(2), mem) == 3

mem = MemoryView(Int8[6, 2, 7, 0, 2])
@test findfirst(iszero, mem) == 4
@test findfirst(==(Int8(0)), mem) == 4

mem = MemoryView(UInt8[1, 4, 2, 5, 6])
@test findnext(==(0x04), mem, 1) == 2
@test findnext(==(0x04), mem, 3) === nothing
end

@testset "Reverse and reverse!" begin
for v in [
["a", "abc", "a", "c", "kij"],
Expand Down Expand Up @@ -463,6 +450,75 @@ end
@test split_unaligned(v, Val(8)) == split_at(v, 3)
@test split_unaligned(v, Val(16)) == split_at(v, 7)
end

@testset "Find" begin
@testset "Generic find" begin
mem = ImmutableMemoryView([1, 2, 3, 4])
@test findfirst(isodd, mem) == 1
@test findfirst(isodd, mem[2:end]) == 2
@test findfirst(mem[1:0]) === nothing

@test findlast(isodd, mem) == 3
@test findlast(isodd, mem[1:2]) == 1
@test findlast(isodd, mem[1:0]) === nothing

@test findnext(isodd, mem, 0x02) == 3
@test findnext(isodd, mem, 3) == 3
@test findnext(isodd, mem, 0x04) === nothing
@test findnext(isodd, mem, 10) === nothing

@test_throws BoundsError findnext(isodd, mem, 0)
@test_throws BoundsError findnext(isodd, mem, -1)

@test findprev(isodd, mem, 4) == 3
@test findprev(isodd, mem, 0x03) == 3
@test findprev(isodd, mem, 2) == 1
@test findprev(isodd, mem, 0x00) === nothing
@test findprev(isodd, mem, -10) === nothing

@test_throws BoundsError findprev(isodd, mem, 5)
@test_throws BoundsError findprev(isodd, mem, 7)
end

@testset "Memchr routines" begin
for T in Any[Int8, UInt8]
mem = MemoryView(T[6, 2, 7, 0, 2, 1])
@test findfirst(iszero, mem) == 4
@test findfirst(==(T(2)), mem) == 2
@test findnext(==(T(2)), mem, 3) == 5
@test findnext(==(T(7)), mem, 4) === nothing
@test findnext(==(T(2)), mem, 7) === nothing
@test_throws BoundsError findnext(iszero, mem, 0)
@test_throws BoundsError findnext(iszero, mem, -3)

@test findlast(iszero, mem) == 4
@test findprev(iszero, mem, 3) === nothing
@test findprev(iszero, mem, 4) == 4
@test findprev(==(T(2)), mem, 5) == 5
@test findprev(==(T(2)), mem, 4) == 2
@test findprev(==(T(9)), mem, 3) === nothing
@test findprev(==(T(2)), mem, -2) === nothing
@test findprev(iszero, mem, 0) === nothing
@test_throws BoundsError findprev(iszero, mem, 7)
end
mem = MemoryView(Int8[2, 3, -1])
@test findfirst(==(0xff), mem) === nothing
@test findprev(==(0xff), mem, 3) === nothing
end

@testset "Find" begin
mem = MemoryView([4, 3, 2])
@test findfirst(==(2), mem) == 3

mem = MemoryView(Int8[6, 2, 7, 0, 2])
@test findfirst(iszero, mem) == 4
@test findfirst(==(Int8(0)), mem) == 4

mem = MemoryView(UInt8[1, 4, 2, 5, 6])
@test findnext(==(0x04), mem, 1) == 2
@test findnext(==(0x04), mem, 3) === nothing
end
end
end

@testset "Iterators.reverse" begin
Expand Down

0 comments on commit 83b4218

Please sign in to comment.