Skip to content

Commit

Permalink
Fix OneElement multiplication with array elements (#335)
Browse files Browse the repository at this point in the history
* Fix OneElement multiplication with array elements

* Fix matmul for array elements in OneElMat * StridedMat
  • Loading branch information
jishnub authored Apr 25, 2024
1 parent b0ee65f commit 8734371
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 39 deletions.
67 changes: 28 additions & 39 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha
end

@inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta)
αx = alpha * x.val
= Ref(x.val * alpha)
ind1 = x.ind[1]
if iszero(beta)
y .= αx .* view(A, :, ind1)
y .= view(A, :, ind1) .*
else
y .= αx .* view(A, :, ind1) .+ beta .* y
y .= view(A, :, ind1) .*.+ y .* beta
end
return y
end
Expand All @@ -171,13 +171,14 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha,
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
return C
end
nzrow, nzcol = B.ind
if iszero(beta)
C .= zero(eltype(C))
C .= Ref(zero(eltype(C)))
else
view(C, :, 1:B.ind[2]-1) .*= beta
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
view(C, :, 1:nzcol-1) .*= beta
view(C, :, nzcol+1:size(C,2)) .*= beta
end
y = view(C, :, B.ind[2])
y = view(C, :, nzcol)
__mul!(y, A, B, alpha, beta)
C
end
Expand All @@ -187,17 +188,14 @@ function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta)
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, :, 1:B.ind[2]-1) .*= beta
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
end
ABα = A * B * alpha
nzrow, nzcol = B.ind
ABα = A * B * alpha
if iszero(beta)
C[B.ind...] = ABα[B.ind...]
C .= Ref(zero(eltype(C)))
C[nzrow, nzcol] = ABα[nzrow, nzcol]
else
view(C, :, 1:nzcol-1) .*= beta
view(C, :, nzcol+1:size(C,2)) .*= beta
y = view(C, :, nzcol)
y .= view(ABα, :, nzcol) .+ y .* beta
end
Expand All @@ -210,19 +208,16 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha,
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:A.ind[1]-1, :) .*= beta
view(C, A.ind[1]+1:size(C,1), :) .*= beta
end
y = view(C, A.ind[1], :)
ind2 = A.ind[2]
nzrow, nzcol = A.ind
y = view(C, nzrow, :)
Aval = A.val
if iszero(beta)
y .= Aval .* view(B, ind2, :) .* alpha
C .= Ref(zero(eltype(C)))
y .= Ref(Aval) .* view(B, nzcol, :) .* alpha
else
y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta
view(C, 1:nzrow-1, :) .*= beta
view(C, nzrow+1:size(C,1), :) .*= beta
y .= Ref(Aval) .* view(B, nzcol, :) .* alpha .+ y .* beta
end
C
end
Expand All @@ -232,17 +227,14 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta)
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:A.ind[1]-1, :) .*= beta
view(C, A.ind[1]+1:size(C,1), :) .*= beta
end
ABα = A * B * alpha
nzrow, nzcol = A.ind
ABα = A * B * alpha
if iszero(beta)
C[A.ind...] = ABα[A.ind...]
C .= Ref(zero(eltype(C)))
C[nzrow, nzcol] = ABα[nzrow, nzcol]
else
view(C, 1:nzrow-1, :) .*= beta
view(C, nzrow+1:size(C,1), :) .*= beta
y = view(C, nzrow, :)
y .= view(ABα, nzrow, :) .+ y .* beta
end
Expand All @@ -256,16 +248,13 @@ function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha,
return C
end
nzrow, nzcol = A.ind
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:nzrow-1) .*= beta
view(C, nzrow+1:size(C,1)) .*= beta
end
Aval = A.val
if iszero(beta)
C .= Ref(zero(eltype(C)))
C[nzrow] = Aval * B[nzcol] * alpha
else
view(C, 1:nzrow-1) .*= beta
view(C, nzrow+1:size(C,1)) .*= beta
C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta
end
C
Expand Down
87 changes: 87 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2318,6 +2318,93 @@ end
@test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2
end
end
@testset "array elements" begin
A = [SMatrix{2,3}(1:6)*(i+j) for i in 1:3, j in 1:2]
@testset "StridedMatrix * OneElementMatrix" begin
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4))
C = [SMatrix{2,2}(1:4) for i in axes(A,1), j in axes(B,2)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "StridedMatrix * OneElementVector" begin
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),))
C = [SMatrix{2,2}(1:4) for i in axes(A,1)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end

A = OneElement(SMatrix{3,2}(1:6), (3,2), (5,4))
@testset "OneElementMatrix * StridedMatrix" begin
B = [SMatrix{2,3}(1:6)*(i+j) for i in axes(A,2), j in 1:2]
C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "OneElementMatrix * StridedVector" begin
B = [SMatrix{2,3}(1:6)*i for i in axes(A,2)]
C = [SMatrix{3,3}(1:9) for i in axes(A,1)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "OneElementMatrix * OneElementMatrix" begin
B = OneElement(SMatrix{2,3}(1:6), (2,4), (size(A,2), 3))
C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "OneElementMatrix * OneElementVector" begin
B = OneElement(SMatrix{2,3}(1:6), 2, size(A,2))
C = [SMatrix{3,3}(1:9) for i in axes(A,1)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
end
@testset "non-commutative" begin
A = OneElement(quat(rand(4)...), (2,3), (3,4))
for (B,C) in (
# OneElementMatrix * OneElementVector
(OneElement(quat(rand(4)...), 3, size(A,2)),
[quat(rand(4)...) for i in axes(A,1)]),

# OneElementMatrix * OneElementMatrix
(OneElement(quat(rand(4)...), (3,2), (size(A,2), 4)),
[quat(rand(4)...) for i in axes(A,1), j in 1:4]),
)
@test mul!(copy(C), A, B) A * B
α, β = quat(0,0,1,0), quat(1,0,1,0)
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
end

A = [quat(rand(4)...)*(i+j) for i in 1:2, j in 1:3]
for (B,C) in (
# StridedMatrix * OneElementVector
(OneElement(quat(rand(4)...), 1, size(A,2)),
[quat(rand(4)...) for i in axes(A,1)]),

# StridedMatrix * OneElementMatrix
(OneElement(quat(rand(4)...), (2,2), (size(A,2), 4)),
[quat(rand(4)...) for i in axes(A,1), j in 1:4]),
)
@test mul!(copy(C), A, B) A * B
α, β = quat(0,0,1,0), quat(1,0,1,0)
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
end

A = OneElement(quat(rand(4)...), (2,2), (3, 4))
for (B,C) in (
# OneElementMatrix * StridedMatrix
([quat(rand(4)...) for i in axes(A,2), j in 1:3],
[quat(rand(4)...) for i in axes(A,1), j in 1:3]),

# OneElementMatrix * StridedVector
([quat(rand(4)...) for i in axes(A,2)],
[quat(rand(4)...) for i in axes(A,1)]),
)
@test mul!(copy(C), A, B) A * B
α, β = quat(0,0,1,0), quat(1,0,1,0)
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
end
end
end

@testset "multiplication/division by a number" begin
Expand Down

0 comments on commit 8734371

Please sign in to comment.