From 6f5480216333d56124581bd7b60835d83b9f3e51 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 19 Jan 2024 17:11:06 +0530 Subject: [PATCH 1/4] Forward complex matrix multiplication to components --- Project.toml | 4 ++++ ext/StructArraysLinearAlgebraExt.jl | 17 +++++++++++++++++ src/StructArrays.jl | 1 + src/structarray.jl | 1 + test/runtests.jl | 10 ++++++++++ 5 files changed, 33 insertions(+) create mode 100644 ext/StructArraysLinearAlgebraExt.jl diff --git a/Project.toml b/Project.toml index 4f5f210..91a8710 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -14,6 +15,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -22,6 +24,7 @@ StructArraysAdaptExt = "Adapt" StructArraysGPUArraysCoreExt = "GPUArraysCore" StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +StructArraysLinearAlgebraExt = "LinearAlgebra" [compat] Adapt = "3.4" @@ -29,6 +32,7 @@ ConstructionBase = "1" DataAPI = "1" GPUArraysCore = "0.1.2" InfiniteArrays = "0.13" +LinearAlgebra = "1" StaticArrays = "1.5.6" Tables = "1" julia = "1.6" diff --git a/ext/StructArraysLinearAlgebraExt.jl b/ext/StructArraysLinearAlgebraExt.jl new file mode 100644 index 0000000..ebb562d --- /dev/null +++ b/ext/StructArraysLinearAlgebraExt.jl @@ -0,0 +1,17 @@ +module StructArraysLinearAlgebraExt + +using StructArrays +using LinearAlgebra +import LinearAlgebra: mul! + +const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}} + +function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) + mul!(C.re, A.re, B.re, alpha, beta) + mul!(C.re, A.im, B.im, -alpha, oneunit(beta)) + mul!(C.im, A.re, B.im, alpha, beta) + mul!(C.im, A.im, B.re, alpha, oneunit(beta)) + C +end + +end diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 7c1b1dd..57058ce 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -30,6 +30,7 @@ end include("../ext/StructArraysGPUArraysCoreExt.jl") include("../ext/StructArraysSparseArraysExt.jl") include("../ext/StructArraysStaticArraysExt.jl") + include("../ext/StructArraysLinearAlgebraExt.jl") end end # module diff --git a/src/structarray.jl b/src/structarray.jl index 64ff17a..239c5d7 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -106,6 +106,7 @@ _structarray(args::Tuple, ::Tuple) = _structarray(args, nothing) _structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args)) const StructVector{T, C<:Tup, I} = StructArray{T, 1, C, I} +const StructMatrix{T, C<:Tup, I} = StructArray{T, 2, C, I} StructVector{T}(args...; kwargs...) where {T} = StructArray{T}(args...; kwargs...) StructVector(args...; kwargs...) = StructArray(args...; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index ed6f573..9e97115 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1510,3 +1510,13 @@ end S = StructArray{Complex{Int}}((1:∞, 1:∞)) @test Base.IteratorSize(S) == Base.IsInfinite() end + +@testset "LinearAlgebra" begin + @testset "matrix * matrix" begin + A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + B = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + MA, MB = Matrix(A), Matrix(B) + @test A * B ≈ MA * MB + @test A * A ≈ MA * MA + end +end From cd1c12cf847ae93c1002a30dec7394d39d4d1009 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 19 Jan 2024 20:38:03 +0530 Subject: [PATCH 2/4] mul for vectors --- ext/StructArraysLinearAlgebraExt.jl | 10 +++++++++- test/runtests.jl | 11 +++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/ext/StructArraysLinearAlgebraExt.jl b/ext/StructArraysLinearAlgebraExt.jl index ebb562d..7908e08 100644 --- a/ext/StructArraysLinearAlgebraExt.jl +++ b/ext/StructArraysLinearAlgebraExt.jl @@ -5,8 +5,9 @@ using LinearAlgebra import LinearAlgebra: mul! const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}} +const StructVectorC{T, A<:AbstractVector{T}} = StructArrays.StructVector{Complex{T}, @NamedTuple{re::A, im::A}} -function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) +function _mul!(C, A, B, alpha, beta) mul!(C.re, A.re, B.re, alpha, beta) mul!(C.re, A.im, B.im, -alpha, oneunit(beta)) mul!(C.im, A.re, B.im, alpha, beta) @@ -14,4 +15,11 @@ function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Numbe C end +function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end +function mul!(C::StructVectorC, A::StructMatrixC, B::StructVectorC, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end + end diff --git a/test/runtests.jl b/test/runtests.jl index 9e97115..42ca450 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1514,9 +1514,16 @@ end @testset "LinearAlgebra" begin @testset "matrix * matrix" begin A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) - B = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + B = StructArray{ComplexF64}((rand(size(A)...), rand(size(A)...))) MA, MB = Matrix(A), Matrix(B) @test A * B ≈ MA * MB - @test A * A ≈ MA * MA + @test mul!(ones(ComplexF64,size(A)), A, B, 2.0, 3.0) ≈ 2 * A * B .+ 3 + end + @testset "matrix * vector" begin + A = StructArray{ComplexF64}((rand(10,10), rand(10,10))) + v = StructArray{ComplexF64}((rand(size(A,2)), rand(size(A,2)))) + MA, Mv = Matrix(A), Vector(v) + @test A * v ≈ MA * Mv + @test mul!(ones(ComplexF64,size(v)), A, v, 2.0, 3.0) ≈ 2 * A * v .+ 3 end end From 1d50f7e328198e2880d133d9b19a1b05ce2872c1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 19 Jan 2024 20:46:12 +0530 Subject: [PATCH 3/4] Convert tabs to spaces --- ext/StructArraysLinearAlgebraExt.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ext/StructArraysLinearAlgebraExt.jl b/ext/StructArraysLinearAlgebraExt.jl index 7908e08..6699d85 100644 --- a/ext/StructArraysLinearAlgebraExt.jl +++ b/ext/StructArraysLinearAlgebraExt.jl @@ -8,18 +8,18 @@ const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex const StructVectorC{T, A<:AbstractVector{T}} = StructArrays.StructVector{Complex{T}, @NamedTuple{re::A, im::A}} function _mul!(C, A, B, alpha, beta) - mul!(C.re, A.re, B.re, alpha, beta) - mul!(C.re, A.im, B.im, -alpha, oneunit(beta)) - mul!(C.im, A.re, B.im, alpha, beta) - mul!(C.im, A.im, B.re, alpha, oneunit(beta)) - C + mul!(C.re, A.re, B.re, alpha, beta) + mul!(C.re, A.im, B.im, -alpha, oneunit(beta)) + mul!(C.im, A.re, B.im, alpha, beta) + mul!(C.im, A.im, B.re, alpha, oneunit(beta)) + C end function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number) - _mul!(C, A, B, alpha, beta) + _mul!(C, A, B, alpha, beta) end function mul!(C::StructVectorC, A::StructMatrixC, B::StructVectorC, alpha::Number, beta::Number) - _mul!(C, A, B, alpha, beta) + _mul!(C, A, B, alpha, beta) end end From dbf8168b00254b3d95a19d6bf90e905a795288bc Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 10 Jun 2024 17:03:20 +0530 Subject: [PATCH 4/4] Remove extra end --- test/runtests.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index b2c35a5..83b0b96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1206,7 +1206,7 @@ end # The following code defines `MyArray1/2/3` with different `BroadcastStyle`s. # 1. `MyArray1` and `MyArray1` have `similar` defined. # We use them to simulate `BroadcastStyle` overloading `Base.copyto!`. -# 2. `MyArray3` has no `similar` defined. +# 2. `MyArray3` has no `similar` defined. # We use it to simulate `BroadcastStyle` overloading `Base.copy`. # 3. Their resolved style could be summaryized as (`-` means conflict) # | MyArray1 | MyArray2 | MyArray3 | Array @@ -1302,7 +1302,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS f(s) = s .+= 1 f(s) @test (@allocated f(s)) == 0 - + # issue #185 A = StructArray(randn(ComplexF64, 3, 3)) B = randn(ComplexF64, 3, 3) @@ -1321,7 +1321,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @testset "ambiguity check" begin test_set = Any[StructArray([1;2+im]), - 1:2, + 1:2, (1,2), StructArray(@SArray [1;1+2im]), (@SArray [1 2]), @@ -1566,9 +1566,7 @@ end @test mul!(ones(ComplexF64,size(v)), A, v, 2.0, 3.0) ≈ 2 * A * v .+ 3 end end - + @testset "project quality" begin Aqua.test_all(StructArrays, ambiguities=(; broken=true)) end - -end