From cd1c12cf847ae93c1002a30dec7394d39d4d1009 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 19 Jan 2024 20:38:03 +0530 Subject: [PATCH] mul for vectors --- ext/StructArraysLinearAlgebraExt.jl | 10 +++++++++- test/runtests.jl | 11 +++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/ext/StructArraysLinearAlgebraExt.jl b/ext/StructArraysLinearAlgebraExt.jl index ebb562d..7908e08 100644 --- a/ext/StructArraysLinearAlgebraExt.jl +++ b/ext/StructArraysLinearAlgebraExt.jl @@ -5,8 +5,9 @@ using LinearAlgebra import LinearAlgebra: mul! const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}} +const StructVectorC{T, A<:AbstractVector{T}} = StructArrays.StructVector{Complex{T}, @NamedTuple{re::A, im::A}} -function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) +function _mul!(C, A, B, alpha, beta) mul!(C.re, A.re, B.re, alpha, beta) mul!(C.re, A.im, B.im, -alpha, oneunit(beta)) mul!(C.im, A.re, B.im, alpha, beta) @@ -14,4 +15,11 @@ function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Numbe C end +function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end +function mul!(C::StructVectorC, A::StructMatrixC, B::StructVectorC, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end + end diff --git a/test/runtests.jl b/test/runtests.jl index 9e97115..42ca450 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1514,9 +1514,16 @@ end @testset "LinearAlgebra" begin @testset "matrix * matrix" begin A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) - B = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + B = StructArray{ComplexF64}((rand(size(A)...), rand(size(A)...))) MA, MB = Matrix(A), Matrix(B) @test A * B ≈ MA * MB - @test A * A ≈ MA * MA + @test mul!(ones(ComplexF64,size(A)), A, B, 2.0, 3.0) ≈ 2 * A * B .+ 3 + end + @testset "matrix * vector" begin + A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + v = StructArray{ComplexF64}((rand(size(A,2)), rand(size(A,2)))) + MA, Mv = Matrix(A), Vector(v) + @test A * v ≈ MA * Mv + @test mul!(ones(ComplexF64,size(v)), A, v, 2.0, 3.0) ≈ 2 * A * v .+ 3 end end