Skip to content

Commit

Permalink
mul for vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Jan 19, 2024
1 parent 6f54802 commit cd1c12c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
10 changes: 9 additions & 1 deletion ext/StructArraysLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ using LinearAlgebra
import LinearAlgebra: mul!

const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}}
const StructVectorC{T, A<:AbstractVector{T}} = StructArrays.StructVector{Complex{T}, @NamedTuple{re::A, im::A}}

function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number)
function _mul!(C, A, B, alpha, beta)
mul!(C.re, A.re, B.re, alpha, beta)
mul!(C.re, A.im, B.im, -alpha, oneunit(beta))
mul!(C.im, A.re, B.im, alpha, beta)
mul!(C.im, A.im, B.re, alpha, oneunit(beta))
C
end

function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::StructVectorC, A::StructMatrixC, B::StructVectorC, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

end
11 changes: 9 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1514,9 +1514,16 @@ end
@testset "LinearAlgebra" begin
@testset "matrix * matrix" begin
A = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
B = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
B = StructArray{ComplexF64}((rand(size(A)...), rand(size(A)...)))
MA, MB = Matrix(A), Matrix(B)
@test A * B MA * MB
@test A * A MA * MA
@test mul!(ones(ComplexF64,size(A)), A, B, 2.0, 3.0) 2 * A * B .+ 3
end
@testset "matrix * vector" begin
A = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
v = StructArray{ComplexF64}((rand(size(A,2)), rand(size(A,2))))
MA, Mv = Matrix(A), Vector(v)
@test A * v MA * Mv
@test mul!(ones(ComplexF64,size(v)), A, v, 2.0, 3.0) 2 * A * v .+ 3
end
end

0 comments on commit cd1c12c

Please sign in to comment.