-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance & runtime improvements to info-theoretic acquisition functions (1/N) #2748
base: main
Are you sure you want to change the base?
Conversation
improve performance and runtime of PES/JES
@sdaulton has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Thanks! It seems like |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2748 +/- ##
=======================================
Coverage 99.99% 99.99%
=======================================
Files 203 203
Lines 18690 18701 +11
=======================================
+ Hits 18689 18700 +11
Misses 1 1 ☔ View full report in Codecov by Sentry. |
@sdaulton for sure! I currently observe similar things for JES, but I'm not sure whether the found points are actually higher in acquisition function value or not (for either LogEI or JES) |
That would be interesting to see |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Carl! This seems like a decent improvement. Just a few comments in-line
raw_samples: int = 2048, | ||
num_restarts: int = 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the motivation that raw_samples
are cheap, so we can use more of them to get better restart points, which in turn helps us reduce num_restarts
, which can be rather expensive in relation?
@@ -1008,13 +1012,17 @@ def optimize_posterior_samples( | |||
negate the objective or otherwise transform the output. | |||
return_transformed: A boolean indicating whether to return the transformed | |||
or non-transformed samples. | |||
suggested_points |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's complete the docstring here
bounds=bounds, n=round(raw_samples * frac_random), q=1 | ||
).squeeze(-2) | ||
if suggested_points is not None: | ||
from botorch.optim.initializers import sample_truncated_normal_perturbations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a local import, because it leads to cyclical dependencies? If so, we could move sample_truncated_normal_perturbations
under utils (assuming it doesn't depend other code in optim).
|
||
perturbed_suggestions = sample_truncated_normal_perturbations( | ||
X=suggested_points, | ||
n_discrete_points=round(raw_samples * (1 - frac_random)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if n
here (or in candidate_set above) is 0
? Should we protect against it or add an informative error? Since raw_samples
is typically quite large, this should be unlikely to happen, so not really critical to address here.
candidate_set = draw_sobol_samples( | ||
bounds=bounds, n=round(raw_samples * frac_random), q=1 | ||
).squeeze(-2) | ||
if suggested_points is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If suggested_points is None
, we end up with a candidate_set
of size smaller than raw_samples
. Should we make sure we always use raw_samples
points?
weights = ( | ||
candidate_queries - candidate_queries.mean(dim=-1, keepdim=True) | ||
) / candidate_queries.std(dim=-1, keepdim=True) | ||
eta = options.get("eta", 2.0) | ||
weights = torch.exp(eta * weights) | ||
|
||
# weights can be more than 2D, which is not supported by torch.multinomial | ||
# the argsort picks out the indices that are nonzero, i.e. those that are drawn | ||
# (without replacement, so we will always have num_restarts nonzero ones) | ||
idx = ( | ||
Multinomial(num_restarts, probs=weights) | ||
.sample() | ||
.argsort(descending=True)[..., :num_restarts] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a very similar logic to initialize_q_batch
. Would it make sense to re-use that?
A series of improvements directed towards improving the performance of PES & JES, as well as their MultiObj counterparts.
Motivation
As pointed out by @SebastianAment in this paper, the BoTorch variant of JES, and to an extent PES, is brutally slow an suspiciously ill-performing. To bring them up to their potential, I've added a series of performance improvements:
1. Improvement to get_optimal_samples and optimal_posterior_samples: As this is an integral part of their efficiency, I've added

.
suggestions
(similar tosample_around_best
) tooptimize_posterior_samples
.Marginal runtime improvement in acquisition optimization (sampling time practically unchanged):
Substantial performance improvement:
2. Added initializer to acquisition funcction optimization: Similar to KG, ES methods have sensible suggestions for acquisition function optimization in the form of the sampled optima. This drastically reduces the time of acquisition function optimization, which could on occasion take 30+ seconds when
num_restarts
was large>4
.Benchmarking INC
2b. Multi-objective support for initializer: By re-naming arguments of the multi-objective variants, we get consistency and support for MO variants.
3. Enabled gradient-based optimization for PES: The current implementation contains a while-loop which forces the gradients to be recursively computed. This commonly causes NaN gradients, which is why the recommended option is
"with_grad": False
in the tutorial. Onedetach()
alleviates this issue, enabling gradient-based optimization.NOTE: this has NOT been ablated, since the non-grad optimization is extremely computationally demanding.
Test Plan
Unit tests and benchmarking.
Related PRs
First of a couple!
Bonus: while benchmarking, I had issues repro'ing the LogEI performance initially. I found that

sample_around_best
made LogEI worse on Mich5. All experiments are otherwise a repro of the settings used in the LogEI paper.