Skip to content

Commit

Permalink
Make LU factorization work for more types (#26344)
Browse files Browse the repository at this point in the history
* Make LU factorization work for more types

* luop -> rationalop

* More robust type promotion

* lufact -> lu in tests

* Implements one for type D

* Move trickyarithmetic to a separate file

* Add constructor used by oneunit
  • Loading branch information
blegat authored and andreasnoack committed Jun 6, 2018
1 parent 6f1944b commit 89e445b
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 2 deletions.
23 changes: 21 additions & 2 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,25 @@ function lu(A::Union{AbstractMatrix{T}, AbstractMatrix{Complex{T}}},
lu!(copy(A), pivot)
end

function lutype(T::Type)
# In generic_lufact!, the elements of the lower part of the matrix are
# obtained using the division of two matrix elements. Hence their type can
# be different (e.g. the division of two types with the same unit is a type
# without unit).
# The elements of the upper part are obtained by U - U * L
# where U is an upper part element and L is a lower part element.
# Therefore, the types LT, UT should be invariant under the map:
# (LT, UT) -> begin
# L = oneunit(UT) / oneunit(UT)
# U = oneunit(UT) - oneunit(UT) * L
# typeof(L), typeof(U)
# end
# The following should handle most cases
UT = typeof(oneunit(T) - oneunit(T) * (oneunit(T) / (oneunit(T) + zero(T))))
LT = typeof(oneunit(UT) / oneunit(UT))
S = promote_type(T, LT, UT)
end

# for all other types we must promote to a type which is stable under division
"""
lu(A, pivot=Val(true)) -> F::LU
Expand Down Expand Up @@ -191,14 +210,14 @@ true
```
"""
function lu(A::AbstractMatrix{T}, pivot::Union{Val{false}, Val{true}}) where T
S = typeof(zero(T)/one(T))
S = lutype(T)
AA = similar(A, S)
copyto!(AA, A)
lu!(AA, pivot)
end
# We can't assume an ordered field so we first try without pivoting
function lu(A::AbstractMatrix{T}) where T
S = typeof(zero(T)/one(T))
S = lutype(T)
AA = similar(A, S)
copyto!(AA, A)
F = lu!(AA, Val(false))
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,15 @@ end
@test allnames == ["L", "P", "U", "factors", "info", "ipiv", "p"]
end

include("trickyarithmetic.jl")

@testset "lu with type whose sum is another type" begin
A = TrickyArithmetic.A[1 2; 3 4]
ElT = TrickyArithmetic.D{TrickyArithmetic.C,TrickyArithmetic.C}
B = lu(A)
@test B isa LinearAlgebra.LU{ElT,Matrix{ElT}}
C = lu(A, Val(false))
@test C isa LinearAlgebra.LU{ElT,Matrix{ElT}}
end

end # module TestLU
60 changes: 60 additions & 0 deletions stdlib/LinearAlgebra/test/trickyarithmetic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module TrickyArithmetic
struct A
x::Int
end
A(a::A) = a
Base.convert(::Type{A}, i::Int) = A(i)
Base.zero(::Union{A, Type{A}}) = A(0)
Base.one(::Union{A, Type{A}}) = A(1)
struct B
x::Int
end
struct C
x::Int
end
C(a::A) = C(a.x)
Base.zero(::Union{C, Type{C}}) = C(0)
Base.one(::Union{C, Type{C}}) = C(1)

Base.:(*)(x::Int, a::A) = B(x*a.x)
Base.:(*)(a::A, x::Int) = B(a.x*x)
Base.:(*)(a::Union{A,B}, b::Union{A,B}) = B(a.x*b.x)
Base.:(*)(a::Union{A,B,C}, b::Union{A,B,C}) = C(a.x*b.x)
Base.:(+)(a::Union{A,B,C}, b::Union{A,B,C}) = C(a.x+b.x)
Base.:(-)(a::Union{A,B,C}, b::Union{A,B,C}) = C(a.x-b.x)

struct D{NT, DT}
n::NT
d::DT
end
D{NT, DT}(d::D{NT, DT}) where {NT, DT} = d # called by oneunit
Base.zero(::Union{D{NT, DT}, Type{D{NT, DT}}}) where {NT, DT} = zero(NT) / one(DT)
Base.one(::Union{D{NT, DT}, Type{D{NT, DT}}}) where {NT, DT} = one(NT) / one(DT)
Base.convert(::Type{D{NT, DT}}, a::Union{A, B, C}) where {NT, DT} = NT(a) / one(DT)
#Base.convert(::Type{D{NT, DT}}, a::D) where {NT, DT} = NT(a.n) / DT(a.d)

Base.:(*)(a::D, b::D) = (a.n*b.n) / (a.d*b.d)
Base.:(*)(a::D, b::Union{A,B,C}) = (a.n * b) / a.d
Base.:(*)(a::Union{A,B,C}, b::D) = b * a
Base.inv(a::Union{A,B,C}) = A(1) / a
Base.inv(a::D) = a.d / a.n
Base.:(/)(a::Union{A,B,C}, b::Union{A,B,C}) = D(a, b)
Base.:(/)(a::D, b::Union{A,B,C}) = a.n / (a.d*b)
Base.:(/)(a::Union{A,B,C,D}, b::D) = a * inv(b)
Base.:(+)(a::Union{A,B,C}, b::D) = (a*b.d+b.n) / b.d
Base.:(+)(a::D, b::Union{A,B,C}) = b + a
Base.:(+)(a::D, b::D) = (a.n*b.d+a.d*b.n) / (a.d*b.d)
Base.:(-)(a::Union{A,B,C}) = typeof(a)(a.x)
Base.:(-)(a::D) = (-a.n) / a.d
Base.:(-)(a::Union{A,B,C,D}, b::Union{A,B,C,D}) = a + (-b)

Base.promote_rule(::Type{A}, ::Type{B}) = B
Base.promote_rule(::Type{B}, ::Type{A}) = B
Base.promote_rule(::Type{A}, ::Type{C}) = C
Base.promote_rule(::Type{C}, ::Type{A}) = C
Base.promote_rule(::Type{B}, ::Type{C}) = C
Base.promote_rule(::Type{C}, ::Type{B}) = C
Base.promote_rule(::Type{D{NT,DT}}, T::Type{<:Union{A,B,C}}) where {NT,DT} = D{promote_type(NT,T),DT}
Base.promote_rule(T::Type{<:Union{A,B,C}}, ::Type{D{NT,DT}}) where {NT,DT} = D{promote_type(NT,T),DT}
Base.promote_rule(::Type{D{NS,DS}}, ::Type{D{NT,DT}}) where {NS,DS,NT,DT} = D{promote_type(NS,NT),promote_type(DS,DT)}
end

0 comments on commit 89e445b

Please sign in to comment.