diff --git a/NEWS.md b/NEWS.md index d0c2437ed5c0a..dd86f5c8409c1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -65,6 +65,8 @@ Standard library changes * `unique(f, itr; seen=Set{T}())` now allows you to declare the container type used for keeping track of values returned by `f` on elements of `itr` ([#36280]). * `Libdl` has been moved to `Base.Libc.Libdl`, however it is still accessible as an stdlib ([#35628]). +* `first` and `last` functions now accept an integer as second argument to get that many + leading or trailing elements of any iterable ([#34868]). #### LinearAlgebra * New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]). diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 70544fa7ce6e3..68dce7e7d0b20 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -344,14 +344,15 @@ function first(itr) end """ - first(v::AbstractVector, n::Integer) + first(itr, n::Integer) -Get the first `n` elements of vector `v`, or fewer elements if `v` is not long enough. +Get the first `n` elements of the iterable collection `itr`, or fewer elements if `v` is not +long enough. # Examples ```jldoctest julia> first(["foo", "bar", "qux"], 2) -2-element Array{String,1}: +2-element Vector{String}: "foo" "bar" @@ -359,10 +360,15 @@ julia> first(1:6, 10) 1:6 julia> first(Bool[], 1) -0-element Array{Bool,1} +Bool[] ``` """ -first(v::AbstractVector, n::Integer) = @inbounds v[begin:min(begin + n - 1, end)] +first(itr, n::Integer) = collect(Iterators.take(itr, n)) +# Faster method for vectors +function first(v::AbstractVector, n::Integer) + n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) + @inbounds v[begin:min(begin + n - 1, end)] +end """ last(coll) @@ -383,14 +389,15 @@ julia> last([1; 2; 3; 4]) last(a) = a[end] """ - last(v::AbstractVector, n::Integer) + last(itr, n::Integer) -Get the last `n` elements of vector `v`, or fewer elements if `v` is not long enough. +Get the last `n` elements of the iterable collection `itr`, or fewer elements if `v` is not +long enough. # Examples ```jldoctest julia> last(["foo", "bar", "qux"], 2) -2-element Array{String,1}: +2-element Vector{String}: "bar" "qux" @@ -398,10 +405,15 @@ julia> last(1:6, 10) 1:6 julia> last(Float64[], 1) -0-element Array{Float64,1} +Float64[] ``` """ -last(v::AbstractArray, n::Integer) = @inbounds v[max(begin, end - n + 1):end] +last(itr, n::Integer) = reverse!(collect(Iterators.take(Iterators.reverse(itr), n))) +# Faster method for arrays +function last(v::AbstractArray, n::Integer) + n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) + @inbounds v[max(begin, end - n + 1):end] +end """ strides(A) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index f23c5270c6771..2db1638c8c950 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1127,16 +1127,17 @@ end end end -@testset "first/last n elements of vector" begin - v = [1, 13, 42] - @test first(v, -2) == [] - @test first(v, 2) == v[1:2] - @test first(v, 100) == v - @test first(v, 100) !== v - @test first(v, 1) != v[1] - @test last(v, -2) == [] - @test last(v, 2) == v[end-1:end] - @test last(v, 100) == v - @test last(v, 100) !== v - @test last(v, 1) != v[end] +@testset "first/last n elements of $(typeof(itr))" for itr in (collect(1:9), + [1 4 7; 2 5 8; 3 6 9], + ntuple(identity, 9)) + @test first(itr, 6) == [itr[1:6]...] + @test first(itr, 25) == [itr[:]...] + @test first(itr, 25) !== itr + @test first(itr, 1) == [itr[1]] + @test_throws ArgumentError first(itr, -6) + @test last(itr, 6) == [itr[end-5:end]...] + @test last(itr, 25) == [itr[:]...] + @test last(itr, 25) !== itr + @test last(itr, 1) == [itr[end]] + @test_throws ArgumentError last(itr, -6) end diff --git a/test/offsetarray.jl b/test/offsetarray.jl index caa8199670958..4f8e5224db01b 100644 --- a/test/offsetarray.jl +++ b/test/offsetarray.jl @@ -611,15 +611,16 @@ end end @testset "first/last n elements of vector" begin - f = firstindex(v) - @test first(v, -2) == [] - @test first(v, 2) == v[f:f+1] + v0 = rand(6) + v = OffsetArray(v0, (-3,)) + @test_throws ArgumentError first(v, -2) + @test first(v, 2) == v[begin:begin+1] @test first(v, 100) == v0 @test first(v, 100) !== v - @test first(v, 1) != v[f] - @test last(v, -2) == [] + @test first(v, 1) == [v[begin]] + @test_throws ArgumentError last(v, -2) @test last(v, 2) == v[end-1:end] @test last(v, 100) == v0 @test last(v, 100) !== v - @test last(v, 1) != v[end] + @test last(v, 1) == [v[end]] end