Skip to content

Commit

Permalink
Fix SparseVector broadcasting oddities (#26474)
Browse files Browse the repository at this point in the history
* Fix SparseVector broadcasting oddities

Lots of SparseVector oddities fixed here, all by deleting code.  Previous behaviors that are now fixed include:

```julia
julia> spzeros(5) .* [1]
ERROR: DimensionMismatch("")
Stacktrace:
 [1] _binarymap(::typeof(*), ::SparseVector{Float64,Int64}, ::Array{Int64,1}, ::Int64) at /home/mbauman/julia-wip3/usr/share/julia/site/v0.7/SparseArrays/src/sparsevector.jl:1325
 [2] _vmul at /home/mbauman/julia-wip3/usr/share/julia/site/v0.7/SparseArrays/src/sparsevector.jl:1369 [inlined]
 [3] broadcast(::typeof(*), ::SparseVector{Float64,Int64}, ::Array{Int64,1}) at /home/mbauman/julia-wip3/usr/share/julia/site/v0.7/SparseArrays/src/sparsevector.jl:1388
 [4] top-level scope

julia> spzeros(5) .* ones(5)
5-element Array{Float64,1}:
 0.0
 0.0
 0.0
 0.0
 0.0

julia> spzeros(5) .* ones(5) .* 1 # this is as expected
5-element SparseVector{Float64,Int64} with 0 stored entries
```

Those all now behave as you'd expect, returning a SparseVector in all cases.

* Remove debugging code

[ci skip]
  • Loading branch information
mbauman authored Mar 16, 2018
1 parent 83b0231 commit cf65ee1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 132 deletions.
5 changes: 5 additions & 0 deletions stdlib/SparseArrays/src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ import Base: asyncmap

Base.@deprecate_binding blkdiag blockdiag

@deprecate complex(x::AbstractSparseVector{<:Real}, y::AbstractSparseVector{<:Real}) complex.(x, y)
@deprecate complex(x::AbstractVector{<:Real}, y::AbstractSparseVector{<:Real}) complex.(x, y)
@deprecate complex(x::AbstractSparseVector{<:Real}, y::AbstractVector{<:Real}) complex.(x, y)


# END 0.7 deprecations

# BEGIN 1.0 deprecations
Expand Down
6 changes: 0 additions & 6 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,6 @@ end
storedvals = Vector{entrytype}(undef, maxnnz)
return SparseMatrixCSC(shape..., pointers, storedinds, storedvals)
end
# Ambiguity killers, TODO: nix conflicting specializations
ambiguityfunnel(f::Tf, x, y) where {Tf} = _aresameshape(x, y) ? _noshapecheck_map(f, x, y) : _diffshape_broadcast(f, x, y)
broadcast(::typeof(+), x::SparseVector, y::SparseVector) = ambiguityfunnel(+, x, y) # base/sparse/sparsevectors.jl:1266
broadcast(::typeof(-), x::SparseVector, y::SparseVector) = ambiguityfunnel(-, x, y) # base/sparse/sparsevectors.jl:1266
broadcast(::typeof(*), x::SparseVector, y::SparseVector) = ambiguityfunnel(*, x, y) # base/sparse/sparsevectors.jl:1266


# (4) _map_zeropres!/_map_notzeropres! specialized for a single sparse vector/matrix
"Stores only the nonzero entries of `map(f, Array(A))` in `C`."
Expand Down
135 changes: 9 additions & 126 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1273,138 +1273,21 @@ function _binarymap_mode_2!(f::Function, mx::Int, my::Int,
return ir
end

function _binarymap(f::Function,
x::AbstractVector{Tx},
y::AbstractSparseVector{Ty},
mode::Int) where {Tx,Ty}
0 <= mode <= 2 || throw(ArgumentError("Incorrect mode $mode."))
R = typeof(f(zero(Tx), zero(Ty)))
n = length(x)
length(y) == n || throw(DimensionMismatch())
# definition of a few known broadcasted/mapped binary functions — all others defer to HigherOrderFunctions

ynzind = nonzeroinds(y)
ynzval = nonzeros(y)
m = length(ynzind)

dst = Vector{R}(undef, n)
if mode == 0
ii = 1
@inbounds for i = 1:m
j = ynzind[i]
while ii < j
dst[ii] = zero(R); ii += 1
end
dst[j] = f(x[j], ynzval[i]); ii += 1
end
@inbounds while ii <= n
dst[ii] = zero(R); ii += 1
end
else # mode >= 1
ii = 1
@inbounds for i = 1:m
j = ynzind[i]
while ii < j
dst[ii] = f(x[ii], zero(Ty)); ii += 1
end
dst[j] = f(x[j], ynzval[i]); ii += 1
end
@inbounds while ii <= n
dst[ii] = f(x[ii], zero(Ty)); ii += 1
end
_bcast_binary_map(f, x, y, mode) = length(x) == length(y) ? _binarymap(f, x, y, mode) : HigherOrderFns._diffshape_broadcast(f, x, y)
for (fun, mode) in [(:+, 1), (:-, 1), (:*, 0), (:min, 2), (:max, 2)]
fun in (:+, :-) && @eval begin
# Addition and subtraction can be defined directly on the arrays (without map/broadcast)
$(fun)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
end
return dst
end

function _binarymap(f::Function,
x::AbstractSparseVector{Tx},
y::AbstractVector{Ty},
mode::Int) where {Tx,Ty}
0 <= mode <= 2 || throw(ArgumentError("Incorrect mode $mode."))
R = typeof(f(zero(Tx), zero(Ty)))
n = length(x)
length(y) == n || throw(DimensionMismatch())

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
m = length(xnzind)

dst = Vector{R}(undef, n)
if mode == 0
ii = 1
@inbounds for i = 1:m
j = xnzind[i]
while ii < j
dst[ii] = zero(R); ii += 1
end
dst[j] = f(xnzval[i], y[j]); ii += 1
end
@inbounds while ii <= n
dst[ii] = zero(R); ii += 1
end
else # mode >= 1
ii = 1
@inbounds for i = 1:m
j = xnzind[i]
while ii < j
dst[ii] = f(zero(Tx), y[ii]); ii += 1
end
dst[j] = f(xnzval[i], y[j]); ii += 1
end
@inbounds while ii <= n
dst[ii] = f(zero(Tx), y[ii]); ii += 1
end
end
return dst
end


### Binary arithmetics: +, -, *

for (vop, fun, mode) in [(:_vadd, :+, 1),
(:_vsub, :-, 1),
(:_vmul, :*, 0)]
@eval begin
$(vop)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
$(vop)(x::AbstractVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
$(vop)(x::AbstractSparseVector, y::AbstractVector) = _binarymap($(fun), x, y, $mode)
map(::typeof($fun), x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($fun, x, y, $mode)
broadcast(::typeof($fun), x::AbstractSparseVector, y::AbstractSparseVector) = _bcast_binary_map($fun, x, y, $mode)
broadcast(::typeof($fun), x::SparseVector, y::SparseVector) = _bcast_binary_map($fun, x, y, $mode)
end
end

# to workaround the ambiguities with BitVector
broadcast(::typeof(*), x::BitVector, y::AbstractSparseVector{Bool}) = _vmul(x, y)
broadcast(::typeof(*), x::AbstractSparseVector{Bool}, y::BitVector) = _vmul(x, y)

# definition of operators

for (op, vop) in [(:+, :_vadd), (:-, :_vsub), (:*, :_vmul)]
op != :* && @eval begin
$(op)(x::AbstractSparseVector, y::AbstractSparseVector) = $(vop)(x, y)
$(op)(x::AbstractVector, y::AbstractSparseVector) = $(vop)(x, y)
$(op)(x::AbstractSparseVector, y::AbstractVector) = $(vop)(x, y)
end
@eval begin
broadcast(::typeof($op), x::AbstractSparseVector, y::AbstractSparseVector) = $(vop)(x, y)
broadcast(::typeof($op), x::AbstractVector, y::AbstractSparseVector) = $(vop)(x, y)
broadcast(::typeof($op), x::AbstractSparseVector, y::AbstractVector) = $(vop)(x, y)
end
end

# definition of other binary functions

broadcast(::typeof(min), x::SparseVector{<:Real}, y::SparseVector{<:Real}) = _binarymap(min, x, y, 2)
broadcast(::typeof(min), x::AbstractSparseVector{<:Real}, y::AbstractSparseVector{<:Real}) = _binarymap(min, x, y, 2)
broadcast(::typeof(min), x::AbstractVector{<:Real}, y::AbstractSparseVector{<:Real}) = _binarymap(min, x, y, 2)
broadcast(::typeof(min), x::AbstractSparseVector{<:Real}, y::AbstractVector{<:Real}) = _binarymap(min, x, y, 2)

broadcast(::typeof(max), x::SparseVector{<:Real}, y::SparseVector{<:Real}) = _binarymap(max, x, y, 2)
broadcast(::typeof(max), x::AbstractSparseVector{<:Real}, y::AbstractSparseVector{<:Real}) = _binarymap(max, x, y, 2)
broadcast(::typeof(max), x::AbstractVector{<:Real}, y::AbstractSparseVector{<:Real}) = _binarymap(max, x, y, 2)
broadcast(::typeof(max), x::AbstractSparseVector{<:Real}, y::AbstractVector{<:Real}) = _binarymap(max, x, y, 2)

complex(x::AbstractSparseVector{<:Real}, y::AbstractSparseVector{<:Real}) = _binarymap(complex, x, y, 1)
complex(x::AbstractVector{<:Real}, y::AbstractSparseVector{<:Real}) = _binarymap(complex, x, y, 1)
complex(x::AbstractSparseVector{<:Real}, y::AbstractVector{<:Real}) = _binarymap(complex, x, y, 1)

### Reduction

sum(x::AbstractSparseVector) = sum(nonzeros(x))
Expand Down
24 changes: 24 additions & 0 deletions stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,30 @@ end
@test spzeros(1,2) .* spzeros(0,1) == zeros(0,2)
end

@testset "sparse vector broadcast of two arguments" begin
sv1, sv5 = sprand(1, 1.), sprand(5, 1.)
for (sa, sb) in ((sv1, sv1), (sv1, sv5), (sv5, sv1), (sv5, sv5))
fa, fb = Vector(sa), Vector(sb)
for f in (+, -, *, min, max)
@test @inferred(broadcast(f, sa, sb))::SparseVector == broadcast(f, fa, fb)
@test @inferred(broadcast(f, Vector(sa), sb))::SparseVector == broadcast(f, fa, fb)
@test @inferred(broadcast(f, sa, Vector(sb)))::SparseVector == broadcast(f, fa, fb)
@test @inferred(broadcast(f, SparseMatrixCSC(sa), sb))::SparseMatrixCSC == broadcast(f, reshape(fa, Val(2)), fb)
@test @inferred(broadcast(f, sa, SparseMatrixCSC(sb)))::SparseMatrixCSC == broadcast(f, fa, reshape(fb, Val(2)))
if length(fa) == length(fb)
@test @inferred(map(f, sa, sb))::SparseVector == broadcast(f, fa, fb)
end
end
if length(fa) == length(fb)
for f in (+, -)
@test @inferred(f(sa, sb))::SparseVector == f(fa, fb)
@test @inferred(f(Vector(sa), sb))::SparseVector == f(fa, fb)
@test @inferred(f(sa, Vector(sb)))::SparseVector == f(fa, fb)
end
end
end
end

@testset "aliasing and indexed assignment or broadcast!" begin
A = sparsevec([0, 0, 1, 1])
B = sparsevec([1, 1, 0, 0])
Expand Down

0 comments on commit cf65ee1

Please sign in to comment.