From c1af7dbc979bc8d77ee3667c16a3923d43b6e25c Mon Sep 17 00:00:00 2001 From: Nikitas Rontsis Date: Sun, 23 Oct 2022 18:16:54 +0100 Subject: [PATCH 1/7] Change Adjoints to be ComponentArray This PR aims to avoid issues like: ``` using ComponentArrays A = ComponentMatrix(ones(2, 2), Axis(:a, :b), FlatAxis()) A[:b, :] # works A'[:, :b] # fails! ``` by wrapping adjoint operations in the underlying data of the ComponentArray structure. By not having to care about adjoints of ComponentArray, we arguably also reduce overall cognitive load. --- src/linear_algebra.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 5d888e7f..36097445 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -42,4 +42,18 @@ for op in [:*, :\, :/] end end end -end \ No newline at end of file +end + + +for op in [:adjoint, :transpose] + @eval begin + function LinearAlgebra.$op(M::ComponentMatrix{T,A,Tuple{Ax1,Ax2}}) where {T,A,Ax1,Ax2} + data = $op(getdata(M)) + return ComponentArray(data, (Ax2(), Ax1())[1:ndims(data)]...) + end + + function LinearAlgebra.$op(M::ComponentVector{T,A,Tuple{Ax1}}) where {T,A,Ax1} + return ComponentMatrix($op(getdata(M)), FlatAxis(), Ax1()) + end + end +end From 0b7a72dcaeb85a72d8e2e8ff7192775ef105e001 Mon Sep 17 00:00:00 2001 From: nrontsis Date: Thu, 27 Oct 2022 17:52:50 +0100 Subject: [PATCH 2/7] Fix tests and cleanup lingering Adj/T references --- src/array_interface.jl | 1 - src/broadcasting.jl | 4 - src/compat/gpuarrays.jl | 203 ++++------------------------------------ src/componentarray.jl | 18 +--- src/linear_algebra.jl | 17 +--- src/show.jl | 2 +- test/gpu_tests.jl | 2 +- test/runtests.jl | 22 ++--- 8 files changed, 36 insertions(+), 233 deletions(-) diff --git a/src/array_interface.jl b/src/array_interface.jl index ffb7e93f..b1ffd5ec 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -59,7 +59,6 @@ function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)), getaxes(x)[2:end]...) end end -Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1])) Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...) Base.vcat(x::ComponentVector, args::Union{Number, UniformScaling, AbstractVecOrMat}...) = vcat(getdata(x), getdata.(args)...) Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...) diff --git a/src/broadcasting.jl b/src/broadcasting.jl index 7a59f3e6..1c21045b 100644 --- a/src/broadcasting.jl +++ b/src/broadcasting.jl @@ -1,9 +1,5 @@ Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = Broadcast.BroadcastStyle(A) -# Need special case here for adjoint vectors in order to avoid type instability in axistype -Broadcast.combine_axes(a::ComponentArray, b::AdjOrTransComponentVector) = (axes(a)[1], axes(b)[2]) -Broadcast.combine_axes(a::AdjOrTransComponentVector, b::ComponentArray) = (axes(b)[2], axes(a)[1]) - Broadcast.axistype(a::CombinedAxis, b::AbstractUnitRange) = a Broadcast.axistype(a::AbstractUnitRange, b::CombinedAxis) = b Broadcast.axistype(a::CombinedAxis, b::CombinedAxis) = CombinedAxis(FlatAxis(), Base.Broadcast.axistype(_array_axis(a), _array_axis(b))) diff --git a/src/compat/gpuarrays.jl b/src/compat/gpuarrays.jl index ace99e6a..8fe201d8 100644 --- a/src/compat/gpuarrays.jl +++ b/src/compat/gpuarrays.jl @@ -1,6 +1,7 @@ -const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax} -const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax} -const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax} +const AbstractGPUArrayOrAdj = Union{<:GPUArrays.AbstractGPUArray{T, N}, Adjoint{T, <:GPUArrays.AbstractGPUArray{T, N}}, Transpose{T, <:GPUArrays.AbstractGPUArray{T, N}}} where {T, N} +const GPUComponentArray = ComponentArray{T,N,<:AbstractGPUArrayOrAdj{T, N},Ax} where {T,N,Ax} +const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:AbstractGPUArrayOrAdj{T, 1},Ax} +const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:AbstractGPUArrayOrAdj{T, 2},Ax} const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}} GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x)) @@ -25,7 +26,10 @@ end LinearAlgebra.dot(x::GPUComponentArray, y::GPUComponentArray) = dot(getdata(x), getdata(y)) LinearAlgebra.norm(ca::GPUComponentArray, p::Real) = norm(getdata(ca), p) -LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number) = GPUArrays.generic_rmul!(ca, b) +function LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number) + GPUArrays.generic_rmul!(getdata(ca), b) + return ca +end function Base.map(f, x::GPUComponentArray, args...) data = map(f, getdata(x), getdata.(args)...) @@ -78,196 +82,23 @@ end function LinearAlgebra.mul!(C::GPUComponentVecorMat, A::GPUComponentVecorMat, B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) + return GPUArrays.generic_matmatmul!(C, getdata(A), getdata(B), a, b) end function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, + A::AbstractGPUArrayOrAdj, B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, B::GPUComponentVecorMat, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) + return GPUArrays.generic_matmatmul!(C, A, getdata(B), a, b) end function LinearAlgebra.mul!(C::GPUComponentVecorMat, A::GPUComponentVecorMat, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, B::GPUComponentVecorMat, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) + B::AbstractGPUArrayOrAdj, a::Number, b::Number) + return GPUArrays.generic_matmatmul!(C, getdata(A), B, a, b) end + function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Real, b::Real) + A::AbstractGPUArrayOrAdj, + B::AbstractGPUArrayOrAdj, a::Number, b::Number) return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end +end \ No newline at end of file diff --git a/src/componentarray.jl b/src/componentarray.jl index 453dc7ee..1fc581ec 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -118,17 +118,11 @@ const CArray = ComponentArray const CVector = ComponentVector const CMatrix = ComponentMatrix -const AdjOrTrans{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} -const AdjOrTransComponentArray{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentArray -const AdjOrTransComponentVector{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentVector -const AdjOrTransComponentMatrix{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentMatrix - const ComponentVecOrMat = Union{ComponentVector, ComponentMatrix} -const AdjOrTransComponentVecOrMat = AdjOrTrans{T, <:ComponentVecOrMat} where T -const AbstractComponentArray = Union{ComponentArray, AdjOrTransComponentArray} -const AbstractComponentVecOrMat = Union{ComponentVecOrMat, AdjOrTransComponentVecOrMat} -const AbstractComponentVector = Union{ComponentVector, AdjOrTransComponentVector} -const AbstractComponentMatrix = Union{ComponentMatrix, AdjOrTransComponentMatrix} +const AbstractComponentArray = ComponentArray +const AbstractComponentVecOrMat = ComponentVecOrMat +const AbstractComponentVector = ComponentVector +const AbstractComponentMatrix = ComponentMatrix ## Constructor helpers @@ -288,12 +282,8 @@ julia> getaxes(ca) ``` """ @inline getaxes(x::ComponentArray) = getfield(x, :axes) -@inline getaxes(x::AdjOrTrans{T, <:ComponentVector}) where T = (FlatAxis(), getaxes(x.parent)[1]) -@inline getaxes(x::AdjOrTrans{T, <:ComponentMatrix}) where T = reverse(getaxes(x.parent)) @inline getaxes(::Type{<:ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = map(x->x(), (Axes.types...,)) -@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentVector} = (FlatAxis(), getaxes(CA)[1]) |> typeof -@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentMatrix} = reverse(getaxes(CA)) |> typeof ## Field access through these functions to reserve dot-getting for keys @inline getaxes(x::VarAxes) = getaxes(typeof(x)) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 36097445..0fa527db 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -13,7 +13,6 @@ _first_axis(x::AbstractComponentVecOrMat) = getaxes(x)[1] _second_axis(x::AbstractMatrix) = FlatAxis() _second_axis(x::ComponentMatrix) = getaxes(x)[2] -_second_axis(x::AdjOrTransComponentVecOrMat) = getaxes(x)[2] _out_axes(::typeof(*), a, b::AbstractVector) = (_first_axis(a), ) _out_axes(::typeof(*), a, b::AbstractMatrix) = (_first_axis(a), _second_axis(b)) @@ -27,19 +26,7 @@ for op in [:*, :\, :/] function Base.$op(A::AbstractComponentVecOrMat, B::AbstractComponentVecOrMat) C = $op(getdata(A), getdata(B)) ax = _out_axes($op, A, B) - return ComponentArray(C, ax) - end - end - for (adj, Adj) in zip([:adjoint, :transpose], [:Adjoint, :Transpose]) - @eval begin - function Base.$op(aᵀ::$Adj{T,<:ComponentVector}, B::AbstractComponentMatrix) where {T} - cᵀ = $op(getdata(aᵀ), getdata(B)) - ax2 = _out_axes($op, aᵀ, B)[2] - return $adj(ComponentArray(cᵀ', ax2)) - end - function Base.$op(A::$Adj{T,<:CV}, B::CV) where {T<:Real, CV<:ComponentVector{T}} - return $op(getdata(A), getdata(B)) - end + return ComponentArray(C, ax...) end end end @@ -56,4 +43,4 @@ for op in [:adjoint, :transpose] return ComponentMatrix($op(getdata(M)), FlatAxis(), Ax1()) end end -end +end \ No newline at end of file diff --git a/src/show.jl b/src/show.jl index 2a08a47f..dc5d1613 100644 --- a/src/show.jl +++ b/src/show.jl @@ -79,4 +79,4 @@ function Base.show(io::IO, ::MIME"text/plain", x::ComponentMatrix{T,A,Axes}) whe println(io, " with axes $(axs[1]) × $(axs[2])") Base.print_matrix(io, getdata(x)) return nothing -end +end \ No newline at end of file diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index 97e32965..e4a819ba 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -46,7 +46,7 @@ end @test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a=1:2, b=3:4)) end @testset "mul!" begin - A = jlca .* jlca'; + A = jlca * jlca'; @test_nowarn mul!(deepcopy(A), A, A, 1, 2); @test_nowarn mul!(deepcopy(A), A', A', 1, 2); @test_nowarn mul!(deepcopy(A), A', A, 1, 2); diff --git a/test/runtests.jl b/test/runtests.jl index c1b58c49..b13dfe00 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,8 +28,8 @@ ca_composed = ComponentArray(a = 1, b = ca) ca2 = ComponentArray(nt2) -cmat = ComponentArray(a .* a', ax, ax) -cmat2 = ca2 .* ca2' +cmat = ComponentArray(a * a', ax, ax) +cmat2 = ca2 * ca2' caa = ComponentArray(a = ca, b = sq_mat) @@ -142,13 +142,13 @@ end @test hash(ca) != hash(getdata(ca)) @test hash(ca, zero(UInt)) != hash(getdata(ca), zero(UInt)) - ab = ComponentArray(a = 1, b = 2) - xy = ComponentArray(x = 1, y = 2) + ab = ComponentArray(a=1, b=2) + xy = ComponentArray(x=1, y=2) @test ab != xy @test hash(ab) != hash(xy) @test hash(ab, zero(UInt)) != hash(xy, zero(UInt)) - @test ab == LVector(a = 1, b = 2) + @test ab == LVector(a=1, b=2) # Issue #117 kw_fun(; a, b) = a // b @@ -369,11 +369,11 @@ end @testset "Broadcasting" begin temp = deepcopy(ca) @test eltype(Float32.(ca)) == Float32 - @test ca .* ca' == cmat + @test ca * ca' == cmat @test 1 .* (ca .+ ca) == ComponentArray(a .+ a, getaxes(ca)) @test typeof(ca .+ cmat) == typeof(cmat) - @test getaxes(false .* ca .* ca') == (ax, ax) - @test getaxes(false .* ca' .* ca) == (ax, ax) + @test getaxes(false .* ca * ca') == (ax, ax) + @test isa(ca' * ca, Float64) @test (vec(temp) .= vec(ca_Float32)) isa ComponentArray @test_broken getdata(ca_MVector .* ca_MVector) isa MArray @@ -393,8 +393,8 @@ end x1 = ComponentArray(a = [1.1, 2.1], b = [0.1]) x2 = ComponentArray(a = [1.1, 2.1], b = 0.1) x3 = ComponentArray(a = [1.1, 2.1], c = [0.1]) - xmat = x1 .* x2' - x1mat = x1 .* x1' + xmat = x1 * x2' + x1mat = x1 * x1' @test x1 + x2 isa Vector @test x1 + x3 isa Vector @test x2 + x3 isa Vector @@ -459,7 +459,7 @@ end @test ca * transpose(ca) == collect(cmat) @test ca * transpose(ca) == a * transpose(a) @test transpose(ca) * ca == transpose(a) * a - @test ca' * cmat == ComponentArray(a' * getdata(cmat), getaxes(ca)) + @test ca' * cmat == ComponentArray(a' * getdata(cmat), FlatAxis(), getaxes(ca)...) @test transpose(transpose(cmat)) == cmat @test transpose(transpose(ca)) == ca @test transpose(ca.c) * cmat[:c, :c] * ca.c isa Number From 8d8bbb609024c1fdca1aa8b850413287c43cd802 Mon Sep 17 00:00:00 2001 From: Nikitas Rontsis Date: Sun, 23 Oct 2022 17:54:48 +0100 Subject: [PATCH 3/7] Fix *cat inconsistencies --- src/array_interface.jl | 57 ++++++++++++------------------------------ 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/src/array_interface.jl b/src/array_interface.jl index b1ffd5ec..4b208d69 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -16,52 +16,27 @@ ArrayInterfaceCore.indices_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) wher ArrayInterfaceCore.instances_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = ArrayInterfaceCore.instances_do_not_alias(A) # Cats -# TODO: Make this a little less copy-pastey -function Base.hcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) - ax_x, ax_y = second_axis.((x,y)) - if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init=false) || getaxes(x)[1] != getaxes(y)[1] - return hcat(getdata(x), getdata(y)) +function Base.cat(inputs::ComponentArray...; dims::Int) + combined_data = cat(getdata.(inputs)...; dims=dims) + axes_to_merge = [(getaxes(i)..., FlatAxis())[dims] for i in inputs] + rest_axes = [getaxes(i)[1:end .!= dims] for i in inputs] + no_duplicate_keys = (length(inputs) == 1 || isempty(intersect(keys.(axes_to_merge)...))) + if no_duplicate_keys && length(Set(rest_axes)) == 1 + offsets = cumsum(size.(inputs, 1) .- size(first(inputs), 1)) + merged_axis = Axis(merge(indexmap.(reindex.(axes_to_merge, offsets))...)) + result_axes = (first(rest_axes)[1:(dims - 1)]..., merged_axis, first(rest_axes)[dims:end]...) + return ComponentArray(combined_data, result_axes...) else - data_x, data_y = getdata.((x, y)) - ax_y = reindex(ax_y, size(x,2)) - idxmap_x, idxmap_y = indexmap.((ax_x, ax_y)) - axs = getaxes(x) - return ComponentArray(hcat(data_x, data_y), axs[1], Axis((;idxmap_x..., idxmap_y...)), axs[3:end]...) + return combined_data end end -second_axis(ca::AbstractComponentVecOrMat) = getaxes(ca)[2] -second_axis(::ComponentVector) = FlatAxis() - -# Are all these methods necessary? -# TODO: See what we can reduce down to without getting ambiguity errors -Base.vcat(x::ComponentVector, y::AbstractVector) = vcat(getdata(x), y) -Base.vcat(x::AbstractVector, y::ComponentVector) = vcat(x, getdata(y)) -function Base.vcat(x::ComponentVector, y::ComponentVector) - if reduce((accum, key) -> accum || (key in keys(x)), keys(y); init=false) - return vcat(getdata(x), getdata(y)) - else - data_x, data_y = getdata.((x, y)) - ax_x, ax_y = getindex.(getaxes.((x, y)), 1) - ax_y = reindex(ax_y, length(x)) - idxmap_x, idxmap_y = indexmap.((ax_x, ax_y)) - return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...))) - end +function Base._typed_hcat(::Type{T}, inputs::Base.AbstractVecOrTuple{ComponentArray}) where {T} + return Base.cat(map(i -> T.(i), inputs)...; dims=2) end -function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) - ax_x, ax_y = getindex.(getaxes.((x, y)), 1) - if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init=false) || getaxes(x)[2:end] != getaxes(y)[2:end] - return vcat(getdata(x), getdata(y)) - else - data_x, data_y = getdata.((x, y)) - ax_y = reindex(ax_y, size(x,1)) - idxmap_x, idxmap_y = indexmap.((ax_x, ax_y)) - return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)), getaxes(x)[2:end]...) - end +function Base._typed_vcat(::Type{T}, inputs::Base.AbstractVecOrTuple{ComponentArray}) where {T} + return Base.cat(map(i -> T.(i), inputs)...; dims=1) end -Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...) -Base.vcat(x::ComponentVector, args::Union{Number, UniformScaling, AbstractVecOrMat}...) = vcat(getdata(x), getdata.(args)...) -Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...) function Base.hvcat(row_lengths::NTuple{N,Int}, xs::AbstractComponentVecOrMat...) where {N} i = 1 @@ -144,4 +119,4 @@ end Base.stride(x::ComponentArray, k) = stride(getdata(x), k) Base.stride(x::ComponentArray, k::Int64) = stride(getdata(x), k) -ArrayInterfaceCore.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A \ No newline at end of file +ArrayInterfaceCore.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A From de8d8bac12ec098f2070488501ee251d2ddca151 Mon Sep 17 00:00:00 2001 From: Nikitas Rontsis Date: Sun, 23 Oct 2022 17:59:35 +0100 Subject: [PATCH 4/7] Extend tests --- test/runtests.jl | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index b13dfe00..a9d405eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -482,14 +482,25 @@ end @test ldiv!(tempmat, lu(cmat + I), cmat) isa ComponentMatrix @test ldiv!(getdata(tempmat), lu(cmat + I), cmat) isa AbstractMatrix - vca2 = vcat(ca2', ca2') - hca2 = hcat(ca2, ca2) + for n in 1:3 # Issue 168 cats (on more than one) ComponentArrays + vca2 = vcat(repeat([ca2'], n)...) + hca2 = hcat(repeat([ca2], n)...) + hca2_reduced = reduce(hcat, repeat([ca2], n)) # Issue 113: reduce cats should match non-reduced ones + vca2_reduced = reduce(vcat, repeat([ca2'], n)) + @test hca2 == hca2_reduced + @test typeof(hca2) == typeof(hca2_reduced) + @test vca2 == vca2_reduced + @test typeof(vca2) == typeof(vca2_reduced) + @test hca2 isa ComponentMatrix + @test vca2 isa ComponentMatrix + @test all(vca2[1, :] .== ca2) + @test all(hca2[:, 1] .== ca2) + @test all(vca2' .== hca2) + @test hca2[:a, :] == vca2[:, :a] + end + temp = ComponentVector(q = 100, r = rand(3, 3, 3)) vtempca = [temp; ca] - @test all(vca2[1, :] .== ca2) - @test all(hca2[:, 1] .== ca2) - @test all(vca2' .== hca2) - @test hca2[:a, :] == vca2[:, :a] @test vtempca isa ComponentVector @test vtempca.r == temp.r @test vtempca.c == ca.c From 419105fdc0f8b1d5aebf44e066a13757df3d0590 Mon Sep 17 00:00:00 2001 From: Nikitas Rontsis Date: Mon, 24 Oct 2022 15:45:23 +0100 Subject: [PATCH 5/7] Fix offsets computation --- src/array_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_interface.jl b/src/array_interface.jl index 4b208d69..fda942c8 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -22,7 +22,7 @@ function Base.cat(inputs::ComponentArray...; dims::Int) rest_axes = [getaxes(i)[1:end .!= dims] for i in inputs] no_duplicate_keys = (length(inputs) == 1 || isempty(intersect(keys.(axes_to_merge)...))) if no_duplicate_keys && length(Set(rest_axes)) == 1 - offsets = cumsum(size.(inputs, 1) .- size(first(inputs), 1)) + offsets = (0, cumsum(size.(inputs, dims))[1:(end - 1)]...) merged_axis = Axis(merge(indexmap.(reindex.(axes_to_merge, offsets))...)) result_axes = (first(rest_axes)[1:(dims - 1)]..., merged_axis, first(rest_axes)[dims:end]...) return ComponentArray(combined_data, result_axes...) From 7bc3561be1238d8e0fcf2e71fc36e89b9c34af25 Mon Sep 17 00:00:00 2001 From: nrontsis Date: Thu, 27 Oct 2022 18:54:23 +0100 Subject: [PATCH 6/7] Fix tests --- src/array_interface.jl | 2 ++ src/axis.jl | 2 +- test/runtests.jl | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/array_interface.jl b/src/array_interface.jl index fda942c8..8359b5fa 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -31,6 +31,8 @@ function Base.cat(inputs::ComponentArray...; dims::Int) end end +Base.hcat(inputs::ComponentArray...) = Base.cat(inputs...; dims=2) +Base.vcat(inputs::ComponentArray...) = Base.cat(inputs...; dims=1) function Base._typed_hcat(::Type{T}, inputs::Base.AbstractVecOrTuple{ComponentArray}) where {T} return Base.cat(map(i -> T.(i), inputs)...; dims=2) end diff --git a/src/axis.jl b/src/axis.jl index 29ea4222..4ffecc68 100644 --- a/src/axis.jl +++ b/src/axis.jl @@ -148,7 +148,7 @@ Base.keys(ax::AbstractAxis) = keys(indexmap(ax)) reindex(i, offset) = i .+ offset reindex(ax::FlatAxis, _) = ax reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax))) -reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax)) +reindex(ax::ViewAxis{Inds,IdxMap,Ax}, offset) where {Inds, IdxMap, Ax} = ViewAxis(viewindex(ax) .+ offset, Ax()) # Get AbstractAxis index @inline Base.getindex(::AbstractAxis, idx) = ComponentIndex(idx) diff --git a/test/runtests.jl b/test/runtests.jl index a9d405eb..8560eae3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -507,7 +507,7 @@ end @test length(vtempca) == length(temp) + length(ca) @test [ca; ca; ca] isa Vector @test vcat(ca, 100) isa Vector - @test [ca' ca']' isa Vector + @test [ca' ca']' isa Adjoint{T, Matrix{T}} where T @test keys(getaxes([ca' temp']')[1]) == (:a, :b, :c, :q, :r) # Getting serious about axes From 5f92105e49106807082fa9054ff49e4117b78cc6 Mon Sep 17 00:00:00 2001 From: nrontsis Date: Thu, 24 Nov 2022 21:14:19 +0000 Subject: [PATCH 7/7] Correct bug in duplicate keys detection; introduce test that catches the bug --- src/array_interface.jl | 2 +- test/runtests.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/array_interface.jl b/src/array_interface.jl index 8359b5fa..3b7c0342 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -20,7 +20,7 @@ function Base.cat(inputs::ComponentArray...; dims::Int) combined_data = cat(getdata.(inputs)...; dims=dims) axes_to_merge = [(getaxes(i)..., FlatAxis())[dims] for i in inputs] rest_axes = [getaxes(i)[1:end .!= dims] for i in inputs] - no_duplicate_keys = (length(inputs) == 1 || isempty(intersect(keys.(axes_to_merge)...))) + no_duplicate_keys = (length(inputs) == 1 || allunique(vcat(collect.(keys.(axes_to_merge))...))) if no_duplicate_keys && length(Set(rest_axes)) == 1 offsets = (0, cumsum(size.(inputs, dims))[1:(end - 1)]...) merged_axis = Axis(merge(indexmap.(reindex.(axes_to_merge, offsets))...)) diff --git a/test/runtests.jl b/test/runtests.jl index 8560eae3..729fdc57 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -482,6 +482,7 @@ end @test ldiv!(tempmat, lu(cmat + I), cmat) isa ComponentMatrix @test ldiv!(getdata(tempmat), lu(cmat + I), cmat) isa AbstractMatrix + @test !(vcat(ca, ca2, ca) isa ComponentVector) for n in 1:3 # Issue 168 cats (on more than one) ComponentArrays vca2 = vcat(repeat([ca2'], n)...) hca2 = hcat(repeat([ca2], n)...)