diff --git a/include/neso_particles/containers/sym_vector.hpp b/include/neso_particles/containers/sym_vector.hpp index aded638f..e0a39afc 100644 --- a/include/neso_particles/containers/sym_vector.hpp +++ b/include/neso_particles/containers/sym_vector.hpp @@ -122,7 +122,7 @@ inline void create_kernel_arg(ParticleLoopIteration &iterationx, T *const *const **rhs, Access::SymVector::Read &lhs) { lhs.cell = iterationx.cellx; - lhs.layer = iterationx.loop_layerx; + lhs.layer = iterationx.layerx; lhs.ptr = rhs; } /** @@ -132,7 +132,7 @@ template inline void create_kernel_arg(ParticleLoopIteration &iterationx, T ****rhs, Access::SymVector::Write &lhs) { lhs.cell = iterationx.cellx; - lhs.layer = iterationx.loop_layerx; + lhs.layer = iterationx.layerx; lhs.ptr = rhs; } diff --git a/test/test_particle_loop_sym_vector.cpp b/test/test_particle_loop_sym_vector.cpp index 2c628226..f2078045 100644 --- a/test/test_particle_loop_sym_vector.cpp +++ b/test/test_particle_loop_sym_vector.cpp @@ -300,3 +300,70 @@ TEST(ParticleLoop, sym_vector) { A->free(); mesh->free(); } + +TEST(ParticleLoop, sub_group_sym_vector) { + auto A = particle_loop_common(); + auto domain = A->domain; + auto mesh = domain->mesh; + const int cell_count = mesh->get_cell_count(); + + auto aa = particle_sub_group( + A, [=](auto ID) { return ID.at(0) % 2 == 0; }, + Access::read(Sym("ID"))); + + auto si = sym_vector(aa, {Sym("ID")}); + + particle_loop( + A, [=](auto LOOP_INDEX) { LOOP_INDEX.at(0) = -1; }, + Access::write(Sym("LOOP_INDEX"))) + ->execute(); + + particle_loop( + aa, + [=](auto dats_int, auto LOOP_INDEX) { + LOOP_INDEX.at(0) = dats_int.at(0, 0); + }, + Access::read(si), Access::write(Sym("LOOP_INDEX"))) + ->execute(); + + for (int cellx = 0; cellx < cell_count; cellx++) { + auto id = A->get_cell(Sym("ID"), cellx); + auto loop_index = A->get_cell(Sym("LOOP_INDEX"), cellx); + const int nrow = id->nrow; + + // for each particle in the cell + for (int rowx = 0; rowx < nrow; rowx++) { + if (id->at(rowx, 0) % 2 == 0) { + ASSERT_EQ(loop_index->at(rowx, 0), id->at(rowx, 0)); + } else { + ASSERT_EQ(loop_index->at(rowx, 0), -1); + } + } + } + + particle_loop( + aa, + [=](auto dats_int, auto LOOP_INDEX) { + dats_int.at(0, 0) = LOOP_INDEX.at(0) + 2; + }, + Access::write(si), Access::read(Sym("LOOP_INDEX"))) + ->execute(); + + for (int cellx = 0; cellx < cell_count; cellx++) { + auto id = A->get_cell(Sym("ID"), cellx); + auto loop_index = A->get_cell(Sym("LOOP_INDEX"), cellx); + const int nrow = id->nrow; + + // for each particle in the cell + for (int rowx = 0; rowx < nrow; rowx++) { + if (id->at(rowx, 0) % 2 == 0) { + ASSERT_EQ(loop_index->at(rowx, 0) + 2, id->at(rowx, 0)); + } else { + ASSERT_EQ(loop_index->at(rowx, 0), -1); + } + } + } + + A->free(); + mesh->free(); +}