Skip to content

Commit

Permalink
reordered how local sizes are deduced to account for local memory req…
Browse files Browse the repository at this point in the history
…uirements better
  • Loading branch information
will-saunders-ukaea committed Nov 21, 2024
1 parent 580e473 commit de61a2e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
30 changes: 23 additions & 7 deletions include/neso_particles/loop/particle_loop_iteration_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,25 @@ class ParticleLoopBlockIterationSet {
return outer_size;
}

inline std::size_t
get_local_size(std::size_t local_size,
const std::size_t num_bytes_local, // this is per particle
const std::size_t stride) {
const std::size_t local_mem_size =
this->sycl_target->device_limits.local_mem_size;
const std::size_t num_bytes_per_block = stride * num_bytes_local;
NESOASSERT(num_bytes_per_block <= local_mem_size,
"Impossible to create a local range for this stride and local "
"memory size.");
const std::size_t max_num_blocks_per_workgroup =
(num_bytes_local == 0) ? local_size
: local_mem_size / num_bytes_per_block;
local_size = get_prev_power_of_two(max_num_blocks_per_workgroup);
NESOASSERT(local_size * stride * num_bytes_local <= local_mem_size,
"Failure to determine a local size for iteration set.");
return local_size;
}

public:
/// The last iteration set produced
std::vector<ParticleLoopBlockHost> iteration_set;
Expand Down Expand Up @@ -284,9 +303,8 @@ class ParticleLoopBlockIterationSet {
get_all_cells(std::size_t nbin = 16, std::size_t local_size = 256,
const std::size_t num_bytes_local = 0,
const std::size_t stride = 1) {
local_size = this->sycl_target->get_num_local_work_items(num_bytes_local,
local_size);
local_size /= stride;

local_size = this->get_local_size(local_size, num_bytes_local, stride);
nbin = std::min(nbin, this->ncell);
this->iteration_set.clear();

Expand Down Expand Up @@ -353,10 +371,8 @@ class ParticleLoopBlockIterationSet {
get_single_cell(const std::size_t cell, std::size_t local_size = 256,
const std::size_t num_bytes_local = 0,
const std::size_t stride = 1) {
local_size = this->sycl_target->get_num_local_work_items(num_bytes_local,
local_size);
local_size /= stride;
NESOASSERT(local_size > 0, "Cannot deduce a local size.");

local_size = this->get_local_size(local_size, num_bytes_local, stride);
this->iteration_set.clear();

const std::size_t npart =
Expand Down
4 changes: 2 additions & 2 deletions test/test_particle_loop_iteration_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ TEST_P(ParticleLoopLocalMem, iteration_set_base_stride) {
ASSERT_EQ(range_global.get(0), 1);
ASSERT_EQ(range_local.get(0), 1);
if (local_mem_required == 0) {
EXPECT_EQ(range_local.get(1), local_size / stride);
EXPECT_EQ(range_local.get(1), local_size);
} else {
ASSERT_TRUE(range_local.get(1) * local_mem_required * stride <=
local_mem_limit);
Expand Down Expand Up @@ -362,7 +362,7 @@ TEST_P(ParticleLoopLocalMem, iteration_set_base_stride) {
EXPECT_EQ(range_local.get(0), 1);

if (local_mem_required == 0) {
EXPECT_EQ(range_local.get(1), local_size / stride);
EXPECT_EQ(range_local.get(1), local_size);
} else {
ASSERT_TRUE(range_local.get(1) * local_mem_required * stride <=
local_mem_limit);
Expand Down

0 comments on commit de61a2e

Please sign in to comment.