Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix excessive compilations in the MC sampler #79

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

tszoldra
Copy link
Contributor

There was an issue with a slow but steady memory growth during TDVP as described here: [#78]

If one executes jax.clear_cache() every few steps of the TDVP, this issue is fixed. This points to an unnecessary JIT compilation happening somewhere during the TDVP step.
One can check that the program indeed performs some compilations during TDVP steps by setting the environment variable JAX_LOG_COMPILES=1 and observing the output.

I found that the problem was caused by the MCSampler object through self._mc_init(params) function which calculates self.logAccProb. A pmap'd version of the function is compiled every time one asks the sampler for samples . A solution is to compile this function once and store as the object property. This PR realizes that.

In a typical use-case, the speed-up due to avoidance of recompilations is minor, but after this change the consumed RAM memory does not grow indefinitely after each sampler.sample() call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant