Skip to content

Commit

Permalink
Aggressive constprop in sparse * dense (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Oct 30, 2023
1 parent 0b36fdd commit 81fc6f3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUni
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion, B::DenseInputVector, _add::MulAddMul) =
spdensemul!(C, tA, 'N', A, B, _add)

function spdensemul!(C, tA, tB, A, B, _add)
Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
if tA == 'N'
_spmatmul!(C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta)
elseif tA == 'T'
Expand Down Expand Up @@ -97,7 +97,7 @@ end
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::DenseTriangular) =
(T = promote_op(matprod, eltype(A), eltype(B)); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B))

function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::AbstractSparseMatrixCSC, _add::MulAddMul)
Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::AbstractSparseMatrixCSC, _add::MulAddMul)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
Expand Down
4 changes: 2 additions & 2 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ function (*)(A::_StridedOrTriangularMatrix{Ta}, x::AbstractSparseVector{Tx}) whe
mul!(y, A, x)
end

function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
if tA == 'N'
_spmul!(y, A, x, _add.alpha, _add.beta)
Expand Down Expand Up @@ -1967,7 +1967,7 @@ function densemv(A::AbstractSparseMatrixCSC, x::AbstractSparseVector; trans::Abs
end

# * and mul!
function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
if tA == 'N'
_spmul!(y, A, x, _add.alpha, _add.beta)
Expand Down

0 comments on commit 81fc6f3

Please sign in to comment.