Skip to content

Commit

Permalink
[BlockSparseArrays] Simplifications of blocks for blocksparse `Adjo…
Browse files Browse the repository at this point in the history
…int` and `Transpose` (#1580)
  • Loading branch information
lkdvos authored Nov 15, 2024
1 parent 1a70987 commit dbec36b
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,17 @@ function BlockArrays.viewblock(
) where {T,N}
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)
end

# migrate wrapper layer for viewing `adjoint` and `transpose`.
for (f, F) in ((:adjoint, :Adjoint), (:transpose, :Transpose))
@eval begin
function Base.view(A::$F{<:Any,<:AbstractBlockSparseVector}, b::Block{1})
return $f(view(parent(A), b))
end

Base.view(A::$F{<:Any,<:AbstractBlockSparseMatrix}, b::Block{2}) = view(A, Tuple(b)...)
function Base.view(A::$F{<:Any,<:AbstractBlockSparseMatrix}, b1::Block{1}, b2::Block{1})
return $f(view(parent(A), b2, b1))
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -186,79 +186,8 @@ end
reverse_index(index) = reverse(index)
reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))

# Represents the array of arrays of a `Transpose`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
struct SparseTransposeBlocks{T,BlockType<:AbstractArray{T},Array<:Transpose{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Transpose)
return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseTransposeBlocks)
return reverse(size(blocks(parent(a.array))))
end
function Base.getindex(a::SparseTransposeBlocks, index::Vararg{Int,2})
return transpose(blocks(parent(a.array))[reverse(index)...])
end
# TODO: This should be handled by generic `AbstractSparseArray` code.
function Base.getindex(a::SparseTransposeBlocks, index::CartesianIndex{2})
return a[Tuple(index)...]
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
function Base.isassigned(a::SparseTransposeBlocks, index::Vararg{Int,2})
return isassigned(blocks(parent(a.array)), reverse(index)...)
end
function SparseArrayInterface.stored_indices(a::SparseTransposeBlocks)
return map(reverse_index, stored_indices(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.stored_length(a::SparseTransposeBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparseTransposeBlocks)
return error("Not implemented")
end

# Represents the array of arrays of a `Adjoint`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
struct SparseAdjointBlocks{T,BlockType<:AbstractArray{T},Array<:Adjoint{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Adjoint)
return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseAdjointBlocks)
return reverse(size(blocks(parent(a.array))))
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
function Base.getindex(a::SparseAdjointBlocks, index::Vararg{Int,2})
return blocks(parent(a.array))[reverse(index)...]'
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
# TODO: This should be handled by generic `AbstractSparseArray` code.
function Base.getindex(a::SparseAdjointBlocks, index::CartesianIndex{2})
return a[Tuple(index)...]
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
function Base.isassigned(a::SparseAdjointBlocks, index::Vararg{Int,2})
return isassigned(blocks(parent(a.array)), reverse(index)...)
end
function SparseArrayInterface.stored_indices(a::SparseAdjointBlocks)
return map(reverse_index, stored_indices(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.stored_length(a::SparseAdjointBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparseAdjointBlocks)
return error("Not implemented")
end
blocksparse_blocks(a::Transpose) = transpose(blocks(parent(a)))
blocksparse_blocks(a::Adjoint) = adjoint(blocks(parent(a)))

# Represents the array of arrays of a `SubArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
Expand Down
58 changes: 50 additions & 8 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using BlockArrays:
mortar
using Compat: @compat
using GPUArraysCore: @allowscalar
using LinearAlgebra: Adjoint, dot, mul!, norm
using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
using NDTensors.BlockSparseArrays:
@view!,
BlockSparseArray,
Expand All @@ -33,7 +33,7 @@ using NDTensors.GPUArraysCoreExtensions: cpu
using NDTensors.SparseArrayInterface: stored_length
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK
using NDTensors.TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset
using Test: @test, @test_broken, @test_throws, @testset, @inferred
include("TestBlockSparseArraysUtils.jl")

using NDTensors: NDTensors
Expand Down Expand Up @@ -70,12 +70,6 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test adjoint(a) isa Adjoint{elt,<:BlockSparseArray}
@test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray}
# could also be directly a BlockSparseArray

a = dev(BlockSparseArray{elt}([1], [1, 1]))
@allowscalar a[1, 2] = 1
@test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector
ah = adjoint(a)
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
end
@testset "Constructors" begin
# BlockSparseMatrix
Expand Down Expand Up @@ -210,6 +204,54 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
## @test b[Block()[]] == 2
end
end

@testset "Transpose" begin
a = dev(BlockSparseArray{elt}([2, 2], [3, 3, 1]))
a[Block(1, 1)] = dev(randn(elt, 2, 3))
a[Block(2, 3)] = dev(randn(elt, 2, 1))

at = @inferred transpose(a)
@test at isa Transpose
@test size(at) == reverse(size(a))
@test blocksize(at) == reverse(blocksize(a))
@test stored_length(at) == stored_length(a)
@test block_stored_length(at) == block_stored_length(a)
for bind in block_stored_indices(a)
bindt = Block(reverse(Int.(Tuple(bind))))
@test bindt in block_stored_indices(at)
end

@test @views(at[Block(1, 1)]) == transpose(a[Block(1, 1)])
@test @views(at[Block(1, 1)]) isa Transpose
@test @views(at[Block(3, 2)]) == transpose(a[Block(2, 3)])
# TODO: BlockView == AbstractArray calls scalar code
@test @allowscalar @views(at[Block(1, 2)]) == transpose(a[Block(2, 1)])
@test @views(at[Block(1, 2)]) isa Transpose
end

@testset "Adjoint" begin
a = dev(BlockSparseArray{elt}([2, 2], [3, 3, 1]))
a[Block(1, 1)] = dev(randn(elt, 2, 3))
a[Block(2, 3)] = dev(randn(elt, 2, 1))

at = @inferred adjoint(a)
@test at isa Adjoint
@test size(at) == reverse(size(a))
@test blocksize(at) == reverse(blocksize(a))
@test stored_length(at) == stored_length(a)
@test block_stored_length(at) == block_stored_length(a)
for bind in block_stored_indices(a)
bindt = Block(reverse(Int.(Tuple(bind))))
@test bindt in block_stored_indices(at)
end

@test @views(at[Block(1, 1)]) == adjoint(a[Block(1, 1)])
@test @views(at[Block(1, 1)]) isa Adjoint
@test @views(at[Block(3, 2)]) == adjoint(a[Block(2, 3)])
# TODO: BlockView == AbstractArray calls scalar code
@test @allowscalar @views(at[Block(1, 2)]) == adjoint(a[Block(2, 1)])
@test @views(at[Block(1, 2)]) isa Adjoint
end
end
@testset "Tensor algebra" begin
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Adapt: WrappedArray
using LinearAlgebra: Adjoint, Transpose

const WrappedAbstractSparseArray{T,N,A} = WrappedArray{
T,N,<:AbstractSparseArray,<:AbstractSparseArray{T,N}
Expand All @@ -7,3 +8,15 @@ const WrappedAbstractSparseArray{T,N,A} = WrappedArray{
const AnyAbstractSparseArray{T,N} = Union{
<:AbstractSparseArray{T,N},<:WrappedAbstractSparseArray{T,N}
}

function stored_indices(a::Adjoint)
return Iterators.map(I -> CartesianIndex(reverse(Tuple(I))), stored_indices(parent(a)))
end
stored_length(a::Adjoint) = stored_length(parent(a))
sparse_storage(a::Adjoint) = Iterators.map(adjoint, sparse_storage(parent(a)))

function stored_indices(a::Transpose)
return Iterators.map(I -> CartesianIndex(reverse(Tuple(I))), stored_indices(parent(a)))
end
stored_length(a::Transpose) = stored_length(parent(a))
sparse_storage(a::Transpose) = Iterators.map(transpose, sparse_storage(parent(a)))

0 comments on commit dbec36b

Please sign in to comment.