diff --git a/NEWS.md b/NEWS.md index d0c2437ed5c0a..dd86f5c8409c1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -65,6 +65,8 @@ Standard library changes * `unique(f, itr; seen=Set{T}())` now allows you to declare the container type used for keeping track of values returned by `f` on elements of `itr` ([#36280]). * `Libdl` has been moved to `Base.Libc.Libdl`, however it is still accessible as an stdlib ([#35628]). +* `first` and `last` functions now accept an integer as second argument to get that many + leading or trailing elements of any iterable ([#34868]). #### LinearAlgebra * New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]). diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 70544fa7ce6e3..fa93112ead432 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -344,9 +344,10 @@ function first(itr) end """ - first(v::AbstractVector, n::Integer) + first(itr, n::Integer) -Get the first `n` elements of vector `v`, or fewer elements if `v` is not long enough. +Get the first `n` elements of the iterable collection `itr`, or fewer elements if `v` is not +long enough. # Examples ```jldoctest @@ -362,7 +363,12 @@ julia> first(Bool[], 1) 0-element Array{Bool,1} ``` """ -first(v::AbstractVector, n::Integer) = @inbounds v[begin:min(begin + n - 1, end)] +first(itr, n::Integer) = collect(Iterators.take(itr, n)) +# Faster method for vectors +function first(v::AbstractVector, n::Integer) + n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) + @inbounds v[begin:min(begin + n - 1, end)] +end """ last(coll) @@ -383,9 +389,10 @@ julia> last([1; 2; 3; 4]) last(a) = a[end] """ - last(v::AbstractVector, n::Integer) + last(itr, n::Integer) -Get the last `n` elements of vector `v`, or fewer elements if `v` is not long enough. +Get the last `n` elements of the iterable collection `itr`, or fewer elements if `v` is not +long enough. # Examples ```jldoctest @@ -401,7 +408,12 @@ julia> last(Float64[], 1) 0-element Array{Float64,1} ``` """ -last(v::AbstractArray, n::Integer) = @inbounds v[max(begin, end - n + 1):end] +last(itr, n::Integer) = reverse!(collect(Iterators.take(Iterators.reverse(itr), n))) +# Faster method for arrays +function last(v::AbstractArray, n::Integer) + n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) + @inbounds v[max(begin, end - n + 1):end] +end """ strides(A) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index f23c5270c6771..2db1638c8c950 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1127,16 +1127,17 @@ end end end -@testset "first/last n elements of vector" begin - v = [1, 13, 42] - @test first(v, -2) == [] - @test first(v, 2) == v[1:2] - @test first(v, 100) == v - @test first(v, 100) !== v - @test first(v, 1) != v[1] - @test last(v, -2) == [] - @test last(v, 2) == v[end-1:end] - @test last(v, 100) == v - @test last(v, 100) !== v - @test last(v, 1) != v[end] +@testset "first/last n elements of $(typeof(itr))" for itr in (collect(1:9), + [1 4 7; 2 5 8; 3 6 9], + ntuple(identity, 9)) + @test first(itr, 6) == [itr[1:6]...] + @test first(itr, 25) == [itr[:]...] + @test first(itr, 25) !== itr + @test first(itr, 1) == [itr[1]] + @test_throws ArgumentError first(itr, -6) + @test last(itr, 6) == [itr[end-5:end]...] + @test last(itr, 25) == [itr[:]...] + @test last(itr, 25) !== itr + @test last(itr, 1) == [itr[end]] + @test_throws ArgumentError last(itr, -6) end diff --git a/test/offsetarray.jl b/test/offsetarray.jl index caa8199670958..127af84b741ab 100644 --- a/test/offsetarray.jl +++ b/test/offsetarray.jl @@ -612,12 +612,12 @@ end @testset "first/last n elements of vector" begin f = firstindex(v) - @test first(v, -2) == [] + @test_throws ArgumentError first(v, -2) @test first(v, 2) == v[f:f+1] @test first(v, 100) == v0 @test first(v, 100) !== v @test first(v, 1) != v[f] - @test last(v, -2) == [] + @test_throws ArgumentError last(v, -2) @test last(v, 2) == v[end-1:end] @test last(v, 100) == v0 @test last(v, 100) !== v