Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding one for structured matrices that preserves type #29777

Merged
merged 8 commits into from
Feb 6, 2019
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ function fill!(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}, x)
not be filled with $x, since some of its entries are constrained."))
end

one(A::Diagonal{T}) where T = Diagonal(fill!(similar(A.diag, typeof(one(T))), one(T)))
one(A::Bidiagonal{T}) where T = Bidiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))), A.uplo)
one(A::Tridiagonal{T}) where T = Tridiagonal(fill!(similar(A.du, typeof(one(T))), zero(one(T))), fill!(similar(A.d, typeof(one(T))), one(T)), fill!(similar(A.dl, typeof(one(T))), zero(one(T))))
one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))))
# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

Expand Down
82 changes: 82 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,88 @@ end
@test isa((@inferred vcat(Float64[], spzeros(1))), SparseVector)
end


# for testing types with a dimension
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :Furlongs) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Furlongs.jl"))
using .Main.Furlongs

@testset "zero and one for structured matrices" begin
for elty in (Int64, Float64, ComplexF64)
D = Diagonal(rand(elty, 10))
Bu = Bidiagonal(rand(elty, 10), rand(elty, 9), 'U')
Bl = Bidiagonal(rand(elty, 10), rand(elty, 9), 'L')
T = Tridiagonal(rand(elty, 9),rand(elty, 10), rand(elty, 9))
S = SymTridiagonal(rand(elty, 10), rand(elty, 9))
mats = [D, Bu, Bl, T, S]
for A in mats
@test iszero(zero(A))
@test isone(one(A))
@test zero(A) == zero(Matrix(A))
@test one(A) == one(Matrix(A))
end

@test zero(D) isa Diagonal
@test one(D) isa Diagonal

@test zero(Bu) isa Bidiagonal
@test one(Bu) isa Bidiagonal
@test zero(Bl) isa Bidiagonal
@test one(Bl) isa Bidiagonal
@test zero(Bu).uplo == one(Bu).uplo == Bu.uplo
@test zero(Bl).uplo == one(Bl).uplo == Bl.uplo

@test zero(T) isa Tridiagonal
@test one(T) isa Tridiagonal
@test zero(S) isa SymTridiagonal
@test one(S) isa SymTridiagonal
end

# ranges
D = Diagonal(1:10)
Bu = Bidiagonal(1:10, 1:9, 'U')
Bl = Bidiagonal(1:10, 1:9, 'L')
T = Tridiagonal(1:9, 1:10, 1:9)
S = SymTridiagonal(1:10, 1:9)
mats = [D, Bu, Bl, T, S]
for A in mats
@test iszero(zero(A))
@test isone(one(A))
@test zero(A) == zero(Matrix(A))
@test one(A) == one(Matrix(A))
end

@test zero(D) isa Diagonal
@test one(D) isa Diagonal

@test zero(Bu) isa Bidiagonal
@test one(Bu) isa Bidiagonal
@test zero(Bl) isa Bidiagonal
@test one(Bl) isa Bidiagonal
@test zero(Bu).uplo == one(Bu).uplo == Bu.uplo
@test zero(Bl).uplo == one(Bl).uplo == Bl.uplo

@test zero(T) isa Tridiagonal
@test one(T) isa Tridiagonal
@test zero(S) isa SymTridiagonal
@test one(S) isa SymTridiagonal

# eltype with dimensions
D = Diagonal{Furlong{2, Int64}}([1, 2, 3, 4])
Bu = Bidiagonal{Furlong{2, Int64}}([1, 2, 3, 4], [1, 2, 3], 'U')
Bl = Bidiagonal{Furlong{2, Int64}}([1, 2, 3, 4], [1, 2, 3], 'L')
T = Tridiagonal{Furlong{2, Int64}}([1, 2, 3], [1, 2, 3, 4], [1, 2, 3])
S = SymTridiagonal{Furlong{2, Int64}}([1, 2, 3, 4], [1, 2, 3])
mats = [D, Bu, Bl, T, S]
for A in mats
@test iszero(zero(A))
@test isone(one(A))
@test zero(A) == zero(Matrix(A))
@test one(A) == one(Matrix(A))
@test eltype(one(A)) == typeof(one(eltype(A)))
end
end

@testset "== for structured matrices" begin
diag = rand(10)
offdiag = rand(9)
Expand Down