Skip to content

Commit

Permalink
fix #1261: npe iid handling, remove batched x warning (#1262)
Browse files Browse the repository at this point in the history
* remove warning for batched x

* raise ValueError on batch observation
  • Loading branch information
janfb authored Sep 5, 2024
1 parent 829817b commit c33a855
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 23 deletions.
22 changes: 17 additions & 5 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,19 @@ def sample(
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
show_progress_bars: Whether to show sampling progress monitor.
"""

num_samples = torch.Size(sample_shape).numel()
x = self._x_else_default_x(x)
x = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
if x.shape[0] > 1:
raise ValueError(
".sample() supports only `batchsize == 1`. If you intend "
"to sample multiple observations, use `.sample_batched()`. "
"If you intend to sample i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)

max_sampling_batch_size = (
self.max_sampling_batch_size
Expand All @@ -132,7 +139,7 @@ def sample(
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
)[0]
)[0] # [0] to return only samples, not acceptance probabilities.

return samples[:, 0] # Remove batch dimension.

Expand Down Expand Up @@ -221,9 +228,14 @@ def log_prob(
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
assert (
x_density_estimator.shape[0] == 1
), ".log_prob() supports only `batchsize == 1`."
if x_density_estimator.shape[0] > 1:
raise ValueError(
".log_prob() supports only `batchsize == 1`. If you intend "
"to evaluate given multiple observations, use `.log_prob_batched()`. "
"If you intend to evaluate given i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)

self.posterior_estimator.eval()

Expand Down
15 changes: 0 additions & 15 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,21 +402,6 @@ def nle_nre_apt_msg_on_invalid_x(
)


def warn_on_batched_x(batch_size):
"""Warn if more than one x was passed."""

if batch_size > 1:
warnings.warn(
f"An x with a batch size of {batch_size} was passed. "
"Unless you are using `sample_batched` or `log_prob_batched`, this will "
"be interpreted as a batch of independent and identically distributed data"
" X={x_1, ..., x_n}, i.e., data generated based on the same underlying"
"(unknown) parameter. The resulting posterior will be with respect to"
" the entire batch, i.e,. p(theta | X).",
stacklevel=2,
)


def check_warn_and_setstate(
state_dict: Dict,
key_name: str,
Expand Down
3 changes: 1 addition & 2 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.distributions import Distribution, Uniform

from sbi.sbi_types import Array
from sbi.utils.sbiutils import warn_on_batched_x, within_support
from sbi.utils.sbiutils import within_support
from sbi.utils.torchutils import BoxUniform, atleast_2d
from sbi.utils.user_input_checks_utils import (
CustomPriorWrapper,
Expand Down Expand Up @@ -582,7 +582,6 @@ def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor:
x = x.unsqueeze(0)

input_x_shape = x.shape
warn_on_batched_x(batch_size=input_x_shape[0])

if x_event_shape is not None:
# Number of trials can change for every new x, but single trial x shape must
Expand Down
2 changes: 1 addition & 1 deletion tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
pytest.param(
2,
marks=pytest.mark.xfail(
raises=AssertionError,
raises=ValueError,
reason=".log_prob() supports only batch size 1 for x_o.",
),
),
Expand Down

0 comments on commit c33a855

Please sign in to comment.