Skip to content

Commit

Permalink
Vector indexing for OneElement
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed May 9, 2024
1 parent 4f8a966 commit b1f8d93
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,27 @@ OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)

Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,0}) where {T} = getindex_value(A)
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
ifelse(kj == A.ind, A.val, zero(T))
end
const VectorIndsWithColon = Union{AbstractRange{Int}, Colon, Int}
const VectorInds = Union{AbstractRange{Int}, Int}
# retain the values from Ainds corresponding to the vector indices in inds
_index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds))
_index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...)
_index_shape(::Tuple{}, ::Tuple{}) = ()
@inline function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorInds,N}) where {T,N}
@boundscheck checkbounds(A, inds...)
shape = _index_shape(inds, inds)
nzind = _index_shape(A.ind, inds) .- first.(shape) .+ firstindex.(shape)
containsval = all(in.(A.ind, inds))
OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1))
end
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N}
getindex(A, Base.to_indices(A, inds)...)
end

"""
nzind(A::OneElement{T,N}) -> CartesianIndex{N}
Expand Down
78 changes: 78 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2148,10 +2148,12 @@ end
@test FillArrays.nzind(A) == CartesianIndex()
@test A == Fill(2, ())
@test A[] === 2
@test A[1] === A[1,1] === 2

e₁ = OneElement(2, 5)
@test e₁ == [0,1,0,0,0]
@test FillArrays.nzind(e₁) == CartesianIndex(2)
@test e₁[2] === e₁[2,1] === e₁[2,1,1] === 1
@test_throws BoundsError e₁[6]

f₁ = AbstractArray{Float64}(e₁)
Expand Down Expand Up @@ -2193,6 +2195,82 @@ end
@test A[1,1] === A[1,2] === A[2,1] === zero(S)
end

@testset "Vector indexing" begin
@testset "1D" begin
A = OneElement(2, 2, 4)
@test @inferred(A[:]) === @inferred(A[axes(A)...]) === A
@test @inferred(A[3:4]) isa OneElement{Int,1}
@test @inferred(A[3:4]) == Zeros(2)
@test @inferred(A[1:2]) === OneElement(2, 2, 2)
@test @inferred(A[2:3]) === OneElement(2, 1, 2)
@test @inferred(A[Base.IdentityUnitRange(2:3)]) isa OneElement{Int,1}
@test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),))
@test A[:,:] == reshape(A, size(A)..., 1)

B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),))
@test @inferred(A[:]) === @inferred(A[axes(A)...]) === A
@test @inferred(A[3:4]) isa OneElement{Int,1}
@test @inferred(A[3:4]) == Zeros(2)
@test @inferred(A[2:3]) === OneElement(2, 1, 2)

C = OneElement(2, (2,), (Base.OneTo(big(4)),))
@test @inferred(C[1:4]) === OneElement(2, 2, 4)

D = OneElement(2, (2,), (InfiniteArrays.OneToInf(),))
D2 = D[:]
@test axes(D2) == axes(D)
@test D2[2] == D[2]
D3 = D[axes(D)...]
@test axes(D3) == axes(D)
@test D3[2] == D[2]
end
@testset "2D" begin
A = OneElement(2, (2,3), (4,5))
@test @inferred(A[:,:]) === @inferred(A[axes(A)...]) === A
@test @inferred(A[:,1]) isa OneElement{Int,1}
@test @inferred(A[:,1]) == Zeros(4)
@test @inferred(A[1,:]) isa OneElement{Int,1}
@test @inferred(A[1,:]) == Zeros(5)
@test @inferred(A[:,3]) === OneElement(2, 2, 4)
@test @inferred(A[2,:]) === OneElement(2, 3, 5)
@test @inferred(A[1:1,:]) isa OneElement{Int,2}
@test @inferred(A[1:1,:]) == Zeros(1,5)
@test @inferred(A[4:4,:]) isa OneElement{Int,2}
@test @inferred(A[4:4,:]) == Zeros(1,5)
@test @inferred(A[2:2,:]) === OneElement(2, (1,3), (1,5))
@test @inferred(A[1:4,:]) === OneElement(2, (2,3), (4,5))
@test @inferred(A[:,3:3]) === OneElement(2, (2,1), (4,1))
@test @inferred(A[:,1:5]) === OneElement(2, (2,3), (4,5))
@test @inferred(A[1:4,1:4]) === OneElement(2, (2,3), (4,4))
@test @inferred(A[2:4,2:4]) === OneElement(2, (1,2), (3,3))
@test @inferred(A[2:4,3:4]) === OneElement(2, (1,1), (3,2))
@test @inferred(A[4:4,5:5]) isa OneElement{Int,2}
@test @inferred(A[4:4,5:5]) == Zeros(1,1)
@test @inferred(A[Base.IdentityUnitRange(2:4), :]) isa OneElement{Int,2}
@test axes(A[Base.IdentityUnitRange(2:4), :]) == (Base.IdentityUnitRange(2:4), axes(A,2))
@test @inferred(A[:,:,:]) == reshape(A, size(A)...,1)

B = OneElement(2, (2,3), (Base.IdentityUnitRange(2:4),Base.IdentityUnitRange(2:5)))
@test @inferred(B[:,:]) === @inferred(B[axes(B)...]) === B
@test @inferred(B[:,3]) === OneElement(2, (2,), (Base.IdentityUnitRange(2:4),))
@test @inferred(B[3:4, 4:5]) isa OneElement{Int,2}
@test @inferred(B[3:4, 4:5]) == Zeros(2,2)
b = @inferred(B[Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(4:5)])
@test b == Zeros(axes(b))

C = OneElement(2, (2,3), (Base.OneTo(big(4)), Base.OneTo(big(5))))
@test @inferred(C[1:4, 1:5]) === OneElement(2, (2,3), Int.(size(C)))

D = OneElement(2, (2,3), (InfiniteArrays.OneToInf(), InfiniteArrays.OneToInf()))
D2 = @inferred D[:,:]
@test axes(D2) == axes(D)
@test D2[2,3] == D[2,3]
D3 = @inferred D[axes(D)...]
@test axes(D3) == axes(D)
@test D3[2,3] == D[2,3]
end
end

@testset "adjoint/transpose" begin
A = OneElement(3im, (2,4), (4,6))
@test A' === OneElement(-3im, (4,2), (6,4))
Expand Down

0 comments on commit b1f8d93

Please sign in to comment.