Skip to content

Commit

Permalink
Make sparse scale!, *, and / on columns fast and add tests (#24981)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack authored Dec 9, 2017
1 parent f8d3fbb commit 0784e39
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
16 changes: 8 additions & 8 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1475,14 +1475,14 @@ end

# scale

scale!(x::AbstractSparseVector, a::Real) = (scale!(nonzeros(x), a); x)
scale!(x::AbstractSparseVector, a::Complex) = (scale!(nonzeros(x), a); x)
scale!(a::Real, x::AbstractSparseVector) = (scale!(nonzeros(x), a); x)
scale!(a::Complex, x::AbstractSparseVector) = (scale!(nonzeros(x), a); x)

(*)(x::AbstractSparseVector, a::Number) = SparseVector(length(x), copy(nonzeroinds(x)), nonzeros(x) * a)
(*)(a::Number, x::AbstractSparseVector) = SparseVector(length(x), copy(nonzeroinds(x)), a * nonzeros(x))
(/)(x::AbstractSparseVector, a::Number) = SparseVector(length(x), copy(nonzeroinds(x)), nonzeros(x) / a)
scale!(x::SparseVectorUnion, a::Real) = (scale!(nonzeros(x), a); x)
scale!(x::SparseVectorUnion, a::Complex) = (scale!(nonzeros(x), a); x)
scale!(a::Real, x::SparseVectorUnion) = (scale!(nonzeros(x), a); x)
scale!(a::Complex, x::SparseVectorUnion) = (scale!(nonzeros(x), a); x)

(*)(x::SparseVectorUnion, a::Number) = SparseVector(length(x), copy(nonzeroinds(x)), nonzeros(x) * a)
(*)(a::Number, x::SparseVectorUnion) = SparseVector(length(x), copy(nonzeroinds(x)), a * nonzeros(x))
(/)(x::SparseVectorUnion, a::Number) = SparseVector(length(x), copy(nonzeroinds(x)), nonzeros(x) / a)

# dot
function dot(x::StridedVector{Tx}, y::SparseVectorUnion{Ty}) where {Tx<:Number,Ty<:Number}
Expand Down
18 changes: 18 additions & 0 deletions test/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1240,3 +1240,21 @@ end
@test length(simA.rowval) == length(A.nzind)
@test length(simA.nzval) == length(A.nzval)
end

@testset "Fast operations on full column views" begin
n = 1000
A = sprandn(n, n, 0.01)
for j in 1:5:n
Aj, Ajview = A[:, j], view(A, :, j)
@test norm(Aj) == norm(Ajview)
@test dot(Aj, copy(Aj)) == dot(Ajview, Aj) # don't alias since it takes a different code path
@test scale!(Aj, 0.1) == scale!(Ajview, 0.1)
@test Aj*0.1 == Ajview*0.1
@test 0.1*Aj == 0.1*Ajview
@test Aj/0.1 == Ajview/0.1
@test LinAlg.axpy!(1.0, Aj, sparse(ones(n))) ==
LinAlg.axpy!(1.0, Ajview, sparse(ones(n)))
@test LinAlg.lowrankupdate!(Matrix(1.0*I, n, n), fill(1.0, n), Aj) ==
LinAlg.lowrankupdate!(Matrix(1.0*I, n, n), fill(1.0, n), Ajview)
end
end

0 comments on commit 0784e39

Please sign in to comment.