diff --git a/src/oneelement.jl b/src/oneelement.jl index 556f692c..ed375f98 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -392,6 +392,14 @@ function triu(A::OneElementMatrix, k::Integer=0) OneElement(nzband < k ? zero(A.val) : A.val, A.ind, axes(A)) end +# diag +function diag(O::OneElementMatrix, k::Integer=0) + Base.require_one_based_indexing(O) + len = length(diagind(O, k)) + ind = O.ind[2] - O.ind[1] == k ? (k >= 0 ? O.ind[2] - k : O.ind[1] + k) : len + 1 + OneElement(getindex_value(O), ind, len) +end + # broadcast function broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::OneElement{<:Any,N}) where {N} diff --git a/test/runtests.jl b/test/runtests.jl index 0163df8a..ad5f50a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2699,6 +2699,16 @@ end B = OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2))) @test repr(B) == "OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))" end + + @testset "diag" begin + @testset for sz in [(0,0), (0,1), (1,0), (1,1), (4,4), (4,6), (6,3)], ind in CartesianIndices(sz) + O = OneElement(4, Tuple(ind), sz) + @testset for k in -maximum(sz):maximum(sz) + @test diag(O, k) == diag(Array(O), k) + @test diag(O, k) isa OneElement{Int,1} + end + end + end end @testset "repeat" begin