Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two-argument first and last methods for any iterable #34868

Merged
merged 5 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Standard library changes
* `unique(f, itr; seen=Set{T}())` now allows you to declare the container type used for
keeping track of values returned by `f` on elements of `itr` ([#36280]).
* `Libdl` has been moved to `Base.Libc.Libdl`, however it is still accessible as an stdlib ([#35628]).
* `first` and `last` functions now accept an integer as second argument to get that many
leading or trailing elements of any iterable ([#34868]).

#### LinearAlgebra
* New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]).
Expand Down
54 changes: 54 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,33 @@ function first(itr)
x[1]
end

"""
first(itr, n::Integer)
Get the first `n` elements of the iterable collection `itr`, or fewer elements if `v` is not
long enough.
# Examples
```jldoctest
julia> first(["foo", "bar", "qux"], 2)
2-element Vector{String}:
"foo"
"bar"
julia> first(1:6, 10)
1:6
julia> first(Bool[], 1)
Bool[]
```
"""
first(itr, n::Integer) = collect(Iterators.take(itr, n))
# Faster method for vectors
function first(v::AbstractVector, n::Integer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this method for AbstractVector, but the corresponding specialized last method below is for AbstractArray?

n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
@inbounds v[begin:min(begin + n - 1, end)]
end

"""
last(coll)
Expand All @@ -361,6 +388,33 @@ julia> last([1; 2; 3; 4])
"""
last(a) = a[end]

"""
last(itr, n::Integer)
Get the last `n` elements of the iterable collection `itr`, or fewer elements if `v` is not
long enough.
# Examples
```jldoctest
julia> last(["foo", "bar", "qux"], 2)
2-element Vector{String}:
"bar"
"qux"
julia> last(1:6, 10)
1:6
julia> last(Float64[], 1)
Float64[]
```
"""
last(itr, n::Integer) = reverse!(collect(Iterators.take(Iterators.reverse(itr), n)))
# Faster method for arrays
function last(v::AbstractArray, n::Integer)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
@inbounds v[max(begin, end - n + 1):end]
end

"""
strides(A)
Expand Down
15 changes: 15 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1126,3 +1126,18 @@ end
end
end
end

@testset "first/last n elements of $(typeof(itr))" for itr in (collect(1:9),
[1 4 7; 2 5 8; 3 6 9],
ntuple(identity, 9))
@test first(itr, 6) == [itr[1:6]...]
@test first(itr, 25) == [itr[:]...]
@test first(itr, 25) !== itr
@test first(itr, 1) == [itr[1]]
@test_throws ArgumentError first(itr, -6)
@test last(itr, 6) == [itr[end-5:end]...]
@test last(itr, 25) == [itr[:]...]
@test last(itr, 25) !== itr
@test last(itr, 1) == [itr[end]]
@test_throws ArgumentError last(itr, -6)
end
15 changes: 15 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,18 @@ end
@test_throws DimensionMismatch maximum!(fill(0, -4:-4, 7:7, -6:-5, 1:1), B)
@test_throws DimensionMismatch minimum!(fill(0, -4:-4, 7:7, -6:-5, 1:1), B)
end

@testset "first/last n elements of vector" begin
v0 = rand(6)
v = OffsetArray(v0, (-3,))
@test_throws ArgumentError first(v, -2)
@test first(v, 2) == v[begin:begin+1]
@test first(v, 100) == v0
@test first(v, 100) !== v
@test first(v, 1) == [v[begin]]
@test_throws ArgumentError last(v, -2)
@test last(v, 2) == v[end-1:end]
@test last(v, 100) == v0
@test last(v, 100) !== v
@test last(v, 1) == [v[end]]
end