diff --git a/Project.toml b/Project.toml index f630477..e8dfb17 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveFactorization" uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" authors = ["Yingbo Ma "] -version = "0.2.21" +version = "0.2.22" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/lu.jl b/src/lu.jl index ed462dd..458ab4d 100644 --- a/src/lu.jl +++ b/src/lu.jl @@ -1,7 +1,7 @@ using LoopVectorization using TriangularSolve: ldiv! using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS, - LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat + LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat using StrideArraysCore using Polyester: @batch @@ -41,32 +41,35 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn) if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_cols!) function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange, - B::StridedVecOrMat) + B::StridedVecOrMat) return B end end if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_rows!) function LinearAlgebra._ipiv_rows!(::(LU{T, <:AbstractMatrix{T}, NotIPIV} where {T}), - ::OrdinalRange, - B::StridedVecOrMat) + ::OrdinalRange, + B::StridedVecOrMat) return B end end if CUSTOMIZABLE_PIVOT function LinearAlgebra.ldiv!(A::LU{T, <:StridedMatrix, <:NotIPIV}, - B::StridedVecOrMat{T}) where {T <: BlasFloat} + B::StridedVecOrMat{T}) where {T <: BlasFloat} ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B)) end end -function lu!(A, pivot = Val(true), thread = Val(false); check = true, kwargs...) +function lu!(A, pivot = Val(true), thread = Val(false); + check::Union{Bool, Val{true}, Val{false}} = Val(true), kwargs...) m, n = size(A) minmn = min(m, n) npivot = normalize_pivot(pivot) # we want the type on both branches to match. When pivot = Val(false), we construct # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not. F = if pivot === Val(true) && minmn < 10 # avx introduces small performance degradation - LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check = check) + LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); + check = ((check isa Bool && check) || (check === Val(true))) + ) else lu!(A, init_pivot(npivot, minmn), npivot, thread; check = check, kwargs...) @@ -87,11 +90,11 @@ recurse(_) = false _ptrarray(ipiv) = PtrArray(ipiv) _ptrarray(ipiv::NotIPIV) = ipiv function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer}, - pivot = Val(true), thread = Val(false); - check::Bool = true, - # the performance is not sensitive wrt blocksize, and 8 is a good default - blocksize::Integer = length(A) ≥ 40_000 ? 8 : 16, - threshold::Integer = pick_threshold()) where {T} + pivot = Val(true), thread = Val(false); + check::Union{Bool, Val{true}, Val{false}} = Val(true), + # the performance is not sensitive wrt blocksize, and 8 is a good default + blocksize::Integer = length(A) ≥ 40_000 ? 8 : 16, + threshold::Integer = pick_threshold()) where {T} pivot = normalize_pivot(pivot) info = zero(BlasInt) m, n = size(A) @@ -113,12 +116,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer}, else # generic fallback info = _generic_lufact!(A, pivot, ipiv, info) end - check && checknonsingular(info) + ((check isa Bool && check) || (check === Val(true))) && checknonsingular(info) LU(A, ipiv, info) end @inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, - ::Val{true}) where {Pivot} + ::Val{true}) where {Pivot} if length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(2)) _recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true)) @@ -127,11 +130,11 @@ end end end @inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, - ::Val{false}) where {Pivot} + ::Val{false}) where {Pivot} _recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false)) end @inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, - ::Val{Thread}) where {Pivot, Thread} + ::Val{Thread}) where {Pivot, Thread} info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int @inbounds if m < n # fat matrix # [AL AR] @@ -175,7 +178,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false}) nothing end function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize, - thread)::BlasInt where {T, Pivot} + thread)::BlasInt where {T, Pivot} @inbounds begin if n <= max(blocksize, 1) info = _generic_lufact!(A, Val(Pivot), ipiv, info)