Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix *cat inconsistencies #169

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 18 additions & 42 deletions src/array_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,29 @@ ArrayInterfaceCore.indices_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) wher
ArrayInterfaceCore.instances_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = ArrayInterfaceCore.instances_do_not_alias(A)

# Cats
# TODO: Make this a little less copy-pastey
function Base.hcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat)
ax_x, ax_y = second_axis.((x,y))
if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init=false) || getaxes(x)[1] != getaxes(y)[1]
return hcat(getdata(x), getdata(y))
function Base.cat(inputs::ComponentArray...; dims::Int)
combined_data = cat(getdata.(inputs)...; dims=dims)
axes_to_merge = [(getaxes(i)..., FlatAxis())[dims] for i in inputs]
rest_axes = [getaxes(i)[1:end .!= dims] for i in inputs]
no_duplicate_keys = (length(inputs) == 1 || allunique(vcat(collect.(keys.(axes_to_merge))...)))
if no_duplicate_keys && length(Set(rest_axes)) == 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the extra check length(Set(rest_axes)) == 1 that was not in place in the previous implementation.

offsets = (0, cumsum(size.(inputs, dims))[1:(end - 1)]...)
merged_axis = Axis(merge(indexmap.(reindex.(axes_to_merge, offsets))...))
result_axes = (first(rest_axes)[1:(dims - 1)]..., merged_axis, first(rest_axes)[dims:end]...)
return ComponentArray(combined_data, result_axes...)
else
data_x, data_y = getdata.((x, y))
ax_y = reindex(ax_y, size(x,2))
idxmap_x, idxmap_y = indexmap.((ax_x, ax_y))
axs = getaxes(x)
return ComponentArray(hcat(data_x, data_y), axs[1], Axis((;idxmap_x..., idxmap_y...)), axs[3:end]...)
return combined_data
end
end

second_axis(ca::AbstractComponentVecOrMat) = getaxes(ca)[2]
second_axis(::ComponentVector) = FlatAxis()

# Are all these methods necessary?
# TODO: See what we can reduce down to without getting ambiguity errors
Base.vcat(x::ComponentVector, y::AbstractVector) = vcat(getdata(x), y)
Base.vcat(x::AbstractVector, y::ComponentVector) = vcat(x, getdata(y))
function Base.vcat(x::ComponentVector, y::ComponentVector)
if reduce((accum, key) -> accum || (key in keys(x)), keys(y); init=false)
return vcat(getdata(x), getdata(y))
else
data_x, data_y = getdata.((x, y))
ax_x, ax_y = getindex.(getaxes.((x, y)), 1)
ax_y = reindex(ax_y, length(x))
idxmap_x, idxmap_y = indexmap.((ax_x, ax_y))
return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)))
end
Base.hcat(inputs::ComponentArray...) = Base.cat(inputs...; dims=2)
Base.vcat(inputs::ComponentArray...) = Base.cat(inputs...; dims=1)
function Base._typed_hcat(::Type{T}, inputs::Base.AbstractVecOrTuple{ComponentArray}) where {T}
Copy link
Contributor Author

@nrontsis nrontsis Oct 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hcat and vcat call _typed_h/vcat under the hood.

Unfortunately, implementing the (seemingly private) _typed_hcat for ComponentArrays is apparently necessary for e.g.

reduce(hcat, [cv1, cv2])

to be identical to hcat(cv1, cv2), as the above calls directly _typed_hcat.

Since unit tests are in place for the above behaviour, I thought the above is okay-ish?

return Base.cat(map(i -> T.(i), inputs)...; dims=2)
end
function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat)
ax_x, ax_y = getindex.(getaxes.((x, y)), 1)
if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init=false) || getaxes(x)[2:end] != getaxes(y)[2:end]
return vcat(getdata(x), getdata(y))
else
data_x, data_y = getdata.((x, y))
ax_y = reindex(ax_y, size(x,1))
idxmap_x, idxmap_y = indexmap.((ax_x, ax_y))
return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)), getaxes(x)[2:end]...)
end
function Base._typed_vcat(::Type{T}, inputs::Base.AbstractVecOrTuple{ComponentArray}) where {T}
return Base.cat(map(i -> T.(i), inputs)...; dims=1)
end
Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1]))
Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...)
Base.vcat(x::ComponentVector, args::Union{Number, UniformScaling, AbstractVecOrMat}...) = vcat(getdata(x), getdata.(args)...)
Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...)

function Base.hvcat(row_lengths::NTuple{N,Int}, xs::AbstractComponentVecOrMat...) where {N}
i = 1
Expand Down Expand Up @@ -145,4 +121,4 @@ end
Base.stride(x::ComponentArray, k) = stride(getdata(x), k)
Base.stride(x::ComponentArray, k::Int64) = stride(getdata(x), k)

ArrayInterfaceCore.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A
ArrayInterfaceCore.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A
2 changes: 1 addition & 1 deletion src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ Base.keys(ax::AbstractAxis) = keys(indexmap(ax))
reindex(i, offset) = i .+ offset
reindex(ax::FlatAxis, _) = ax
reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax)))
reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax))
reindex(ax::ViewAxis{Inds,IdxMap,Ax}, offset) where {Inds, IdxMap, Ax} = ViewAxis(viewindex(ax) .+ offset, Ax())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bugfix that was only now discovered by the tests, because now cat calls reindex to all axis, as compared to only the first or the second one that the old hcats and vcats were calling before.


# Get AbstractAxis index
@inline Base.getindex(::AbstractAxis, idx) = ComponentIndex(idx)
Expand Down
4 changes: 0 additions & 4 deletions src/broadcasting.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = Broadcast.BroadcastStyle(A)

# Need special case here for adjoint vectors in order to avoid type instability in axistype
Broadcast.combine_axes(a::ComponentArray, b::AdjOrTransComponentVector) = (axes(a)[1], axes(b)[2])
Broadcast.combine_axes(a::AdjOrTransComponentVector, b::ComponentArray) = (axes(b)[2], axes(a)[1])

Broadcast.axistype(a::CombinedAxis, b::AbstractUnitRange) = a
Broadcast.axistype(a::AbstractUnitRange, b::CombinedAxis) = b
Broadcast.axistype(a::CombinedAxis, b::CombinedAxis) = CombinedAxis(FlatAxis(), Base.Broadcast.axistype(_array_axis(a), _array_axis(b)))
Expand Down
203 changes: 17 additions & 186 deletions src/compat/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax}
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax}
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax}
const AbstractGPUArrayOrAdj = Union{<:GPUArrays.AbstractGPUArray{T, N}, Adjoint{T, <:GPUArrays.AbstractGPUArray{T, N}}, Transpose{T, <:GPUArrays.AbstractGPUArray{T, N}}} where {T, N}
const GPUComponentArray = ComponentArray{T,N,<:AbstractGPUArrayOrAdj{T, N},Ax} where {T,N,Ax}
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:AbstractGPUArrayOrAdj{T, 1},Ax}
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:AbstractGPUArrayOrAdj{T, 2},Ax}
const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}}

GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x))
Expand All @@ -25,7 +26,10 @@ end

LinearAlgebra.dot(x::GPUComponentArray, y::GPUComponentArray) = dot(getdata(x), getdata(y))
LinearAlgebra.norm(ca::GPUComponentArray, p::Real) = norm(getdata(ca), p)
LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number) = GPUArrays.generic_rmul!(ca, b)
function LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number)
GPUArrays.generic_rmul!(getdata(ca), b)
return ca
end

function Base.map(f, x::GPUComponentArray, args...)
data = map(f, getdata(x), getdata.(args)...)
Expand Down Expand Up @@ -78,196 +82,23 @@ end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
return GPUArrays.generic_matmatmul!(C, getdata(A), getdata(B), a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
A::AbstractGPUArrayOrAdj,
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, B::GPUComponentVecorMat,
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
return GPUArrays.generic_matmatmul!(C, A, getdata(B), a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, B::GPUComponentVecorMat,
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
B::AbstractGPUArrayOrAdj, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, getdata(A), B, a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
A::AbstractGPUArrayOrAdj,
B::AbstractGPUArrayOrAdj, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
end
18 changes: 4 additions & 14 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,11 @@ const CArray = ComponentArray
const CVector = ComponentVector
const CMatrix = ComponentMatrix

const AdjOrTrans{T, A} = Union{Adjoint{T, A}, Transpose{T, A}}
const AdjOrTransComponentArray{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentArray
const AdjOrTransComponentVector{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentVector
const AdjOrTransComponentMatrix{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentMatrix

const ComponentVecOrMat = Union{ComponentVector, ComponentMatrix}
const AdjOrTransComponentVecOrMat = AdjOrTrans{T, <:ComponentVecOrMat} where T
const AbstractComponentArray = Union{ComponentArray, AdjOrTransComponentArray}
const AbstractComponentVecOrMat = Union{ComponentVecOrMat, AdjOrTransComponentVecOrMat}
const AbstractComponentVector = Union{ComponentVector, AdjOrTransComponentVector}
const AbstractComponentMatrix = Union{ComponentMatrix, AdjOrTransComponentMatrix}
const AbstractComponentArray = ComponentArray
const AbstractComponentVecOrMat = ComponentVecOrMat
const AbstractComponentVector = ComponentVector
const AbstractComponentMatrix = ComponentMatrix


## Constructor helpers
Expand Down Expand Up @@ -288,12 +282,8 @@ julia> getaxes(ca)
```
"""
@inline getaxes(x::ComponentArray) = getfield(x, :axes)
@inline getaxes(x::AdjOrTrans{T, <:ComponentVector}) where T = (FlatAxis(), getaxes(x.parent)[1])
@inline getaxes(x::AdjOrTrans{T, <:ComponentMatrix}) where T = reverse(getaxes(x.parent))

@inline getaxes(::Type{<:ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = map(x->x(), (Axes.types...,))
@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentVector} = (FlatAxis(), getaxes(CA)[1]) |> typeof
@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentMatrix} = reverse(getaxes(CA)) |> typeof

## Field access through these functions to reserve dot-getting for keys
@inline getaxes(x::VarAxes) = getaxes(typeof(x))
Expand Down
Loading