Fix excessive compilations in the MC sampler #79
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was an issue with a slow but steady memory growth during TDVP as described here: [#78]
If one executes
jax.clear_cache()
every few steps of the TDVP, this issue is fixed. This points to an unnecessary JIT compilation happening somewhere during the TDVP step.One can check that the program indeed performs some compilations during TDVP steps by setting the environment variable
JAX_LOG_COMPILES=1
and observing the output.I found that the problem was caused by the
MCSampler
object throughself._mc_init(params)
function which calculatesself.logAccProb
. A pmap'd version of the function is compiled every time one asks the sampler for samples . A solution is to compile this function once and store as the object property. This PR realizes that.In a typical use-case, the speed-up due to avoidance of recompilations is minor, but after this change the consumed RAM memory does not grow indefinitely after each
sampler.sample()
call.