Skip to content

Commit

Permalink
MultiFab: Fix Fixture Lifetime (#84)
Browse files Browse the repository at this point in the history
pytest fixtures are cached by default. That's not ideal for us,
since we want to clean up and re-init AMReX and its memory
arenas between tests.

Thus, we now create a multifab fixture that is just a generator,
so that the returned/yielded object can be cached - but the
generated MultiFab is actually destroyed as we would expect once
the test finishes.
  • Loading branch information
ax3l authored Oct 17, 2022
1 parent 86273bd commit fc9d0ea
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
23 changes: 16 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,19 @@ def distmap(boxarr):


@pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1])))
def mfab(boxarr, distmap, request):
"""MultiFab for tests"""
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amrex.MultiFab(boxarr, distmap, num_components, num_ghost)
mfab.set_val(0.0, 0, num_components)
return mfab
def make_mfab(boxarr, distmap, request):
"""MultiFab that is either managed or device:
The MultiFab object itself is not a fixture because we want to avoid caching
it between amrex.initialize/finalize calls of various tests.
https://github.com/pytest-dev/pytest/discussions/10387
https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764
"""

def create():
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amrex.MultiFab(boxarr, distmap, num_components, num_ghost)
mfab.set_val(0.0, 0, num_components)
return mfab

return create
11 changes: 6 additions & 5 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import amrex


@pytest.mark.parametrize("nghost", [0, 1])
def test_mfab_loop(mfab, nghost):
def test_mfab_loop(make_mfab):
mfab = make_mfab()
ngv = mfab.nGrowVect
print(f"\n mfab={mfab}, mfab.nGrowVect={ngv}")

Expand Down Expand Up @@ -77,7 +77,8 @@ def test_mfab_loop(mfab, nghost):
# TODO


def test_mfab_simple(mfab):
def test_mfab_simple(make_mfab):
mfab = make_mfab()
assert mfab.is_all_cell_centered
# assert(all(not mfab.is_nodal(i) for i in [-1, 0, 1, 2])) # -1??
assert all(not mfab.is_nodal(i) for i in [0, 1, 2])
Expand Down Expand Up @@ -142,8 +143,8 @@ def test_mfab_ops(boxarr, distmap, nghost):
np.testing.assert_allclose(dst.max(0), 150.0)


@pytest.mark.parametrize("nghost", [0, 1])
def test_mfab_mfiter(mfab, nghost):
def test_mfab_mfiter(make_mfab):
mfab = make_mfab()
assert iter(mfab).is_valid
assert iter(mfab).length == 8

Expand Down

0 comments on commit fc9d0ea

Please sign in to comment.