diff --git a/Project.toml b/Project.toml index f3c9e95..176f2bf 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -14,6 +15,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -22,6 +24,7 @@ StructArraysAdaptExt = "Adapt" StructArraysGPUArraysCoreExt = "GPUArraysCore" StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +StructArraysLinearAlgebraExt = "LinearAlgebra" [compat] Adapt = "3.4, 4" diff --git a/ext/StructArraysLinearAlgebraExt.jl b/ext/StructArraysLinearAlgebraExt.jl new file mode 100644 index 0000000..6699d85 --- /dev/null +++ b/ext/StructArraysLinearAlgebraExt.jl @@ -0,0 +1,25 @@ +module StructArraysLinearAlgebraExt + +using StructArrays +using LinearAlgebra +import LinearAlgebra: mul! + +const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}} +const StructVectorC{T, A<:AbstractVector{T}} = StructArrays.StructVector{Complex{T}, @NamedTuple{re::A, im::A}} + +function _mul!(C, A, B, alpha, beta) + mul!(C.re, A.re, B.re, alpha, beta) + mul!(C.re, A.im, B.im, -alpha, oneunit(beta)) + mul!(C.im, A.re, B.im, alpha, beta) + mul!(C.im, A.im, B.re, alpha, oneunit(beta)) + C +end + +function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end +function mul!(C::StructVectorC, A::StructMatrixC, B::StructVectorC, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end + +end diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 7c1b1dd..57058ce 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -30,6 +30,7 @@ end include("../ext/StructArraysGPUArraysCoreExt.jl") include("../ext/StructArraysSparseArraysExt.jl") include("../ext/StructArraysStaticArraysExt.jl") + include("../ext/StructArraysLinearAlgebraExt.jl") end end # module diff --git a/src/structarray.jl b/src/structarray.jl index d83ccbd..0f13ad4 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -106,6 +106,7 @@ _structarray(args::Tuple, ::Tuple) = _structarray(args, nothing) _structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args)) const StructVector{T, C<:Tup, I} = StructArray{T, 1, C, I} +const StructMatrix{T, C<:Tup, I} = StructArray{T, 2, C, I} StructVector{T}(args...; kwargs...) where {T} = StructArray{T}(args...; kwargs...) StructVector(args...; kwargs...) = StructArray(args...; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index 47c1564..83b0b96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1206,7 +1206,7 @@ end # The following code defines `MyArray1/2/3` with different `BroadcastStyle`s. # 1. `MyArray1` and `MyArray1` have `similar` defined. # We use them to simulate `BroadcastStyle` overloading `Base.copyto!`. -# 2. `MyArray3` has no `similar` defined. +# 2. `MyArray3` has no `similar` defined. # We use it to simulate `BroadcastStyle` overloading `Base.copy`. # 3. Their resolved style could be summaryized as (`-` means conflict) # | MyArray1 | MyArray2 | MyArray3 | Array @@ -1302,7 +1302,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS f(s) = s .+= 1 f(s) @test (@allocated f(s)) == 0 - + # issue #185 A = StructArray(randn(ComplexF64, 3, 3)) B = randn(ComplexF64, 3, 3) @@ -1321,7 +1321,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @testset "ambiguity check" begin test_set = Any[StructArray([1;2+im]), - 1:2, + 1:2, (1,2), StructArray(@SArray [1;1+2im]), (@SArray [1 2]), @@ -1550,6 +1550,23 @@ end @test Base.IteratorSize(S) == Base.IsInfinite() end +@testset "LinearAlgebra" begin + @testset "matrix * matrix" begin + A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + B = StructArray{ComplexF64}((rand(size(A)...), rand(size(A)...))) + MA, MB = Matrix(A), Matrix(B) + @test A * B ≈ MA * MB + @test mul!(ones(ComplexF64,size(A)), A, B, 2.0, 3.0) ≈ 2 * A * B .+ 3 + end + @testset "matrix * vector" begin + A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + v = StructArray{ComplexF64}((rand(size(A,2)), rand(size(A,2)))) + MA, Mv = Matrix(A), Vector(v) + @test A * v ≈ MA * Mv + @test mul!(ones(ComplexF64,size(v)), A, v, 2.0, 3.0) ≈ 2 * A * v .+ 3 + end +end + @testset "project quality" begin Aqua.test_all(StructArrays, ambiguities=(; broken=true)) end