-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BlockSparseArrays] Initial support for more general blocks, such as GPU blocks #1560
Conversation
Here is a demonstration: using BlockArrays: Block, blockedrange
using Metal: mtl
using NDTensors.BlockSparseArrays: BlockSparseArray
function randn_block_diag(n)
return BlockSparseArray(
[Block(1, 1), Block(2, 2)],
[randn(Float32, n, n), randn(Float32, n, n)],
blockedrange.(([n, n], [n, n])),
)
end
n = 4000
a = randn_block_diag(n)
b = randn_block_diag(n)
@time a * b
a_gpu = mtl(a)
b_gpu = mtl(b)
@time a_gpu * b_gpu which outputs: 0.478664 seconds (69 allocations: 122.073 MiB, 4.58% gc time)
0.000443 seconds (442 allocations: 13.984 KiB) |
…l into BlockSparseArrays_gpu
Some basic operations with using Distributed: @everywhere, addprocs
addprocs(2)
using Adapt: Adapt, adapt
using BlockArrays: Block, blockedrange
@everywhere using Dagger: Dagger, AutoBlocks, DArray, distribute
using NDTensors.BlockSparseArrays: BlockSparseArray, BlockZero, block_size
struct DArrayAdaptor end
function Adapt.adapt_storage(::DArrayAdaptor, a::AbstractArray)
return distribute(a)
end
function Dagger.distribute(a::BlockSparseArray)
return adapt(DArrayAdaptor(), a)
end
function randn_block_diag(n)
return BlockSparseArray(
[Block(1, 1), Block(2, 2)],
[randn(Float32, n, n), randn(Float32, n, n)],
blockedrange.(([n, n], [n, n])),
)
end
function (f::BlockZero)(arraytype::Type{<:DArray}, I::CartesianIndex)
blck_size = block_size(f.axes, Block(Tuple(I)))
return zeros(AutoBlocks(), eltype(arraytype), blck_size...)
end
n = 4
a = randn_block_diag(n)
b = randn_block_diag(n)
c = a * b
a_d = distribute(a)
b_d = distribute(b)
c_d = a_d * b_d
c ≈ c_d |
I've converted some of the BlockSparseArray tests to also (optionally) run on GPU backends, and on CPU run with the JLArray backend. It caught a few scalar indexing bugs which we can investigate in future PRs (@kmp5VT). But basic slicing, scalar multiplication, addition, and matrix multiplication operations work on GPU. I'll merge this once tests pass and it can be used as a starting point for future work. |
This fixes broken functionality for
BlockSparseArrays
that have blocks that aren't justArray
, such as blocks that are GPU arrays. Before this PR, the library supported constructing block sparse arrays with more general blocks, but functionality like adding or multiplying them was broken or implicitly moved data to CPU.To-do:
Adapt.jl
overloads forAbstractBlockSparseArray
in terms of mapping adapt over nonzero/stored blocks.similartype
toTypeParameterAccessors
(completed in [TypeParameterAccessors]similartype
#1561).Future work:
BlockSparseArray
withDiagonal
blocks on GPU.BlockSparseMatrix
withDiagonal
blocks on GPU.Diagonal
(for example, accessing non-allocated blocks is currently broken).@kmp5VT this should also help with making block sparse arrays that have distributed blocks, though I haven't tested that. But also this PR should give you some guidance on where you might look to fix issues that come up with that, like where in the code the output of matrix multiplication is defined so that can be customized if needed.