Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend_params fix after n_particles removed #62

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

andrewdipper
Copy link

Fix examples after change to extend_params in blackjax-devs/blackjax#694.

Additionally in the TemperedSMC example max_num_doublings was changed to 6 instead of the default 10 since we regularly hit max_num_doublings due to the small step size (I believe this is for illustrative purposes). On a gpu device the example is extraordinarily slow without the change - and still takes ~2 mins with it. It seems far too slow but I haven't been able to find any explanation.

For reference:

CPU (10000 samples, max_num_doublings=10):
step_size = 1e-2:
HMC: 50 steps / 1.14s
NUTS: 30 steps / .964s

step_size = 1e-3
HMC: 50 / 1.14s
NUTS: 273 / 1.9s

step_size = 1e-4
HMC: 50 / 1.18s
NUTS: 926 / 4.23s

GPU (1000 samples - 10x fewer samples..., max_num_doublings=10):
step_size = 1e-2:
HMC: 50 / 3.31s
NUTS: 30 / 7.3s

step_size = 1e-3:
HMC: 50 / 3.32s
NUTS: 267 / 63s

step_size = 1e-4
HMC: 50 / 3.31s
NUTS: 926.4 / 237s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant