Skip to content

Commit

Permalink
use pcg32 for pssmlt
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jun 2, 2024
1 parent 163bb6f commit 5e5f383
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/integrators/pssmlt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PSSMLTSampler {

public:
struct State {
UInt rng_state;
PCG32 rng;
ULong current_iteration;
Bool large_step;
ULong last_large_step_iteration;
Expand All @@ -62,7 +62,7 @@ class PSSMLTSampler {
private:
uint _chains{};
uint _pss_dim{};
Buffer<uint> _rng_buffer;
Buffer<ulong2> _rng_buffer;
Buffer<ulong> _current_iteration_buffer;
Buffer<uint> _large_step_and_initialized_dimensions_buffer;
Buffer<ulong> _last_large_step_iteration_buffer;
Expand All @@ -77,7 +77,7 @@ class PSSMLTSampler {
"Too many primary samples.");
command_buffer << synchronize();
if (auto n = next_pow2(chains); n > _rng_buffer.size()) {
_rng_buffer = _device.create_buffer<uint>(n);
_rng_buffer = _device.create_buffer<ulong2>(n);
_current_iteration_buffer = _device.create_buffer<ulong>(n);
_large_step_and_initialized_dimensions_buffer = _device.create_buffer<uint>(n);
_last_large_step_iteration_buffer = _device.create_buffer<ulong>(n);
Expand Down Expand Up @@ -152,19 +152,19 @@ class PSSMLTSampler {
};
// Reset Xi if a large step took place in the meantime
$if (Xi.last_modification < _state->last_large_step_iteration) {
Xi.value = lcg(_state->rng_state);
Xi.value = _state->rng.uniform_float();
Xi.last_modification = _state->last_large_step_iteration;
};
// Apply remaining sequence of mutations to _sample_
Xi->backup();
$if (_state->large_step) {
Xi.value = lcg(_state->rng_state);
Xi.value = _state->rng.uniform_float();
}
$else {
auto nSmall = compute::cast<uint>(_state->current_iteration - Xi.last_modification);
// Apply _nSmall_ small step mutations
// Sample the standard normal distribution N(0, 1)
auto normalSample = sqrt_two * _erf_inv(2.f * lcg(_state->rng_state) - 1.f);
auto normalSample = sqrt_two * _erf_inv(2.f * _state->rng.uniform_float() - 1.f);
// Compute the effective standard deviation and apply perturbation to Xi
auto effSigma = _sigma * sqrt(cast<float>(nSmall));
Xi.value = fract(Xi.value + normalSample * effSigma);
Expand All @@ -181,7 +181,7 @@ class PSSMLTSampler {

void create(Expr<uint> chain_index, Expr<uint> rng_sequence) noexcept {
_state = luisa::make_unique<State>(State{
.rng_state = xxhash32(rng_sequence),
.rng = PCG32{rng_sequence},
.current_iteration = 0ull,
.large_step = true,
.last_large_step_iteration = 0ull,
Expand All @@ -196,7 +196,7 @@ class PSSMLTSampler {
auto large_step_and_dimensions = _large_step_and_initialized_dimensions_buffer->read(chain_index);
auto last_large_step_iteration = _last_large_step_iteration_buffer->read(chain_index);
_state = luisa::make_unique<State>(State{
.rng_state = rng_state,
.rng = PCG32{rng_state.x, rng_state.y},
.current_iteration = current_iteration,
.large_step = (large_step_and_dimensions & 1u) != 0u,
.last_large_step_iteration = last_large_step_iteration,
Expand All @@ -206,7 +206,7 @@ class PSSMLTSampler {
}

void save() noexcept {
_rng_buffer->write(_state->chain_index, _state->rng_state);
_rng_buffer->write(_state->chain_index, make_ulong2(_state->rng.state(), _state->rng.inc()));
_current_iteration_buffer->write(_state->chain_index, _state->current_iteration);
_large_step_and_initialized_dimensions_buffer->write(
_state->chain_index, ite(_state->large_step, 1u, 0u) | (_state->initialized_dimensions << 1u));
Expand Down Expand Up @@ -249,7 +249,7 @@ class PSSMLTSampler {

void start_iteration() noexcept {
_state->current_iteration = _state->current_iteration + 1ull;
_state->large_step = lcg(_state->rng_state) < _large_step_probability;
_state->large_step = _state->rng.uniform_float() < _large_step_probability;
}
};

Expand Down

0 comments on commit 5e5f383

Please sign in to comment.