-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi output #360
Multi output #360
Conversation
@henrymoss Can you review this? The Main adaptations in the Would need to add the changes to |
Is there a way for me to just see the MO differences? At the moment its all mixed in with the masking stuff |
a2f1a6b
to
054ed98
Compare
Yes, just go to Maybe I was a bit too optimistic with the mask PR being merged into main ;) |
I'm not sure how to do this using a wrapper such as |
I’ll think about a wrapper solution, as I also think its much nicer. If you have “how” ideas for the design, I’d be very happy |
Mask-missing is now merged, so all changes you see in diffs are agains main @henrymoss |
Ok, I think I can draft a solution that is explicit like Ti suggested. For the likelihood it can, after all, be a pretty simple wrapper I think. For the Posterior, things get slightly more involved, but you can introduce classes that encapsulate the output logic. Here is a draft. Its rough around the edges, and the method naming is not right, but it illustrates the point so we can discuss:
Then the predict method of the posterior can look something like this, under the assumption that the posterior is constructed with an
What do you think? For reference, the GPytorch API provides
|
Hey @ingmarschuster, for me, the |
Hey nice work @ingmarschuster. Great to see MOGPs coming along. Just perhaps a naive question though, is there a benefit of this PR, over using another input dimension to the Kernel (defining output indicies) and using a Kronecker product kernel defined on k((i, x), (j, y)) = k(i, j) k(x, y) - just here we would actually be computationally efficient with regards to LinearOperator structure, while your code assumes everything is dense. (We actually nearly support this - following CoLA PR merger #370 , all we need is a Kronecker compute engine #371 - which is a quick PR). We could create an MOGP wrapper to do so in 5 lines of code, and would have all the computational benefits retained, would be simpler to read, would be extensible.
As this seems to be the approach you are trying to mimic? Or maybe I am missing something? |
@thomaspinder: passing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry its taken so long for me to get to this. Its been my teaching weeks....
I think the approach is nice. We just need to think about how to do the efficient kronecker matrix inverse.
I also wonder if the ReshapedDistribution is a bit too far? You could just have some code that maps stuff into the right shapes where we need it (i.e. in the likelhodd and posteriors?).
Also, perhaps Im being stupid. But is it possible with this setup to have different kernels for each output?
gpjax/dataset.py
Outdated
mask: (Optional[Union[Bool[Array, "N Q"], Literal["infer automatically"]]]): mask for the output data. | ||
Users can optionally specify a pre-computed mask, or explicitly pass `None` which | ||
means no mask will be used. Defaults to `"infer automatically"` which means that | ||
the mask will be computed from the output data, or set to `None` if no output data is provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this will be its own PR?
@ingmarschuster, @henrymoss, @thomaspinder, thinking about this further, these are my current thoughts on this PR. I find some aspects of this approach quite appealing, and I believe there are strong elements we can build upon. However, I have reservations about merging this PR directly into the main GPJax codebase in its current state, with the implication that it would become a stable code design for several releases. My primary concern is that this implementation seems tailored exclusively for conjugate priors, rendering it incompatible with various other components of the library, including non-conjugate priors, variational families, approximations like RFF, and pathwise sampling. To make it compatible, substantial code modifications would be necessary. However, I believe there are simpler ways to achieve the same functionality introduced by this PR. These approaches would not necessitate as extensive changes to existing code, and would offer enhanced efficiency and extensibility for various modelling scenarios, including heterotopic and isotopic cases. A logical starting point would be to restructure the kernel and incorporate structured likelihoods. In my view, some of the functionality found in your Nonetheless, I appreciate the innovative aspects of your work, @ingmarschuster. There are undoubtedly neat computational structure exampled here, and I believe this can be effectively utilised with careful kernel and likelihood design. To keep the momentum going, I would recommend that this PR in its current state, be designated for a |
@daniel-dodd The current implementation is only for conjugate priors mostly because I wanted to be sure that the design is well-received before continuing implementation. However the approach is easy to implement for Non-Conjugate models etc, since its basically just making the output dimension index into a kernel input/coregionalization. Variational, pathwise sampling and RFF I didn't think about too deeply, but don't see a problem. If you think you want to integrate structured likelihoods and rework the code like that, of course this is also a viable way, but I will probably not be able to do it. I'm open to gpjax.experimental integration on the other hand. Just maybe you and @thomaspinder decide either way, because I wouldn't want to spend time on it if its not used. |
Hey @ingmarschuster, thanks for your quick reply and apologies for my slow one. That sounds good to me. The main intention of the previous messages was to assess how this fits into the broader developmental plan. Of course, a minimal working implementation is the best place to start - but nowhere did you describe this as such - please briefly mention such known limitations on the PR notes in future so we can open corresponding issues. Maybe a broader concern I did have is that the prior and conjugate posterior objects are starting to get busy with code - perhaps a bit out of touch with didacticity. But, to get the ball rolling, please feel free to merge as is, but I would recommend an experimental flag to avoid frustrating users by saying we support a stable MOGP version when we are still yet to finalise things and address its integration into the rest of the library. During your rebase you may find it useful to use CoLA’s I will shortly open a GitHub discussion where I believe we can simplify your implementation and be agnostic to the type of MOGP while achieving full compatibility with the rest of the library. I would be happy to lead such developmental work and would greatly appreciate your insights — as there are certain things my mental model could do with learning and improving. |
* using Co-regionalization & trainable output correlations * prediction needs to be reshaped currently to reflect multiple outputs
Necessitates correcting type hints
(Num instead of Float arrays)
eba682a
to
9bcc366
Compare
to include reshaped distributions
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
Added multi-output capability with coregionalization, where the dataset is expected to have multiple columns in
y
for the different output dimensions. Trainable output correlations can be achieved via a categoricalCatKernel
used as theout_kernel
field in the prior.The advantage of this approach is that the Kronecker structure of coregionalization, i.e. the fact that
K = kron(Kxx, Kyy)
can be used to efficiently invertK
. When instead we leave theDataset.y
array to be a single dimension and artificially add the index of the output dimension to the input points, the information about the Kronecker structure is lost and efficiency goes away. Unless doing some smart trick of course.Issue Number: N/A