Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BlockSparseArrays] Initial support for more general blocks, such as GPU blocks #1560

Merged
merged 20 commits into from
Nov 8, 2024

Conversation

mtfishman
Copy link
Member

@mtfishman mtfishman commented Nov 1, 2024

This fixes broken functionality for BlockSparseArrays that have blocks that aren't just Array, such as blocks that are GPU arrays. Before this PR, the library supported constructing block sparse arrays with more general blocks, but functionality like adding or multiplying them was broken or implicitly moved data to CPU.

To-do:

  • Add tests.
  • Fix slicing operations on GPU (they are still hardcoded to allocate the output on CPU).
  • Define Adapt.jl overloads for AbstractBlockSparseArray in terms of mapping adapt over nonzero/stored blocks.
  • Move similartype to TypeParameterAccessors (completed in [TypeParameterAccessors] similartype #1561).

Future work:

  • Fix slicing off-diagonal blocks of BlockSparseArray with Diagonal blocks on GPU.
  • Fix matrix multiplication of BlockSparseMatrix with Diagonal blocks on GPU.
  • Look into fixing operations for blocks that are Diagonal (for example, accessing non-allocated blocks is currently broken).

@kmp5VT this should also help with making block sparse arrays that have distributed blocks, though I haven't tested that. But also this PR should give you some guidance on where you might look to fix issues that come up with that, like where in the code the output of matrix multiplication is defined so that can be customized if needed.

@mtfishman
Copy link
Member Author

mtfishman commented Nov 1, 2024

Here is a demonstration:

using BlockArrays: Block, blockedrange
using Metal: mtl
using NDTensors.BlockSparseArrays: BlockSparseArray
function randn_block_diag(n)
  return BlockSparseArray(
    [Block(1, 1), Block(2, 2)],
    [randn(Float32, n, n), randn(Float32, n, n)],
    blockedrange.(([n, n], [n, n])),
  )
end
n = 4000
a = randn_block_diag(n)
b = randn_block_diag(n)
@time a * b
a_gpu = mtl(a)
b_gpu = mtl(b)
@time a_gpu * b_gpu

which outputs:

  0.478664 seconds (69 allocations: 122.073 MiB, 4.58% gc time)
  0.000443 seconds (442 allocations: 13.984 KiB)

@mtfishman mtfishman changed the title [BlockSparseArrays] Initial GPU support [WIP] [BlockSparseArrays] Initial GPU support Nov 7, 2024
@mtfishman mtfishman marked this pull request as draft November 7, 2024 15:05
@mtfishman
Copy link
Member Author

Some basic operations with Dagger.DArray like matrix multiplication now work:

using Distributed: @everywhere, addprocs
addprocs(2)
using Adapt: Adapt, adapt
using BlockArrays: Block, blockedrange
@everywhere using Dagger: Dagger, AutoBlocks, DArray, distribute
using NDTensors.BlockSparseArrays: BlockSparseArray, BlockZero, block_size

struct DArrayAdaptor end
function Adapt.adapt_storage(::DArrayAdaptor, a::AbstractArray)
  return distribute(a)
end

function Dagger.distribute(a::BlockSparseArray)
  return adapt(DArrayAdaptor(), a)
end

function randn_block_diag(n)
  return BlockSparseArray(
    [Block(1, 1), Block(2, 2)],
    [randn(Float32, n, n), randn(Float32, n, n)],
    blockedrange.(([n, n], [n, n])),
  )
end

function (f::BlockZero)(arraytype::Type{<:DArray}, I::CartesianIndex)
  blck_size = block_size(f.axes, Block(Tuple(I)))
  return zeros(AutoBlocks(), eltype(arraytype), blck_size...)
end

n = 4
a = randn_block_diag(n)
b = randn_block_diag(n)
c = a * b
a_d = distribute(a)
b_d = distribute(b)
c_d = a_d * b_d
c  c_d

@mtfishman mtfishman changed the title [WIP] [BlockSparseArrays] Initial GPU support [BlockSparseArrays] Initial GPU support Nov 8, 2024
@mtfishman mtfishman marked this pull request as ready for review November 8, 2024 01:44
@mtfishman
Copy link
Member Author

I've converted some of the BlockSparseArray tests to also (optionally) run on GPU backends, and on CPU run with the JLArray backend. It caught a few scalar indexing bugs which we can investigate in future PRs (@kmp5VT). But basic slicing, scalar multiplication, addition, and matrix multiplication operations work on GPU. I'll merge this once tests pass and it can be used as a starting point for future work.

@mtfishman mtfishman merged commit d1547b4 into main Nov 8, 2024
13 checks passed
@mtfishman mtfishman deleted the BlockSparseArrays_gpu branch November 8, 2024 14:46
@mtfishman mtfishman changed the title [BlockSparseArrays] Initial GPU support [BlockSparseArrays] Initial support for more general blocks, such as GPU blocks Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant