diff --git a/tests/conftest.py b/tests/conftest.py index c697fe1d..d06729be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,10 +68,19 @@ def distmap(boxarr): @pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1]))) -def mfab(boxarr, distmap, request): - """MultiFab for tests""" - num_components = request.param[0] - num_ghost = request.param[1] - mfab = amrex.MultiFab(boxarr, distmap, num_components, num_ghost) - mfab.set_val(0.0, 0, num_components) - return mfab +def make_mfab(boxarr, distmap, request): + """MultiFab that is either managed or device: + The MultiFab object itself is not a fixture because we want to avoid caching + it between amrex.initialize/finalize calls of various tests. + https://github.com/pytest-dev/pytest/discussions/10387 + https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764 + """ + + def create(): + num_components = request.param[0] + num_ghost = request.param[1] + mfab = amrex.MultiFab(boxarr, distmap, num_components, num_ghost) + mfab.set_val(0.0, 0, num_components) + return mfab + + return create diff --git a/tests/test_multifab.py b/tests/test_multifab.py index d86af212..fc92c912 100644 --- a/tests/test_multifab.py +++ b/tests/test_multifab.py @@ -6,8 +6,8 @@ import amrex -@pytest.mark.parametrize("nghost", [0, 1]) -def test_mfab_loop(mfab, nghost): +def test_mfab_loop(make_mfab): + mfab = make_mfab() ngv = mfab.nGrowVect print(f"\n mfab={mfab}, mfab.nGrowVect={ngv}") @@ -77,7 +77,8 @@ def test_mfab_loop(mfab, nghost): # TODO -def test_mfab_simple(mfab): +def test_mfab_simple(make_mfab): + mfab = make_mfab() assert mfab.is_all_cell_centered # assert(all(not mfab.is_nodal(i) for i in [-1, 0, 1, 2])) # -1?? assert all(not mfab.is_nodal(i) for i in [0, 1, 2]) @@ -142,8 +143,8 @@ def test_mfab_ops(boxarr, distmap, nghost): np.testing.assert_allclose(dst.max(0), 150.0) -@pytest.mark.parametrize("nghost", [0, 1]) -def test_mfab_mfiter(mfab, nghost): +def test_mfab_mfiter(make_mfab): + mfab = make_mfab() assert iter(mfab).is_valid assert iter(mfab).length == 8