Skip to content

Commit

Permalink
[Draft] MultiFab: CuPy Test
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Oct 7, 2022
1 parent 5865c3e commit 5395043
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 4 deletions.
21 changes: 20 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,28 @@ def distmap(boxarr):

@pytest.fixture(params=list(itertools.product([1, 3], [0, 1])))
def mfab(boxarr, distmap, request):
"""MultiFab for tests"""
"""MultiFab that is either managed or device"""
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amrex.MultiFab(boxarr, distmap, num_components, num_ghost)
mfab.set_val(0.0, 0, num_components)
return mfab


@pytest.mark.skipif(
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
@pytest.fixture(params=list(itertools.product([1, 3], [0, 1])))
def mfab_device(boxarr, distmap, request):
"""MultiFab that resides purely on the device"""
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amrex.MultiFab(
boxarr,
distmap,
num_components,
num_ghost,
amrex.MFInfo().set_arena(amrex.The_Device_Arena()),
)
mfab.set_val(0.0, 0, num_components)
return mfab
82 changes: 79 additions & 3 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,89 @@ def test_mfab_ops_cuda_numba():
@pytest.mark.skipif(
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_cupy():
@pytest.mark.parametrize("nghost", [0, 1])
def test_mfab_ops_cuda_cupy(mfab_device, nghost):
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html
import cupy as cp
import cupy.prof

# AMReX -> cupy
# arr_numba = cuda.as_cuda_array(arr4)
# TODO
ngv = mfab_device.nGrowVect
print(f"\n mfab_device={mfab_device}, mfab_device.nGrowVect={ngv}")

# assign 3
with cupy.prof.time_range("assign 3 [()]", color_id=0):
for mfi in mfab_device:
bx = mfi.tilebox().grow(ngv)
marr = mfab_device.array(mfi)
marr_cupy = cp.array(marr, copy=False)
# print(marr_cupy.shape) # 1, 32, 32, 32
# print(marr_cupy.dtype) # float64

# write and read into the marr_cupy
marr_cupy[()] = 3.0

# verify result with a .sum_unique
with cupy.prof.time_range("verify 3", color_id=0):
shape = 32**3 * 8
# print(mfab_device.shape)
sum_threes = mfab_device.sum_unique(comp=0, local=False)
assert sum_threes == shape * 3

# assign 2
with cupy.prof.time_range("assign 2 (set_val)", color_id=1):
mfab_device.set_val(2.0)
with cupy.prof.time_range("verify 2", color_id=1):
sum_twos = mfab_device.sum_unique(comp=0, local=False)
assert sum_twos == shape * 2

# assign 5
with cupy.prof.time_range("assign 5 (ones-like)", color_id=2):

def set_to_five(mm):
xp = cp.get_array_module(mm)
assert xp.__name__ == "cupy"
mm = xp.ones_like(mm) * 10.0
mm /= 2.0
return mm

for mfi in mfab_device:
bx = mfi.tilebox().grow(ngv)
marr = mfab_device.array(mfi)
marr_cupy = cp.array(marr, copy=False)

# write and read into the marr_cupy
fives_cp = set_to_five(marr_cupy)
marr_cupy[()] = 0.0
marr_cupy += fives_cp

# verify
with cupy.prof.time_range("verify 5", color_id=2):
sum = mfab_device.sum_unique(comp=0, local=False)
assert sum == shape * 5

# assign 7
with cupy.prof.time_range("assign 7 (fuse)", color_id=3):

@cp.fuse(kernel_name="set_to_seven")
def set_to_seven(x):
x += 7.0

for mfi in mfab_device:
bx = mfi.tilebox().grow(ngv)
marr = mfab_device.array(mfi)
marr_cupy = cp.array(marr, copy=False)

# write and read into the marr_cupy
marr_cupy[()] = 0.0
set_to_seven(marr_cupy)

# verify
with cupy.prof.time_range("verify 7", color_id=3):
sum = mfab_device.sum_unique(comp=0, local=False)
assert sum == shape * 7

# TODO: @jit.rawkernel()


@pytest.mark.skipif(
Expand Down

0 comments on commit 5395043

Please sign in to comment.