Replies: 1 comment
-
Just responding with some questions that we can discuss:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
In order to enable more flexible losses and handling of return values, I would suggest the following changes.
Getting rid of
B
I suggest that the
forward
function of the inference network only receives what we call currentlyA
, instead ofA
andB
. AlsoLogRatioEstimator
only receivesA
components. The split into marginal and joined examples is handled then within that object, instead ofSwyftModule
.Pro:
Con:
A
and large number of prior samples inB
. Instead everything has to go throughA
. In order to make log ratio evaluation efficient, we probably should allow for different components ofA
having different batch sizes (such that the observation key can have batch size one, and the parameter key can have batch size 1024, for instance). That might lead to unforeseen problems.More informative return types & flexible losses
Instead of returning
LogRatioSamples
from the inference network, we directly return theLogRatioEstimator
objects (or similar other objects). We add a new method,calc_loss
, to these objects, which are then called during training. During evaluation, which doesn't require the contrastive examples, we can call another method, likeget_log_ratio_samples
.More general sampling objects
Right now the main sampling objects is
LogRatioSamples
, which is expected to contain weighted posterior samples, generated by unweighted prior samples pluslogratios
as weight. If we want to allow for NPE, or results based on for instance nested sampling, we need the ability to store and handle more general samples. In the case of NPE, we could easily handle this by settinglogratios
= 0, and weight all samples the same. But nested sampling leads to more involved reweighing factors for the dead points. I would suggest that in general we introduce aSample
object that can cover all these cases. What would be enough is to introduce two kinds of weights, one beingp(x|z)/p(x)
(ourlogratios
) and the other beingp(z)
(effectively constant for what we currently have). We can call thoselogratio
andlogprior
. Setting one or both of them to None would imply that they are constant. Thoughts about that?Infer function
Right now inference is handled through the
infer
method, which essentially callspredict
of lightning under the hood. It evaluates the network many times, with different batches of prior samples. If we generalise the framework to settings, it becomes less obvious how to do that, as the optimal solutions appear case-dependent. Other settings include:get_samples
method to the NPE object. Usage would be then to call the inference network once with the observed result, grab the returned NPE objects, and callget_samples
. I'm not entirely sure how to generate "truncated samples" in the case of NPE.get_samples
method would call a slice sampler. Thatget_samples
method would also require input regarding prior densities or hyper-cube mapping. We could also return truncated samples automatically. It is not clear to me whether this should be the same return type as the posterior samples (probably not, in order to avoid confusion).get_samples
calls the GEDA sampler. Prior information has to be provided as well. "truncated samples" would be here samples with a tempered likelihood. Again, it should be probably a different return type than posterior samples in order to avoid confusion.get_samples
option does not really work, since the entire point of the framework has been up to now to move simulator-generated prior samples through the ratio estimator. Doing that with aget_samples
function is less obvious, since it would then happen outside of the network. If we want to look at posteriors for derived parameters, those parameter derivations could not happen inside the network anymore, but would need to be passed as some kind of transformation hook to the LogRatioEstimator (in order to being able to apply this transformation on-the-fly when callingget_samples
on prior samples from the simulator).All of the above procedures should ideally return the sample
PosteriorSamples
object, in order to enable a uniform plotting and testing framework, and the same type of Truncated Samples, in order to feed it back to the simulator.It looks like we might end up with heterogeneous APIs for different use-cases. Is that acceptable or even desirable? What are the problems that this can bring further down the line? @NoemiAM @james-alvey-42, I would appreciate your thoughts about this.
Beta Was this translation helpful? Give feedback.
All reactions