Skip to content

Commit

Permalink
Diag for OneElement returns a OneElement (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Aug 27, 2024
1 parent 05b76ad commit 6bab762
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,14 @@ function triu(A::OneElementMatrix, k::Integer=0)
OneElement(nzband < k ? zero(A.val) : A.val, A.ind, axes(A))
end

# diag
function diag(O::OneElementMatrix, k::Integer=0)
Base.require_one_based_indexing(O)
len = length(diagind(O, k))
ind = O.ind[2] - O.ind[1] == k ? (k >= 0 ? O.ind[2] - k : O.ind[1] + k) : len + 1
OneElement(getindex_value(O), ind, len)
end

# broadcast

function broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::OneElement{<:Any,N}) where {N}
Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,16 @@ end
B = OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))
@test repr(B) == "OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))"
end

@testset "diag" begin
@testset for sz in [(0,0), (0,1), (1,0), (1,1), (4,4), (4,6), (6,3)], ind in CartesianIndices(sz)
O = OneElement(4, Tuple(ind), sz)
@testset for k in -maximum(sz):maximum(sz)
@test diag(O, k) == diag(Array(O), k)
@test diag(O, k) isa OneElement{Int,1}
end
end
end
end

@testset "repeat" begin
Expand Down

0 comments on commit 6bab762

Please sign in to comment.