diff --git a/src/numerics.jl b/src/numerics.jl index cfce54f..c214f61 100644 --- a/src/numerics.jl +++ b/src/numerics.jl @@ -2,6 +2,7 @@ using Compat # Compat@v3.23 for sincospi() end using SparseArrays +using SparseArrays: AbstractSparseMatrix # COV_EXCL_START @@ -23,18 +24,27 @@ unchecked_sqrt(x::T) where {T <: Integer} = unchecked_sqrt(float(x)) unchecked_sqrt(x) = Base.sqrt(x) """ - quadprod(A, b, n, dir=:col) + quadprod(A::AbstractSparseMatrixCSC, b::AbstractVecOrMat, n::Integer) Computes the quadratic product ``ABA^\\top`` efficiently for the case where ``B`` is all zero -except for the `n`th column or row vector `b`, for `dir = :col` or `dir = :row`, -respectively. +except for a small number of columns `b` starting at the `n`th. """ -@inline function quadprod(A, b, n, dir::Symbol=:col) - if dir == :col - return (A * sparse(b)) * view(A, :, n)' - elseif dir == :row - return view(A, :, n) * (A * sparse(b))' +function quadprod(A::AbstractSparseMatrix, b::AbstractVecOrMat, n::Integer) + size(b, 1) == size(A, 2) || throw(DimensionMismatch()) + + # sparse * dense naturally returns dense, but we want to dispatch to + # a sparse-sparse matrix multiplication, so forceably sparsify. + # - Tests with a few example matrices A show that `sparse(A * b)` is faster than + # `A * sparse(b)`. + w = sparse(A * b) + p = n + size(b, 2) - 1 + + if ndims(w) == 1 + # vector outer product using column view into matrix is fast + C = w * transpose(view(A, :, n)) else - error("Unrecognized direction `dir = $(repr(dir))`.") + # views are not fast for multiple columns; subset copies are faster + C = w * transpose(A[:, n:p]) end + return C end