Skip to content

Commit

Permalink
Add two-argument first and last methods for any iterable (#34868)
Browse files Browse the repository at this point in the history
* Add `{first,last}(::AbstractVector, ::Integer)` methods

* Apply suggestions from code review

Co-Authored-By: Milan Bouchet-Valat <nalimilan@club.fr>

* Add tests for OffsetArrays

* Use `begin` in place of `firstindex`

* Generalise two-argument `first` and `last` to any iterable

Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
  • Loading branch information
giordano and nalimilan authored Jul 13, 2020
1 parent a23a4ff commit e24e2f0
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Standard library changes
* `unique(f, itr; seen=Set{T}())` now allows you to declare the container type used for
keeping track of values returned by `f` on elements of `itr` ([#36280]).
* `Libdl` has been moved to `Base.Libc.Libdl`, however it is still accessible as an stdlib ([#35628]).
* `first` and `last` functions now accept an integer as second argument to get that many
leading or trailing elements of any iterable ([#34868]).

#### LinearAlgebra
* New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]).
Expand Down
54 changes: 54 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,33 @@ function first(itr)
x[1]
end

"""
first(itr, n::Integer)
Get the first `n` elements of the iterable collection `itr`, or fewer elements if `v` is not
long enough.
# Examples
```jldoctest
julia> first(["foo", "bar", "qux"], 2)
2-element Vector{String}:
"foo"
"bar"
julia> first(1:6, 10)
1:6
julia> first(Bool[], 1)
Bool[]
```
"""
first(itr, n::Integer) = collect(Iterators.take(itr, n))
# Faster method for vectors
function first(v::AbstractVector, n::Integer)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
@inbounds v[begin:min(begin + n - 1, end)]
end

"""
last(coll)
Expand All @@ -361,6 +388,33 @@ julia> last([1; 2; 3; 4])
"""
last(a) = a[end]

"""
last(itr, n::Integer)
Get the last `n` elements of the iterable collection `itr`, or fewer elements if `v` is not
long enough.
# Examples
```jldoctest
julia> last(["foo", "bar", "qux"], 2)
2-element Vector{String}:
"bar"
"qux"
julia> last(1:6, 10)
1:6
julia> last(Float64[], 1)
Float64[]
```
"""
last(itr, n::Integer) = reverse!(collect(Iterators.take(Iterators.reverse(itr), n)))
# Faster method for arrays
function last(v::AbstractArray, n::Integer)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
@inbounds v[max(begin, end - n + 1):end]
end

"""
strides(A)
Expand Down
15 changes: 15 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1126,3 +1126,18 @@ end
end
end
end

@testset "first/last n elements of $(typeof(itr))" for itr in (collect(1:9),
[1 4 7; 2 5 8; 3 6 9],
ntuple(identity, 9))
@test first(itr, 6) == [itr[1:6]...]
@test first(itr, 25) == [itr[:]...]
@test first(itr, 25) !== itr
@test first(itr, 1) == [itr[1]]
@test_throws ArgumentError first(itr, -6)
@test last(itr, 6) == [itr[end-5:end]...]
@test last(itr, 25) == [itr[:]...]
@test last(itr, 25) !== itr
@test last(itr, 1) == [itr[end]]
@test_throws ArgumentError last(itr, -6)
end
15 changes: 15 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,18 @@ end
@test_throws DimensionMismatch maximum!(fill(0, -4:-4, 7:7, -6:-5, 1:1), B)
@test_throws DimensionMismatch minimum!(fill(0, -4:-4, 7:7, -6:-5, 1:1), B)
end

@testset "first/last n elements of vector" begin
v0 = rand(6)
v = OffsetArray(v0, (-3,))
@test_throws ArgumentError first(v, -2)
@test first(v, 2) == v[begin:begin+1]
@test first(v, 100) == v0
@test first(v, 100) !== v
@test first(v, 1) == [v[begin]]
@test_throws ArgumentError last(v, -2)
@test last(v, 2) == v[end-1:end]
@test last(v, 100) == v0
@test last(v, 100) !== v
@test last(v, 1) == [v[end]]
end

2 comments on commit e24e2f0

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @ararslan

Please sign in to comment.