Skip to content

Commit

Permalink
Use kernel launchers instead of macros (celeritas-project#1477)
Browse files Browse the repository at this point in the history
* Define CPU kernel launcher function
* Refactor reseed to use executor
* Refactor RNG state init to use executor
  • Loading branch information
sethrj authored Nov 1, 2024
1 parent a70ca85 commit 4dfb22e
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 132 deletions.
6 changes: 4 additions & 2 deletions src/celeritas/random/CuHipRngData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ void resize(CuHipRngStateData<Ownership::value, M>* state,
resize(&state->rng, size);
detail::CuHipRngInitData<Ownership::value, M> init_data;
init_data.seeds = host_seeds;
detail::rng_state_init(
make_const_ref(data), make_ref(*state), make_const_ref(init_data));
detail::rng_state_init(make_const_ref(data),
make_ref(*state),
make_const_ref(init_data),
stream);
}

//---------------------------------------------------------------------------//
Expand Down
22 changes: 4 additions & 18 deletions src/celeritas/random/RngReseed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
//---------------------------------------------------------------------------//
#include "RngReseed.hh"

#include "corecel/cont/Range.hh"
#include "corecel/sys/ThreadId.hh"
#include "corecel/sys/KernelLauncher.hh"

#include "RngEngine.hh"
#include "detail/RngReseedExecutor.hh"

namespace celeritas
{
Expand All @@ -28,21 +27,8 @@ void reseed_rng(HostCRef<RngParamsData> const& params,
StreamId,
UniqueEventId event_id)
{
CELER_EXPECT(event_id);
static_assert(sizeof(ull_int) == sizeof(UniqueEventId::size_type));

ull_int size = state.size();
#if CELERITAS_OPENMP == CELERITAS_OPENMP_TRACK
# pragma omp parallel for
#endif
for (TrackSlotId::size_type i = 0; i < size; ++i)
{
RngEngine::Initializer_t init;
init.seed = params.seed;
init.subsequence = event_id.unchecked_get() * size + i;
RngEngine engine(params, state, TrackSlotId{i});
engine = init;
}
launch_kernel(state.size(),
detail::RngReseedExecutor{params, state, event_id});
}

//---------------------------------------------------------------------------//
Expand Down
52 changes: 7 additions & 45 deletions src/celeritas/random/RngReseed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,13 @@
//---------------------------------------------------------------------------//
#include "RngReseed.hh"

#include "corecel/DeviceRuntimeApi.hh"
#include "corecel/Types.hh"
#include "corecel/sys/KernelLauncher.device.hh"

#include "corecel/Assert.hh"
#include "corecel/sys/Device.hh"
#include "corecel/sys/KernelParamCalculator.device.hh"
#include "corecel/sys/Stream.hh"

#include "RngEngine.hh"
#include "detail/RngReseedExecutor.hh"

namespace celeritas
{
namespace
{
//---------------------------------------------------------------------------//
// KERNELS
//---------------------------------------------------------------------------//
/*!
* Reinitialize the RNG states on device at the start of an event.
*/
__global__ void reseed_rng_kernel(DeviceCRef<RngParamsData> const params,
DeviceRef<RngStateData> const state,
UniqueEventId::size_type event_id)
{
auto tid = TrackSlotId{
celeritas::KernelParamCalculator::thread_id().unchecked_get()};
if (tid.get() < state.size())
{
TrackSlotId tsid{tid.unchecked_get()};
RngEngine::Initializer_t init;
init.seed = params.seed;
init.subsequence = event_id * state.size() + tsid.get();
RngEngine rng(params, state, tsid);
rng = init;
}
}

//---------------------------------------------------------------------------//
} // namespace

//---------------------------------------------------------------------------//
// KERNEL INTERFACE
//---------------------------------------------------------------------------//
Expand All @@ -61,16 +29,10 @@ void reseed_rng(DeviceCRef<RngParamsData> const& params,
StreamId stream,
UniqueEventId event_id)
{
CELER_EXPECT(state);
CELER_EXPECT(params);
CELER_EXPECT(stream);

CELER_LAUNCH_KERNEL(reseed_rng,
state.size(),
celeritas::device().stream(stream).get(),
params,
state,
event_id.get());
detail::RngReseedExecutor execute_thread{params, state, event_id};
static KernelLauncher<decltype(execute_thread)> const launch_kernel(
"rng-reseed");
launch_kernel(state.size(), stream, execute_thread);
}

//---------------------------------------------------------------------------//
Expand Down
1 change: 0 additions & 1 deletion src/celeritas/random/RngReseed.hh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "corecel/Assert.hh"
#include "corecel/Macros.hh"
#include "corecel/Types.hh"
#include "corecel/data/Collection.hh"
#include "celeritas/Types.hh"

#include "RngData.hh"
Expand Down
19 changes: 6 additions & 13 deletions src/celeritas/random/detail/CuHipRngStateInit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
//---------------------------------------------------------------------------//
#include "CuHipRngStateInit.hh"

#include "corecel/cont/Range.hh"
#include "corecel/sys/ThreadId.hh"

#include "../CuHipRngData.hh"
#include "../CuHipRngEngine.hh"
#include "corecel/Assert.hh"
#include "corecel/sys/KernelLauncher.hh"

namespace celeritas
{
Expand All @@ -23,15 +20,11 @@ namespace detail
*/
void rng_state_init(HostCRef<CuHipRngParamsData> const& params,
HostRef<CuHipRngStateData> const& state,
HostCRef<CuHipRngInitData> const& seeds)
HostCRef<CuHipRngInitData> const& seeds,
StreamId)
{
for (auto tid : range(TrackSlotId{seeds.size()}))
{
CuHipRngInitializer init;
init.seed = seeds.seeds[tid];
CuHipRngEngine engine(params, state, tid);
engine = init;
}
CELER_EXPECT(state.size() == seeds.size());
launch_kernel(state.size(), RngSeedExecutor{params, state, seeds});
}

//---------------------------------------------------------------------------//
Expand Down
43 changes: 7 additions & 36 deletions src/celeritas/random/detail/CuHipRngStateInit.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,13 @@
//---------------------------------------------------------------------------//
#include "CuHipRngStateInit.hh"

#include "corecel/DeviceRuntimeApi.hh"

#include "corecel/Assert.hh"
#include "corecel/sys/Device.hh"
#include "corecel/sys/KernelParamCalculator.device.hh"

#include "../CuHipRngEngine.hh"
#include "corecel/sys/KernelLauncher.device.hh"

namespace celeritas
{
namespace detail
{
namespace
{
//---------------------------------------------------------------------------//
// KERNELS
//---------------------------------------------------------------------------//
/*!
* Initialize the RNG states on device from seeds randomly generated on host.
*/
__global__ void
rng_state_init_kernel(DeviceCRef<CuHipRngParamsData> const params,
DeviceRef<CuHipRngStateData> const state,
DeviceCRef<CuHipRngInitData> const seeds)
{
auto tid = TrackSlotId{
celeritas::KernelParamCalculator::thread_id().unchecked_get()};
if (tid.get() < state.size())
{
TrackSlotId tsid{tid.unchecked_get()};
CuHipRngInitializer init;
init.seed = seeds.seeds[tsid];
CuHipRngEngine rng(params, state, tsid);
rng = init;
}
}

//---------------------------------------------------------------------------//
} // namespace

//---------------------------------------------------------------------------//
// KERNEL INTERFACE
//---------------------------------------------------------------------------//
Expand All @@ -55,10 +22,14 @@ rng_state_init_kernel(DeviceCRef<CuHipRngParamsData> const params,
*/
void rng_state_init(DeviceCRef<CuHipRngParamsData> const& params,
DeviceRef<CuHipRngStateData> const& state,
DeviceCRef<CuHipRngInitData> const& seeds)
DeviceCRef<CuHipRngInitData> const& seeds,
StreamId stream)
{
CELER_EXPECT(state.size() == seeds.size());
CELER_LAUNCH_KERNEL(rng_state_init, seeds.size(), 0, params, state, seeds);
detail::RngSeedExecutor execute_thread{params, state, seeds};
static KernelLauncher<decltype(execute_thread)> const launch_kernel(
"rng-reseed");
launch_kernel(state.size(), stream, execute_thread);
}

//---------------------------------------------------------------------------//
Expand Down
37 changes: 34 additions & 3 deletions src/celeritas/random/detail/CuHipRngStateInit.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "corecel/data/Collection.hh"

#include "../CuHipRngData.hh"
#include "../CuHipRngEngine.hh"

namespace celeritas
{
Expand Down Expand Up @@ -42,15 +43,44 @@ struct CuHipRngInitData
}
};

//---------------------------------------------------------------------------//
/*!
* Initialize the given track slot.
*/
struct RngSeedExecutor
{
NativeCRef<CuHipRngParamsData> const params;
NativeRef<CuHipRngStateData> const state;
NativeCRef<CuHipRngInitData> const seeds;

//! Initialize the given track slot
inline CELER_FUNCTION void operator()(TrackSlotId tid) const
{
CELER_EXPECT(tid < state.size());
CuHipRngInitializer init;
init.seed = seeds.seeds[tid];
CuHipRngEngine rng{params, state, tid};
rng = init;
}

//! Initialize from the given thread
CELER_FORCEINLINE_FUNCTION void operator()(ThreadId tid) const
{
return (*this)(TrackSlotId{tid.unchecked_get()});
}
};

//---------------------------------------------------------------------------//
// Initialize the RNG state on host/device
void rng_state_init(DeviceCRef<CuHipRngParamsData> const& params,
DeviceRef<CuHipRngStateData> const& state,
DeviceCRef<CuHipRngInitData> const& seeds);
DeviceCRef<CuHipRngInitData> const& seeds,
StreamId stream);

void rng_state_init(HostCRef<CuHipRngParamsData> const& params,
HostRef<CuHipRngStateData> const& state,
HostCRef<CuHipRngInitData> const& seeds);
HostCRef<CuHipRngInitData> const& seeds,
StreamId);

#if !CELER_USE_DEVICE
//---------------------------------------------------------------------------//
Expand All @@ -59,7 +89,8 @@ void rng_state_init(HostCRef<CuHipRngParamsData> const& params,
*/
inline void rng_state_init(DeviceCRef<CuHipRngParamsData> const&,
DeviceRef<CuHipRngStateData> const&,
DeviceCRef<CuHipRngInitData> const&)
DeviceCRef<CuHipRngInitData> const&,
StreamId)
{
CELER_ASSERT_UNREACHABLE();
}
Expand Down
83 changes: 83 additions & 0 deletions src/celeritas/random/detail/RngReseedExecutor.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//----------------------------------*-C++-*----------------------------------//
// Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file celeritas/random/detail/RngReseedExecutor.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/Types.hh"
#include "corecel/sys/ThreadId.hh"
#include "celeritas/Types.hh"
#include "celeritas/random/RngData.hh"
#include "celeritas/random/RngEngine.hh"

namespace celeritas
{
namespace detail
{
//---------------------------------------------------------------------------//
/*!
* Reinitialize a track's random state from a unique event ID.
*/
class RngReseedExecutor
{
public:
using ParamsCRef = NativeCRef<RngParamsData>;
using StateRef = NativeRef<RngStateData>;

public:
// Construct with state and event ID
inline CELER_FUNCTION
RngReseedExecutor(ParamsCRef const&, StateRef const&, UniqueEventId id);

// Initialize the given track slot
inline CELER_FUNCTION void operator()(TrackSlotId tid) const;

//! Initialize from the given thread
CELER_FORCEINLINE_FUNCTION void operator()(ThreadId tid) const
{
return (*this)(TrackSlotId{tid.unchecked_get()});
}

private:
ParamsCRef const params_;
StateRef const state_;
UniqueEventId::size_type stride_;
};

//---------------------------------------------------------------------------//
// INLINE DEFINITIONS
//---------------------------------------------------------------------------//
/*!
* Construct with state and event ID.
*/
CELER_FUNCTION RngReseedExecutor::RngReseedExecutor(ParamsCRef const& params,
StateRef const& state,
UniqueEventId id)
: params_{params}, state_{state}, stride_{id.unchecked_get() * state.size()}
{
CELER_EXPECT(params_ && state_);
CELER_EXPECT(id);
static_assert(sizeof(ull_int) == sizeof(UniqueEventId::size_type));
}

//---------------------------------------------------------------------------//
/*!
* Initialize the given track slot.
*/
CELER_FUNCTION void RngReseedExecutor::operator()(TrackSlotId tid) const
{
CELER_EXPECT(tid < state_.size());
RngEngine::Initializer_t init;
init.seed = params_.seed;
init.subsequence = stride_ + tid.unchecked_get();

RngEngine engine(params_, state_, tid);
engine = init;
}

//---------------------------------------------------------------------------//
} // namespace detail
} // namespace celeritas
Loading

0 comments on commit 4dfb22e

Please sign in to comment.