Skip to content

Commit

Permalink
Forward complex matrix multiplication to components
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Jan 19, 2024
1 parent d9791eb commit 6f54802
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand All @@ -22,13 +24,15 @@ StructArraysAdaptExt = "Adapt"
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysSparseArraysExt = "SparseArrays"
StructArraysStaticArraysExt = "StaticArrays"
StructArraysLinearAlgebraExt = "LinearAlgebra"

[compat]
Adapt = "3.4"
ConstructionBase = "1"
DataAPI = "1"
GPUArraysCore = "0.1.2"
InfiniteArrays = "0.13"
LinearAlgebra = "1"
StaticArrays = "1.5.6"
Tables = "1"
julia = "1.6"
Expand Down
17 changes: 17 additions & 0 deletions ext/StructArraysLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module StructArraysLinearAlgebraExt

using StructArrays
using LinearAlgebra
import LinearAlgebra: mul!

const StructMatrixC{T, A<:AbstractMatrix{T}} = StructArrays.StructMatrix{Complex{T}, @NamedTuple{re::A, im::A}}

function mul!(C::StructMatrixC, A::StructMatrixC, B::StructMatrixC, alpha::Number, beta::Number)
mul!(C.re, A.re, B.re, alpha, beta)
mul!(C.re, A.im, B.im, -alpha, oneunit(beta))
mul!(C.im, A.re, B.im, alpha, beta)
mul!(C.im, A.im, B.re, alpha, oneunit(beta))
C
end

end
1 change: 1 addition & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ end
include("../ext/StructArraysGPUArraysCoreExt.jl")
include("../ext/StructArraysSparseArraysExt.jl")
include("../ext/StructArraysStaticArraysExt.jl")
include("../ext/StructArraysLinearAlgebraExt.jl")
end

end # module
1 change: 1 addition & 0 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ _structarray(args::Tuple, ::Tuple) = _structarray(args, nothing)
_structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args))

const StructVector{T, C<:Tup, I} = StructArray{T, 1, C, I}
const StructMatrix{T, C<:Tup, I} = StructArray{T, 2, C, I}
StructVector{T}(args...; kwargs...) where {T} = StructArray{T}(args...; kwargs...)
StructVector(args...; kwargs...) = StructArray(args...; kwargs...)

Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1510,3 +1510,13 @@ end
S = StructArray{Complex{Int}}((1:∞, 1:∞))
@test Base.IteratorSize(S) == Base.IsInfinite()
end

@testset "LinearAlgebra" begin
@testset "matrix * matrix" begin
A = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
B = StructArray{ComplexF64}((rand(10,10), rand(10,10)))
MA, MB = Matrix(A), Matrix(B)
@test A * B MA * MB
@test A * A MA * MA
end
end

0 comments on commit 6f54802

Please sign in to comment.