Skip to content

Commit

Permalink
Merge branch 'master' into jishnub/blockstructuredbcast
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored May 15, 2024
2 parents 05dc110 + 28aaafc commit 743978b
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 70 deletions.
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
typed_hcat, vec, view, zero
using Base: IndexLinear, promote_eltype, promote_op, print_matrix,
@propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
splat
splat, BitInteger
using Base.Broadcast: Broadcasted, broadcasted
using Base.PermutedDimsArrays: CommutativeOps
using OpenBLAS_jll
Expand Down
19 changes: 7 additions & 12 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,14 @@ end

#Converting from Bidiagonal to dense Matrix
function Matrix{T}(A::Bidiagonal) where T
n = size(A, 1)
B = Matrix{T}(undef, n, n)
n == 0 && return B
n > 1 && fill!(B, zero(T))
@inbounds for i = 1:n - 1
B[i,i] = A.dv[i]
if A.uplo == 'U'
B[i,i+1] = A.ev[i]
else
B[i+1,i] = A.ev[i]
end
B = Matrix{T}(undef, size(A))
if haszero(T) # optimized path for types with zero(T) defined
size(B,1) > 1 && fill!(B, zero(T))
copyto!(view(B, diagind(B)), A.dv)
copyto!(view(B, diagind(B, A.uplo == 'U' ? 1 : -1)), A.ev)
else
copyto!(B, A)
end
B[n,n] = A.dv[n]
return B
end
Matrix(A::Bidiagonal{T}) where {T} = Matrix{promote_type(T, typeof(zero(T)))}(A)
Expand Down
11 changes: 6 additions & 5 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,12 @@ AbstractMatrix{T}(D::Diagonal{T}) where {T} = copy(D)
Matrix(D::Diagonal{T}) where {T} = Matrix{promote_type(T, typeof(zero(T)))}(D)
Array(D::Diagonal{T}) where {T} = Matrix(D)
function Matrix{T}(D::Diagonal) where {T}
n = size(D, 1)
B = Matrix{T}(undef, n, n)
n > 1 && fill!(B, zero(T))
@inbounds for i in 1:n
B[i,i] = D.diag[i]
B = Matrix{T}(undef, size(D))
if haszero(T) # optimized path for types with zero(T) defined
size(B,1) > 1 && fill!(B, zero(T))
copyto!(view(B, diagind(B)), D.diag)
else
copyto!(B, D)
end
return B
end
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ julia> opnorm(A, 1)
5.0
```
"""
function opnorm(A::AbstractMatrix, p::Real=2)
Base.@constprop :aggressive function opnorm(A::AbstractMatrix, p::Real=2)
if p == 2
return opnorm2(A)
elseif p == 1
Expand Down
88 changes: 54 additions & 34 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,10 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA,
return _rmul_or_fill!(C, β)
end
if size(C) == size(A) == size(B) == (2,2)
return @stable_muladdmul matmul2x2!(C, tA, tB, A, B, MulAddMul(α, β))
return matmul2x2!(C, tA, tB, A, B, α, β)
end
if size(C) == size(A) == size(B) == (3,3)
return @stable_muladdmul matmul3x3!(C, tA, tB, A, B, MulAddMul(α, β))
return matmul3x3!(C, tA, tB, A, B, α, β)
end
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
Expand Down Expand Up @@ -459,6 +459,23 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, t
A
end

_fullstride2(A, f=identity) = f(stride(A, 2)) >= size(A, 1)
# for some standard StridedArrays, the _fullstride2 condition is known to hold at compile-time
# We specialize the function for certain StridedArray subtypes

# Similar to Base.RangeIndex, but only include range types where the step is statically known to be non-zero
const IncreasingRangeIndex = Union{BitInteger, AbstractUnitRange{<:BitInteger}}
const NonConstRangeIndex = Union{IncreasingRangeIndex, StepRange{<:BitInteger, <:BitInteger}}
# StridedArray subtypes for which _fullstride2(::T) === true is known from the type
const DenseOrStridedReshapedReinterpreted = Union{DenseArray, Base.StridedReshapedArray, Base.StridedReinterpretArray}
# Similar to Base.StridedSubArray, except with a NonConstRangeIndex instead of a RangeIndex
StridedSubArrayStandard{T,N,A,
I<:Tuple{Vararg{Union{NonConstRangeIndex, Base.ReshapedUnitRange, Base.AbstractCartesianIndex}}}} = Base.StridedSubArray{T,N,A,I}
_fullstride2(A::Union{DenseOrStridedReshapedReinterpreted,StridedSubArrayStandard}, ::typeof(abs)) = true
StridedSubArrayIncr{T,N,A,
I<:Tuple{Vararg{Union{IncreasingRangeIndex, Base.ReshapedUnitRange, Base.AbstractCartesianIndex}}}} = Base.StridedSubArray{T,N,A,I}
_fullstride2(A::Union{DenseOrStridedReshapedReinterpreted,StridedSubArrayIncr}, ::typeof(identity)) = true

Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar,
A::StridedVecOrMat{T}, x::StridedVector{T},
α::Number=true, β::Number=false) where {T<:BlasFloat}
Expand All @@ -472,7 +489,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
alpha, beta = promote(α, β, zero(T))
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
stride(A, 1) == 1 && _fullstride2(A, abs) &&
!iszero(stride(x, 1)) && # We only check input's stride here.
if tA_uc in ('N', 'T', 'C')
return BLAS.gemv!(tA, alpha, A, x, beta, y)
Expand Down Expand Up @@ -503,9 +520,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
alpha, beta = promote(α, β, zero(T))
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
stride(y, 1) == 1 && tA_uc == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
!iszero(stride(x, 1))
stride(A, 1) == 1 && _fullstride2(A, abs) &&
stride(y, 1) == 1 && tA_uc == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
!iszero(stride(x, 1))
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
Expand All @@ -527,8 +544,8 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
alpha, beta = promote(α, β, zero(T))
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
!iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C')
stride(A, 1) == 1 && _fullstride2(A, abs) &&
!iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C')
xfl = reinterpret(reshape, T, x) # Use reshape here.
yfl = reinterpret(reshape, T, y)
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
Expand Down Expand Up @@ -565,10 +582,9 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
if iszero(beta) || issymmetric(C)
α, β = promote(alpha, beta, zero(T))
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(C, 1) == 1 &&
stride(A, 2) >= size(A, 1) &&
stride(C, 2) >= size(C, 1))
beta isa Union{Bool,T} &&
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
end
end
Expand Down Expand Up @@ -601,10 +617,9 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
if iszero(β) || issymmetric(C)
alpha, beta = promote(α, β, zero(T))
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(C, 1) == 1 &&
stride(A, 2) >= size(A, 1) &&
stride(C, 2) >= size(C, 1))
beta isa Union{Bool,T} &&
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
end
end
Expand Down Expand Up @@ -652,11 +667,9 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab

alpha, beta = promote(α, β, zero(T))
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
stride(A, 2) >= size(A, 1) &&
stride(B, 2) >= size(B, 1) &&
stride(C, 2) >= size(C, 1))
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(B) && _fullstride2(C))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
Expand Down Expand Up @@ -688,11 +701,9 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}

# Make-sure reinterpret-based optimization is BLAS-compatible.
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
stride(A, 2) >= size(A, 1) &&
stride(B, 2) >= size(B, 1) &&
stride(C, 2) >= size(C, 1) && tA_uc == 'N')
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(B) && _fullstride2(C) && tA_uc == 'N')
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
Expand Down Expand Up @@ -985,9 +996,8 @@ Base.@constprop :aggressive function __matmul2x2_elements(tA, A::AbstractMatrix)
end
Base.@constprop :aggressive __matmul2x2_elements(tA, tB, A, B) = __matmul2x2_elements(tA, A), __matmul2x2_elements(tB, B)

Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
(A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements(C, tA, tB, A, B)
function _modify2x2!(Aelements, Belements, C, _add)
(A11, A12, A21, A22), (B11, B12, B21, B22) = Aelements, Belements
@inbounds begin
_modify!(_add, A11*B11 + A12*B21, C, (1,1))
_modify!(_add, A21*B11 + A22*B21, C, (2,1))
Expand All @@ -996,6 +1006,12 @@ Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::Ab
end # inbounds
C
end
Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
α = true, β = false)
Aelements, Belements = _matmul2x2_elements(C, tA, tB, A, B)
@stable_muladdmul _modify2x2!(Aelements, Belements, C, MulAddMul(α, β))
C
end

# Multiply 3x3 matrices
Base.@constprop :aggressive function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
Expand Down Expand Up @@ -1050,12 +1066,9 @@ Base.@constprop :aggressive function __matmul3x3_elements(tA, A::AbstractMatrix)
end
Base.@constprop :aggressive __matmul3x3_elements(tA, tB, A, B) = __matmul3x3_elements(tA, A), __matmul3x3_elements(tB, B)

Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())

function _modify3x3!(Aelements, Belements, C, _add)
(A11, A12, A13, A21, A22, A23, A31, A32, A33),
(B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements(C, tA, tB, A, B)

(B11, B12, B13, B21, B22, B23, B31, B32, B33) = Aelements, Belements
@inbounds begin
_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
_modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1))
Expand All @@ -1071,6 +1084,13 @@ Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::Ab
end # inbounds
C
end
Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
α = true, β = false)

Aelements, Belements = _matmul3x3_elements(C, tA, tB, A, B)
@stable_muladdmul _modify3x3!(Aelements, Belements, C, MulAddMul(α, β))
C
end

const RealOrComplex = Union{Real,Complex}

Expand Down
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,11 @@ function svd(A::RealHermSymComplexHerm; full::Bool=false)
end
return SVD(vecs, vals, V')
end
function svd(A::RealHermSymComplexHerm{Float16}; full::Bool = false)
T = eltype(A)
F = svd(eigencopy_oftype(A, eigtype(T)); full)
return SVD{T}(F)
end

function svdvals!(A::RealHermSymComplexHerm)
vals = eigvals!(A)
Expand Down
19 changes: 17 additions & 2 deletions stdlib/LinearAlgebra/src/symmetriceigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ function eigen(A::RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothin
S = eigtype(eltype(A))
eigen!(eigencopy_oftype(A, S), sortby=sortby)
end
function eigen(A::RealHermSymComplexHerm{Float16}; sortby::Union{Function,Nothing}=nothing)
S = eigtype(eltype(A))
E = eigen!(eigencopy_oftype(A, S), sortby=sortby)
values = convert(AbstractVector{Float16}, E.values)
vectors = convert(AbstractMatrix{isreal(E.vectors) ? Float16 : Complex{Float16}}, E.vectors)
return Eigen(values, vectors)
end

eigen!(A::RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}, irange::UnitRange) =
Eigen(LAPACK.syevr!('V', 'I', A.uplo, A.data, 0.0, 0.0, irange.start, irange.stop, -1.0)...)
Expand Down Expand Up @@ -295,8 +302,16 @@ function eigvals!(A::StridedMatrix{T}, F::LU{T,<:StridedMatrix}; sortby::Union{F
return eigvals!(A; sortby)
end


function eigen(A::Hermitian{Complex{T}, <:Tridiagonal}; kwargs...) where {T}
eigen(A::Hermitian{<:Complex, <:Tridiagonal}; kwargs...) =
_eigenhermtridiag(A; kwargs...)
# disambiguation
function eigen(A::Hermitian{Complex{Float16}, <:Tridiagonal}; kwargs...)
E = _eigenhermtridiag(A; kwargs...)
values = convert(AbstractVector{Float16}, E.values)
vectors = convert(AbstractMatrix{ComplexF16}, E.vectors)
return Eigen(values, vectors)
end
function _eigenhermtridiag(A::Hermitian{<:Complex,<:Tridiagonal}; kwargs...)
(; dl, d, du) = parent(A)
N = length(d)
if N <= 1
Expand Down
33 changes: 18 additions & 15 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,17 @@ function Matrix{T}(M::SymTridiagonal) where T
n = size(M, 1)
Mf = Matrix{T}(undef, n, n)
n == 0 && return Mf
n > 2 && fill!(Mf, zero(T))
@inbounds for i = 1:n-1
Mf[i,i] = symmetric(M.dv[i], :U)
Mf[i+1,i] = transpose(M.ev[i])
Mf[i,i+1] = M.ev[i]
if haszero(T) # optimized path for types with zero(T) defined
n > 2 && fill!(Mf, zero(T))
@inbounds for i = 1:n-1
Mf[i,i] = symmetric(M.dv[i], :U)
Mf[i+1,i] = transpose(M.ev[i])
Mf[i,i+1] = M.ev[i]
end
Mf[n,n] = symmetric(M.dv[n], :U)
else
copyto!(Mf, M)
end
Mf[n,n] = symmetric(M.dv[n], :U)
return Mf
end
Matrix(M::SymTridiagonal{T}) where {T} = Matrix{promote_type(T, typeof(zero(T)))}(M)
Expand Down Expand Up @@ -586,15 +590,14 @@ axes(M::Tridiagonal) = (ax = axes(M.d,1); (ax, ax))

function Matrix{T}(M::Tridiagonal) where {T}
A = Matrix{T}(undef, size(M))
n = length(M.d)
n == 0 && return A
n > 2 && fill!(A, zero(T))
for i in 1:n-1
A[i,i] = M.d[i]
A[i+1,i] = M.dl[i]
A[i,i+1] = M.du[i]
end
A[n,n] = M.d[n]
if haszero(T) # optimized path for types with zero(T) defined
size(A,1) > 2 && fill!(A, zero(T))
copyto!(view(A, diagind(A)), M.d)
copyto!(view(A, diagind(A,1)), M.du)
copyto!(view(A, diagind(A,-1)), M.dl)
else
copyto!(A, M)
end
A
end
Matrix(M::Tridiagonal{T}) where {T} = Matrix{promote_type(T, typeof(zero(T)))}(M)
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -904,4 +904,11 @@ end
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
end

@testset "Matrix conversion for non-numeric" begin
B = Bidiagonal(fill(Diagonal([1,3]), 3), fill(Diagonal([1,3]), 2), :U)
M = Matrix{eltype(B)}(B)
@test M isa Matrix{eltype(B)}
@test M == B
end

end # module TestBidiagonal
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1298,4 +1298,11 @@ end
@test yadj == x'
end

@testset "Matrix conversion for non-numeric" begin
D = Diagonal(fill(Diagonal([1,3]), 2))
M = Matrix{eltype(D)}(D)
@test M isa Matrix{eltype(D)}
@test M == D
end

end # module TestDiagonal
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ end
C .= C0 = rand(eltype(C), size(C))
@test mul!(C, vf, transpose(vf), 2, 3) 2vf * vf' .+ 3C0
end

@testset "zero stride" begin
for AAv in (view(AA, StepRangeLen(2,0,size(AA,1)), :),
view(AA, StepRangeLen.(2,0,size(AA))...),
view(complex.(AA, AA), StepRangeLen.(2,0,size(AA))...),)
for BB2 in (BB, complex.(BB, BB))
C = AAv * BB2
@test allequal(C)
@test C Array(AAv) * BB2
end
end
end
end

@testset "generic_matvecmul for vectors of vectors" begin
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ Random.seed!(1)
Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x))
Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero()
LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero()
LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero
Base.copy(A::TypeWithoutZero) = A
Base.transpose(::TypeWithoutZero) = TypeWithoutZero()
d = fill(TypeWithoutZero(), 3)
du = fill(TypeWithoutZero(), 2)
Expand Down
Loading

0 comments on commit 743978b

Please sign in to comment.