Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] BlockSparseArrays prototype #1205

Merged
merged 9 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Expand Down
62 changes: 62 additions & 0 deletions NDTensors/src/BlockSparseArrays/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# BlockSparseArrays.jl

A Julia `BlockSparseArray` type based on the `BlockArrays.jl` interface.

It wraps an elementwise `SparseArray` type that uses a dictionary-of-keys
to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`.
`BlockArrays` reinterprets the `SparseArray` as a blocked data structure.

```julia
using NDTensors.BlockSparseArrays
using BlockArrays
using Dictionaries

# Block dimensions
i1 = [2, 3]
i2 = [2, 3]

i_axes = (blockedrange(i1), blockedrange(i2))

function block_size(axes, block)
return length.(getindex.(axes, Block.(block.n)))
end

# Data
nz_blocks = [Block(1, 1), Block(2, 2)]
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
nz_block_lengths = prod.(nz_block_sizes)

# Blocks with discontiguous underlying data
d_blocks = randn.(nz_block_sizes)

# Blocks with contiguous underlying data
# d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths)
# d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)]

block_data = Dictionary([CartesianIndex(nz_block.n) for nz_block in nz_blocks], d_blocks)
block_storage = SparseArray{valtype(block_data),length(i_axes)}(block_data, blocklength.(i_axes))

B = BlockSparseArray(block_storage, i_axes)

# Access a block
B[Block(1, 1)]

# Access a non-zero block, returns a zero matrix
B[Block(1, 2)]

# Set a zero block
B[Block(1, 2)] = randn(2, 3)

# Matrix multiplication (not optimized for sparsity yet)
B * B
```

## TODO

- Define an `AbstractBlockSparseArray` type along with two concrete types, one with blocks that makes no assumptions about data layout (they could be slices into contiguous data or not), and one that uses a contiguous memory in the background (which could be any `AbstractVector` wrapped in a `PseudoBlockVector` that tracks the blocks as shown above).
- Define fast linear algebra (matmul, SVD, QR, etc.) that takes advantage of sparsity.
- Define tensor contraction and addition using the `TensorOperations.jl` tensor operations interface (`tensoradd!`, `tensorcontract!`, and `tensortrace!`). See `SparseArrayKit.jl` for examples of overloading for sparse data structures.
- Use `SparseArrayKit.jl` as the elementwise sparse array backend (it would need to be generalized a little,
for example it makes the assumption that `zero` is defined for the element type, which isn't the case when the values are matrices since it would need shape information, though it could output a universal zero tensor).
- Implement `SparseArrays` functionality such as `findnz`, `findall(!iszero, B)`, `nnz`, `nonzeros`, `dropzeros`, and `droptol!`, along with the block versions of those (which would get forwarded to the `SparseArray` data structure, where they are treated as elementwise sparsity). `SparseArrayKit.jl` has functions `nonzero_pairs`, `nonzero_keys`, `nonzero_values`, and `nonzero_length` which could have analagous block functions.
- Look at other packages that deal with block sparsity such as `BlockSparseMatrices.jl` and `BlockBandedMatrices.jl` for ideas on code design and interfaces.
10 changes: 10 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module BlockSparseArrays
using BlockArrays
using Dictionaries

export BlockSparseArray, SparseArray

include("sparsearray.jl")
include("blocksparsearray.jl")

end
86 changes: 86 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using BlockArrays: block

# Also add a version with contiguous underlying data.
struct BlockSparseArray{
T,N,R<:SparseArray{<:AbstractArray{T,N},N},BS<:NTuple{N,AbstractUnitRange{Int}}
} <: AbstractBlockArray{T,N}
blocks::R
axes::BS
end

Base.axes(block_arr::BlockSparseArray) = block_arr.axes

function Base.copy(block_arr::BlockSparseArray)
return BlockSparseArray(deepcopy(block_arr.blocks), copy.(block_arr.axes))
end

function BlockArrays.viewblock(block_arr::BlockSparseArray, block)
blks = block.n
@boundscheck blockcheckbounds(block_arr, blks...)
block_size = length.(getindex.(axes(block_arr), Block.(blks)))
# TODO: Make this `Zeros`?
zero = zeros(eltype(block_arr), block_size)
# return block_arr.blocks[blks...] # Fails because zero isn't defined
return get_nonzero(block_arr.blocks, blks, zero)
end

function Base.getindex(block_arr::BlockSparseArray{T,N}, bi::BlockIndex{N}) where {T,N}
@boundscheck blockcheckbounds(block_arr, Block(bi.I))
bl = view(block_arr, block(bi))
inds = bi.α
@boundscheck checkbounds(bl, inds...)
v = bl[inds...]
return v
end

function Base.setindex!(
block_arr::BlockSparseArray{T,N}, v, i::Vararg{Integer,N}
) where {T,N}
@boundscheck checkbounds(block_arr, i...)
block_indices = findblockindex.(axes(block_arr), i)
block = map(block_index -> Block(block_index.I), block_indices)
offsets = map(block_index -> only(block_index.α), block_indices)
block_view = @view block_arr[block...]
block_view[offsets...] = v
block_arr[block...] = block_view
return block_arr
end

function BlockArrays._check_setblock!(
block_arr::BlockSparseArray{T,N}, v, block::NTuple{N,Integer}
) where {T,N}
for i in 1:N
bsz = length(axes(block_arr, i)[Block(block[i])])
if size(v, i) != bsz
throw(
DimensionMismatch(
string(
"tried to assign $(size(v)) array to ",
length.(getindex.(axes(block_arr), block)),
" block",
),
),
)
end
end
end
mtfishman marked this conversation as resolved.
Show resolved Hide resolved

function Base.setindex!(
block_arr::BlockSparseArray{T,N}, v, block::Vararg{Block{1},N}
) where {T,N}
blks = Int.(block)
@boundscheck blockcheckbounds(block_arr, blks...)
@boundscheck BlockArrays._check_setblock!(block_arr, v, blks)
# This fails since it tries to replace the element
block_arr.blocks[blks...] = v
# Use .= here to overwrite data.
## block_view = @view block_arr[Block(blks)]
## block_view .= v
return block_arr
end

function Base.getindex(block_arr::BlockSparseArray{T,N}, i::Vararg{Integer,N}) where {T,N}
@boundscheck checkbounds(block_arr, i...)
v = block_arr[findblockindex.(axes(block_arr), i)...]
return v
end
31 changes: 31 additions & 0 deletions NDTensors/src/BlockSparseArrays/src/sparsearray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
struct SparseArray{T,N} <: AbstractArray{T,N}
data::Dictionary{CartesianIndex{N},T}
dims::NTuple{N,Int64}
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
end

Base.size(a::SparseArray) = a.dims

function Base.setindex!(a::SparseArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
set!(a.data, I, v)
return a
end
function Base.setindex!(a::SparseArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
return setindex!(a, v, CartesianIndex(I))
end

function Base.getindex(a::SparseArray{T,N}, I::CartesianIndex{N}) where {T,N}
return get(a.data, I, nothing)
end
function Base.getindex(a::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
return getindex(a, CartesianIndex(I))
end

# `getindex` but uses a default if the value is
# structurally zero.
function get_nonzero(a::SparseArray{T,N}, I::CartesianIndex{N}, zero) where {T,N}
@boundscheck checkbounds(a, I)
return get(a.data, I, zero)
end
function get_nonzero(a::SparseArray{T,N}, I::NTuple{N,Int}, zero) where {T,N}
return get_nonzero(a, CartesianIndex(I), zero)
end
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("BlockSparseArrays/src/BlockSparseArrays.jl")
using .BlockSparseArrays

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down
Loading