diff --git a/Project.toml b/Project.toml index 4f5f210..91a8710 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -14,6 +15,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -22,6 +24,7 @@ StructArraysAdaptExt = "Adapt" StructArraysGPUArraysCoreExt = "GPUArraysCore" StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +StructArraysLinearAlgebraExt = "LinearAlgebra" [compat] Adapt = "3.4" @@ -29,6 +32,7 @@ ConstructionBase = "1" DataAPI = "1" GPUArraysCore = "0.1.2" InfiniteArrays = "0.13" +LinearAlgebra = "1" StaticArrays = "1.5.6" Tables = "1" julia = "1.6" diff --git a/ext/StructArraysLinearAlgebraExt.jl b/ext/StructArraysLinearAlgebraExt.jl new file mode 100644 index 0000000..ebb562d --- /dev/null +++ b/ext/StructArraysLinearAlgebraExt.jl @@ -0,0 +1,17 @@ +module StructArraysLinearAlgebraExt + +using StructArrays +using LinearAlgebra +import LinearAlgebra: mul! + +const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}} + +function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) + mul!(C.re, A.re, B.re, alpha, beta) + mul!(C.re, A.im, B.im, -alpha, oneunit(beta)) + mul!(C.im, A.re, B.im, alpha, beta) + mul!(C.im, A.im, B.re, alpha, oneunit(beta)) + C +end + +end diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 7c1b1dd..57058ce 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -30,6 +30,7 @@ end include("../ext/StructArraysGPUArraysCoreExt.jl") include("../ext/StructArraysSparseArraysExt.jl") include("../ext/StructArraysStaticArraysExt.jl") + include("../ext/StructArraysLinearAlgebraExt.jl") end end # module diff --git a/src/structarray.jl b/src/structarray.jl index 64ff17a..239c5d7 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -106,6 +106,7 @@ _structarray(args::Tuple, ::Tuple) = _structarray(args, nothing) _structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args)) const StructVector{T, C<:Tup, I} = StructArray{T, 1, C, I} +const StructMatrix{T, C<:Tup, I} = StructArray{T, 2, C, I} StructVector{T}(args...; kwargs...) where {T} = StructArray{T}(args...; kwargs...) StructVector(args...; kwargs...) = StructArray(args...; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index ed6f573..9e97115 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1510,3 +1510,13 @@ end S = StructArray{Complex{Int}}((1:∞, 1:∞)) @test Base.IteratorSize(S) == Base.IsInfinite() end + +@testset "LinearAlgebra" begin + @testset "matrix * matrix" begin + A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + B = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + MA, MB = Matrix(A), Matrix(B) + @test A * B ≈ MA * MB + @test A * A ≈ MA * MA + end +end