Skip to content

Commit

Permalink
Fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 13, 2024
1 parent a3d0b91 commit bd54ec3
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions src/BlockArraysExtensions/BlockArraysExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,39 @@ using BlockArrays:
findblockindex
using Dictionaries: Dictionary, Indices
using GradedUnitRanges: blockedunitrange_getindices, to_blockindices
using SparseArraysBase: SparseArraysBase, storedlength, eachstoredindex
using SparseArraysBase:
SparseArraysBase,
eachstoredindex,
getunstoredindex,
isstored,
setunstoredindex!,
storedlength

# A return type for `blocks(array)` when `array` isn't blocked.
# Represents a vector with just that single block.
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
Base.parent(a::SingleBlockView) = a.array
blocks_maybe_single(a) = blocks(a)
blocks_maybe_single(a::Array) = SingleBlockView(a)
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
@assert all(isone, index)
return a.array
return parent(a)
end

# A wrapper around a potentially blocked array that is not blocked.
struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
Base.size(a::NonBlockedArray) = size(a.array)
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...]
Base.parent(a::NonBlockedArray) = a.array
Base.size(a::NonBlockedArray) = size(parent(a))
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = parent(a)[I...]
# Views of `NonBlockedArray`/`NonBlockedVector` are eager.
# This fixes an issue in Julia 1.11 where reindexing defaults to using views.
# TODO: Maybe reconsider this design, and allows views to work in slicing.
Base.view(a::NonBlockedArray, I...) = a[I...]
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array)
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(parent(a))
const NonBlockedVector{T,Array} = NonBlockedArray{T,1,Array}
NonBlockedVector(array::AbstractVector) = NonBlockedArray(array)

Expand Down Expand Up @@ -100,10 +108,10 @@ Base.view(S::BlockIndices, i) = S[i]
function Base.getindex(
a::NonBlockedVector{<:Integer,<:BlockIndices}, I::UnitRange{<:Integer}
)
ax = only(axes(a.array.indices))
ax = only(axes(parent(a).indices))
brs = to_blockindices(ax, I)
inds = blockedunitrange_getindices(ax, I)
return NonBlockedVector(a.array[BlockSlice(brs, inds)])
return NonBlockedVector(parent(a)[BlockSlice(brs, inds)])
end

function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
Expand Down Expand Up @@ -511,25 +519,30 @@ struct BlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
block::Tuple{Vararg{Block{1,Int},N}}
end
Base.parent(a::BlockView) = a.array
function Base.axes(a::BlockView)
# TODO: Try to avoid conversion to `Base.OneTo{Int}`, or just convert
# the element type to `Int` with `Int.(...)`.
# When the axes of `a.array` are `GradedOneTo`, the block is `LabelledUnitRange`,
# When the axes of `parent(a)` are `GradedOneTo`, the block is `LabelledUnitRange`,
# which has element type `LabelledInteger`. That causes conversion problems
# in some generic Base Julia code, for example when printing `BlockView`.
return ntuple(ndims(a)) do dim
return Base.OneTo{Int}(only(axes(axes(a.array, dim)[a.block[dim]])))
return Base.OneTo{Int}(only(axes(axes(parent(a), dim)[a.block[dim]])))
end
end
function Base.size(a::BlockView)
return length.(axes(a))
end
function Base.getindex(a::BlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
return blocks(a.array)[Int.(a.block)...][index...]
return blocks(parent(a))[Int.(a.block)...][index...]
end
function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) where {N}
blocks(a.array)[Int.(a.block)...] = blocks(a.array)[Int.(a.block)...]
blocks(a.array)[Int.(a.block)...][index...] = value
I = Int.(a.block)
if !isstored(blocks(parent(a)), I...)
unstored_value = getunstoredindex(blocks(parent(a)), I...)
setunstoredindex!(blocks(parent(a)), unstored_value, I...)
end
blocks(parent(a))[I...][index...] = value
return a
end

Expand All @@ -538,15 +551,15 @@ function SparseArraysBase.storedlength(a::BlockView)
# a Bool in `BlockView`.
I = CartesianIndex(Int.(a.block))
# TODO: Use `block_eachstoredindex`.
if I eachstoredindex(blocks(a.array))
return storedlength(blocks(a.array)[I])
if I eachstoredindex(blocks(parent(a)))
return storedlength(blocks(parent(a))[I])
end
return 0
end

## # Allow more fine-grained control:
## function ArrayLayouts.sub_materialize(layout, a::BlockView, ax)
## return blocks(a.array)[Int.(a.block)...]
## return blocks(parent(a))[Int.(a.block)...]
## end
## function ArrayLayouts.sub_materialize(layout, a::BlockView)
## return sub_materialize(layout, a, axes(a))
Expand All @@ -555,7 +568,7 @@ end
## return sub_materialize(MemoryLayout(a), a)
## end
function ArrayLayouts.sub_materialize(a::BlockView)
return blocks(a.array)[Int.(a.block)...]
return blocks(parent(a))[Int.(a.block)...]
end

function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
Expand Down

0 comments on commit bd54ec3

Please sign in to comment.