From 5e5f383387ed4f7922a4fde05390762a5ef3c9f9 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sun, 2 Jun 2024 16:20:11 +0800 Subject: [PATCH] use pcg32 for pssmlt --- src/integrators/pssmlt.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/integrators/pssmlt.cpp b/src/integrators/pssmlt.cpp index 3ee56d6e..ef911836 100644 --- a/src/integrators/pssmlt.cpp +++ b/src/integrators/pssmlt.cpp @@ -44,7 +44,7 @@ class PSSMLTSampler { public: struct State { - UInt rng_state; + PCG32 rng; ULong current_iteration; Bool large_step; ULong last_large_step_iteration; @@ -62,7 +62,7 @@ class PSSMLTSampler { private: uint _chains{}; uint _pss_dim{}; - Buffer _rng_buffer; + Buffer _rng_buffer; Buffer _current_iteration_buffer; Buffer _large_step_and_initialized_dimensions_buffer; Buffer _last_large_step_iteration_buffer; @@ -77,7 +77,7 @@ class PSSMLTSampler { "Too many primary samples."); command_buffer << synchronize(); if (auto n = next_pow2(chains); n > _rng_buffer.size()) { - _rng_buffer = _device.create_buffer(n); + _rng_buffer = _device.create_buffer(n); _current_iteration_buffer = _device.create_buffer(n); _large_step_and_initialized_dimensions_buffer = _device.create_buffer(n); _last_large_step_iteration_buffer = _device.create_buffer(n); @@ -152,19 +152,19 @@ class PSSMLTSampler { }; // Reset Xi if a large step took place in the meantime $if (Xi.last_modification < _state->last_large_step_iteration) { - Xi.value = lcg(_state->rng_state); + Xi.value = _state->rng.uniform_float(); Xi.last_modification = _state->last_large_step_iteration; }; // Apply remaining sequence of mutations to _sample_ Xi->backup(); $if (_state->large_step) { - Xi.value = lcg(_state->rng_state); + Xi.value = _state->rng.uniform_float(); } $else { auto nSmall = compute::cast(_state->current_iteration - Xi.last_modification); // Apply _nSmall_ small step mutations // Sample the standard normal distribution N(0, 1) - auto normalSample = sqrt_two * _erf_inv(2.f * lcg(_state->rng_state) - 1.f); + auto normalSample = sqrt_two * _erf_inv(2.f * _state->rng.uniform_float() - 1.f); // Compute the effective standard deviation and apply perturbation to Xi auto effSigma = _sigma * sqrt(cast(nSmall)); Xi.value = fract(Xi.value + normalSample * effSigma); @@ -181,7 +181,7 @@ class PSSMLTSampler { void create(Expr chain_index, Expr rng_sequence) noexcept { _state = luisa::make_unique(State{ - .rng_state = xxhash32(rng_sequence), + .rng = PCG32{rng_sequence}, .current_iteration = 0ull, .large_step = true, .last_large_step_iteration = 0ull, @@ -196,7 +196,7 @@ class PSSMLTSampler { auto large_step_and_dimensions = _large_step_and_initialized_dimensions_buffer->read(chain_index); auto last_large_step_iteration = _last_large_step_iteration_buffer->read(chain_index); _state = luisa::make_unique(State{ - .rng_state = rng_state, + .rng = PCG32{rng_state.x, rng_state.y}, .current_iteration = current_iteration, .large_step = (large_step_and_dimensions & 1u) != 0u, .last_large_step_iteration = last_large_step_iteration, @@ -206,7 +206,7 @@ class PSSMLTSampler { } void save() noexcept { - _rng_buffer->write(_state->chain_index, _state->rng_state); + _rng_buffer->write(_state->chain_index, make_ulong2(_state->rng.state(), _state->rng.inc())); _current_iteration_buffer->write(_state->chain_index, _state->current_iteration); _large_step_and_initialized_dimensions_buffer->write( _state->chain_index, ite(_state->large_step, 1u, 0u) | (_state->initialized_dimensions << 1u)); @@ -249,7 +249,7 @@ class PSSMLTSampler { void start_iteration() noexcept { _state->current_iteration = _state->current_iteration + 1ull; - _state->large_step = lcg(_state->rng_state) < _large_step_probability; + _state->large_step = _state->rng.uniform_float() < _large_step_probability; } };