diff --git a/src/CellArray.jl b/src/CellArray.jl index c58a1c6..0bae3cb 100644 --- a/src/CellArray.jl +++ b/src/CellArray.jl @@ -317,6 +317,15 @@ end end +## CellArray properties + +@inline Base.getproperty(A::CellArray{T,N,B,T_array}, s::Symbol) where {T<:FieldArray,N,B,T_array} = _getproperty(A, Val(s)) +@inline _getproperty(A::CellArray{T,N,B,T_array}, s::Val) where {T<:FieldArray,N,B,T_array} = _getfield(A, s) +@inline _getfield(A::CellArray{T,N,B,T_array}, ::Val{:data}) where {T<:FieldArray,N,B,T_array} = getfield(A, :data) +@inline _getfield(A::CellArray{T,N,B,T_array}, ::Val{:dims}) where {T<:FieldArray,N,B,T_array} = getfield(A, :dims) +@inline _getfield(A::CellArray{T,N,B,T_array}, s::Val) where {T<:FieldArray,N,B,T_array} = field(A, s) + + ## API functions """ @@ -339,11 +348,13 @@ Return the blocklength of CellArray `A`. """ field(A, indices) + field(A, fieldname) -Return an array view of the field of CellArray `A` designated with `indices` (modifying the view will modify `A`). The view's dimensionality and size are equal to `A`'s. The operation is not supported if parameter `B` of `A` is neither `0` nor `1`. +Return an array view of the field of CellArray `A` designated with `indices` or `fieldname` (modifying the view will modify `A`). The view's dimensionality and size are equal to `A`'s. The operation is not supported if parameter `B` of `A` is neither `0` nor `1`. ## Arguments - `indices::Int|NTuple{N,Int}`: the `indices` that designate the field in accordance with `A`'s cell type. +- `fieldname::Symbol`: the `fieldname` that designates the field in accordance with `A`'s cell type. """ @inline field(A::CellArray{T,N,0,T_array}, index::Int) where {T,N,T_array} = view(plain(A), Base.OneTo.(size(A))..., index) @inline field(A::CellArray{T,N,0,T_array}, indices::NTuple{M,Int}) where {T_elem,M,T<:AbstractArray{T_elem,M},N, T_array} = view(plain(A), Base.OneTo.(size(A))..., indices...) @@ -351,6 +362,13 @@ Return an array view of the field of CellArray `A` designated with `indices` (mo @inline field(A::CellArray{T,N,1,T_array}, indices::NTuple{M,Int}) where {T_elem,M,T<:AbstractArray{T_elem,M},N, T_array} = view(plain(A), indices..., Base.OneTo.(size(A))...) @inline field(A::CellArray{T,N,B,T_array}, indices::Union{Int,NTuple{M,Int}}) where {T_elem,M,T<:AbstractArray{T_elem,M},N,B,T_array} = @ArgumentError("the operation is not supported if parameter `B` of `A` is neither `0` nor `1`.") @inline field(A::CellArray, indices::Int...) = field(A, indices) +@inline field(A::CellArray{T,N,B,T_array}, fieldname::Symbol) where {T<:FieldArray,N,B,T_array} = field(A, Val(fieldname)) + +@inline @generated function field(A::CellArray{T,N,B,T_array}, ::Val{fieldname}) where {T<:FieldArray{N2,T2,D},N,B,T_array,fieldname} where {N2,T2,D} + names = SArray{N2}(fieldnames(T)) + indices = Tuple(findfirst(x->x===fieldname, names)) + return :(field(A, $(indices...))) +end ## Helper functions diff --git a/test/test_CellArray.jl b/test/test_CellArray.jl index a95f8d6..569d4a6 100644 --- a/test/test_CellArray.jl +++ b/test/test_CellArray.jl @@ -1,7 +1,7 @@ using Test using CUDA, AMDGPU, Metal, StaticArrays import CellArrays -import CellArrays: CPUCellArray, @define_CuCellArray, @define_ROCCellArray, @define_MtlCellArray, cellsize, blocklength, _N +import CellArrays: CPUCellArray, @define_CuCellArray, @define_ROCCellArray, @define_MtlCellArray, cellsize, blocklength, field, _N import CellArrays: IncoherentArgumentError, ArgumentError @define_CuCellArray @@ -321,6 +321,24 @@ end @test blocklength(G) == 1 @test blocklength(H) == 4 end; + @testset "field" begin + @test size(field(A, 1)) == dims + @test size(field(C, 3,4)) == dims + @test size(field(E, 2,2,2,2)) == dims + @test size(field(G, 3,4)) == dims + @test size(field(E, :xxxx)) == dims + @test size(field(E, :yxxx)) == dims + @test size(field(E, :xyxx)) == dims + @test size(field(E, :yyxx)) == dims + @test size(field(E, :yyyy)) == dims + end; + @testset "field property" begin + @test E.xxxx == field(E, :xxxx) + @test E.yxxx == field(E, :yxxx) + @test E.xyxx == field(E, :xyxx) + @test E.yyxx == field(E, :yyxx) + @test E.yyyy == field(E, :yyyy) + end; end; @testset "3. Exceptions ($array_type arrays) (precision: $(nameof(Float)))" for (array_type, Array, CellArray, allowscalar, Float) in zip(array_types, ArrayConstructors, CellArrayConstructors, allowscalar_functions, precision_types) dims = (2,3)