diff --git a/src/implementations/LinearAlgebra.jl b/src/implementations/LinearAlgebra.jl index 0caddd7..de806ad 100644 --- a/src/implementations/LinearAlgebra.jl +++ b/src/implementations/LinearAlgebra.jl @@ -98,6 +98,24 @@ function operate!( return broadcast!(op, A, B) end +function operate_to!( + output::AbstractArray, + op::Union{typeof(+),typeof(-)}, + A::AbstractArray, +) + if axes(output) != axes(A) + throw( + DimensionMismatch( + "Cannot sum or substract a matrix of axes `$(axes(A))`" * + " into a matrix of axes `$(axes(output))`, expected" * + " axes `$(axes(A))`.", + ), + ) + end + # We don't have `MA.broadcast_to!` as it would be exactly `Base.broadcast!`. + return Base.broadcast!(op, output, A) +end + function operate_to!( output::AbstractArray, op::Union{typeof(+),typeof(-)}, diff --git a/test/matmul.jl b/test/matmul.jl index fdbf9fe..90fc363 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -139,6 +139,11 @@ end "Cannot sum or substract matrices of axes `$(axes(A))` and `$(axes(B))` into a matrix of axes `$(axes(output))`, expected axes `$(axes(B))`.", ) @test_throws err MA.operate_to!(output, +, A, B) + err = DimensionMismatch( + "Cannot sum or substract a matrix of axes `$(axes(A))` into a matrix of axes `$(axes(output))`, expected axes `$(axes(A))`.", + ) + @test_throws err MA.operate_to!(output, +, A) + @test_throws err MA.operate_to!(output, -, A) end @testset "unsupported_product" begin unsupported_product() @@ -471,17 +476,14 @@ function test_sparse_vector_sum(::Type{T}) where {T} x = SparseArrays.sparsevec([1, 3], T[5, 7]) y = copy(x) z = copy(y) - alloc_test(() -> MA.operate!(+, y, z), 0) - alloc_test(() -> MA.operate!(-, y, z), 0) - alloc_test(() -> MA.add!!(y, z), 0) - alloc_test(() -> MA.sub!!(y, z), 0) alloc_test(() -> MA.operate_to!(x, +, y, z), 0) alloc_test(() -> MA.operate_to!(x, -, y, z), 0) - alloc_test(() -> MA.add_to!!(x, y, z), 0) - alloc_test(() -> MA.sub_to!!(x, y, z), 0) + alloc_test(() -> MA.operate_to!(x, +, y), 0) + alloc_test(() -> MA.operate_to!(x, -, y), 0) return end @testset "Array sum" begin test_array_sum(Int) + test_sparse_vector_sum(Int) end