Skip to content

Commit

Permalink
Generalise two-argument first and last to any iterable
Browse files Browse the repository at this point in the history
  • Loading branch information
giordano committed Jul 10, 2020
1 parent 0e0eced commit 9ad53fa
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 20 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Standard library changes
* `unique(f, itr; seen=Set{T}())` now allows you to declare the container type used for
keeping track of values returned by `f` on elements of `itr` ([#36280]).
* `Libdl` has been moved to `Base.Libc.Libdl`, however it is still accessible as an stdlib ([#35628]).
* `first` and `last` functions now accept an integer as second argument to get that many
leading or trailing elements of any iterable ([#34868]).

#### LinearAlgebra
* New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]).
Expand Down
24 changes: 18 additions & 6 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,10 @@ function first(itr)
end

"""
first(v::AbstractVector, n::Integer)
first(itr, n::Integer)
Get the first `n` elements of vector `v`, or fewer elements if `v` is not long enough.
Get the first `n` elements of the iterable collection `itr`, or fewer elements if `v` is not
long enough.
# Examples
```jldoctest
Expand All @@ -362,7 +363,12 @@ julia> first(Bool[], 1)
0-element Array{Bool,1}
```
"""
first(v::AbstractVector, n::Integer) = @inbounds v[begin:min(begin + n - 1, end)]
first(itr, n::Integer) = collect(Iterators.take(itr, n))
# Faster method for vectors
function first(v::AbstractVector, n::Integer)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
@inbounds v[begin:min(begin + n - 1, end)]
end

"""
last(coll)
Expand All @@ -383,9 +389,10 @@ julia> last([1; 2; 3; 4])
last(a) = a[end]

"""
last(v::AbstractVector, n::Integer)
last(itr, n::Integer)
Get the last `n` elements of vector `v`, or fewer elements if `v` is not long enough.
Get the last `n` elements of the iterable collection `itr`, or fewer elements if `v` is not
long enough.
# Examples
```jldoctest
Expand All @@ -401,7 +408,12 @@ julia> last(Float64[], 1)
0-element Array{Float64,1}
```
"""
last(v::AbstractArray, n::Integer) = @inbounds v[max(begin, end - n + 1):end]
last(itr, n::Integer) = reverse!(collect(Iterators.take(Iterators.reverse(itr), n)))
# Faster method for arrays
function last(v::AbstractArray, n::Integer)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
@inbounds v[max(begin, end - n + 1):end]
end

"""
strides(A)
Expand Down
25 changes: 13 additions & 12 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1127,16 +1127,17 @@ end
end
end

@testset "first/last n elements of vector" begin
v = [1, 13, 42]
@test first(v, -2) == []
@test first(v, 2) == v[1:2]
@test first(v, 100) == v
@test first(v, 100) !== v
@test first(v, 1) != v[1]
@test last(v, -2) == []
@test last(v, 2) == v[end-1:end]
@test last(v, 100) == v
@test last(v, 100) !== v
@test last(v, 1) != v[end]
@testset "first/last n elements of $(typeof(itr))" for itr in (collect(1:9),
[1 4 7; 2 5 8; 3 6 9],
ntuple(identity, 9))
@test first(itr, 6) == [itr[1:6]...]
@test first(itr, 25) == [itr[:]...]
@test first(itr, 25) !== itr
@test first(itr, 1) == [itr[1]]
@test_throws ArgumentError first(itr, -6)
@test last(itr, 6) == [itr[end-5:end]...]
@test last(itr, 25) == [itr[:]...]
@test last(itr, 25) !== itr
@test last(itr, 1) == [itr[end]]
@test_throws ArgumentError last(itr, -6)
end
4 changes: 2 additions & 2 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,12 @@ end

@testset "first/last n elements of vector" begin
f = firstindex(v)
@test first(v, -2) == []
@test_throws ArgumentError first(v, -2)
@test first(v, 2) == v[f:f+1]
@test first(v, 100) == v0
@test first(v, 100) !== v
@test first(v, 1) != v[f]
@test last(v, -2) == []
@test_throws ArgumentError last(v, -2)
@test last(v, 2) == v[end-1:end]
@test last(v, 100) == v0
@test last(v, 100) !== v
Expand Down

0 comments on commit 9ad53fa

Please sign in to comment.