diff --git a/src/numerics.jl b/src/numerics.jl index cfce54f..6b6af87 100644 --- a/src/numerics.jl +++ b/src/numerics.jl @@ -2,6 +2,7 @@ using Compat # Compat@v3.23 for sincospi() end using SparseArrays +using SparseArrays: AbstractSparseMatrixCSC # COV_EXCL_START @@ -23,18 +24,27 @@ unchecked_sqrt(x::T) where {T <: Integer} = unchecked_sqrt(float(x)) unchecked_sqrt(x) = Base.sqrt(x) """ - quadprod(A, b, n, dir=:col) + quadprod(A::AbstractSparseMatrixCSC, b::AbstractVecOrMat, n::Integer) Computes the quadratic product ``ABA^\\top`` efficiently for the case where ``B`` is all zero -except for the `n`th column or row vector `b`, for `dir = :col` or `dir = :row`, -respectively. +except for a small number of columns `b` starting at the `n`th. """ -@inline function quadprod(A, b, n, dir::Symbol=:col) - if dir == :col - return (A * sparse(b)) * view(A, :, n)' - elseif dir == :row - return view(A, :, n) * (A * sparse(b))' +function quadprod(A::AbstractSparseMatrixCSC, b::AbstractVecOrMat, n::Integer) + size(b, 1) == size(A, 2) || throw(DimensionMismatch()) + + # sparse * dense naturally returns dense, but we want to dispatch to + # a sparse-sparse matrix multiplication, so forceably sparsify. + # - Tests with a few example matrices A show that `sparse(A * b)` is faster than + # `A * sparse(b)`. + w = sparse(A * b) + p = n + size(b, 2) - 1 + + if ndims(w) == 1 + # vector outer product using column view into matrix is fast + C = w * transpose(view(A, :, n)) else - error("Unrecognized direction `dir = $(repr(dir))`.") + # views are not fast for multiple columns; subset copies are faster + C = w * transpose(A[:, n:p]) end + return C end diff --git a/test/numerics.jl b/test/numerics.jl index 7b99016..df0f0a2 100644 --- a/test/numerics.jl +++ b/test/numerics.jl @@ -11,14 +11,18 @@ end using SparseArrays using CMB: quadprod (m, n) = (10, 25) - i = 7 + i, j = 7, 3 A = sprand(T, m, n, 0.5) + + # single vector b = rand(T, n) - Bc = sparse(collect(1:n), fill(i,n), b, n, n) - Br = sparse(fill(i,n), collect(1:n), b, n, n) + B = sparse(collect(1:n), fill(i,n), b, n, n) + @test A * B * A' == quadprod(A, b, i) + @test @inferred(quadprod(A, b, i)) isa SparseMatrixCSC{T, Int} - @test A * Bc * A' == quadprod(A, b, i, :col) - @test @inferred(quadprod(A, b, i, :col)) isa SparseMatrixCSC - @test A * Br * A' ≈ quadprod(A, b, i, :row) - @test @inferred(quadprod(A, b, i, :row)) isa SparseMatrixCSC + # block of columns + b = rand(T, n, j) + B[:,i:i+j-1] .= b + @test A * B * A' == quadprod(A, b, i) + @test @inferred(quadprod(A, b, i)) isa SparseMatrixCSC{T, Int} end