Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi output #360

Merged
merged 26 commits into from
Sep 24, 2023
Merged

Multi output #360

merged 26 commits into from
Sep 24, 2023

Conversation

ingmarschuster
Copy link
Contributor

@ingmarschuster ingmarschuster commented Aug 14, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Added multi-output capability with coregionalization, where the dataset is expected to have multiple columns in y for the different output dimensions. Trainable output correlations can be achieved via a categorical CatKernel used as the out_kernel field in the prior.

The advantage of this approach is that the Kronecker structure of coregionalization, i.e. the fact that K = kron(Kxx, Kyy) can be used to efficiently invert K. When instead we leave the Dataset.y array to be a single dimension and artificially add the index of the output dimension to the input points, the information about the Kronecker structure is lost and efficiency goes away. Unless doing some smart trick of course.

Issue Number: N/A

@ingmarschuster
Copy link
Contributor Author

ingmarschuster commented Aug 24, 2023

@henrymoss Can you review this? The regression_mo.py can be used to try things out.

Main adaptations in the predict() method of Prior, ConjugatePosterior and likelihood Gaussian.

Would need to add the changes toNonConjugatePosterior and likelihoods Bernoulli, Poisson as well and write tests. If somebody can help out here that would be great!

@ingmarschuster ingmarschuster marked this pull request as ready for review August 24, 2023 08:25
@henrymoss
Copy link
Contributor

Is there a way for me to just see the MO differences? At the moment its all mixed in with the masking stuff

@ingmarschuster
Copy link
Contributor Author

Yes, just go to
mask-missing...multi-output

Maybe I was a bit too optimistic with the mask PR being merged into main ;)

@ingmarschuster
Copy link
Contributor Author

one thing I [dislike is trying] to do too much out of the box and then things [get] messy

I'm not sure how to do this using a wrapper such as MultiOutputLikelihood around univariate likelihoods as suggested by Ti on slack. Mainly because I think we need access to not only the abstract distribution of a univariate likelihood, but also to parameters like the covariance matrix if the likelihood is gaussian.
One obvious option to require the user to be more explicit would be to have multivariate versions of the different likelihoods.

@ingmarschuster
Copy link
Contributor Author

I’ll think about a wrapper solution, as I also think its much nicer. If you have “how” ideas for the design, I’d be very happy

@ingmarschuster
Copy link
Contributor Author

Mask-missing is now merged, so all changes you see in diffs are agains main @henrymoss

@ingmarschuster
Copy link
Contributor Author

ingmarschuster commented Aug 30, 2023

Ok, I think I can draft a solution that is explicit like Ti suggested. For the likelihood it can, after all, be a pretty simple wrapper I think. For the Posterior, things get slightly more involved, but you can introduce classes that encapsulate the output logic.

Here is a draft. Its rough around the edges, and the method naming is not right, but it illustrates the point so we can discuss:

class AbstractOutput:
    def calculate_Sigma(self, x, y, n, obs_noise, jitter, mask, Kxx):
        raise NotImplementedError()

    def calculate_Ktt_and_Kxt(self, t, x, n_test, mask):
        raise NotImplementedError()

    def calculate_distribution(self, n_test, m, mean, covariance):
        raise NotImplementedError()

class UnivariateOutput(AbstractOutput):
    def calculate_Sigma(self, x, y, n, obs_noise, jitter, mask, Kxx):
        Sigma = Kxx + identity(n) * (jitter + obs_noise)

        if mask is not None:
            y = jnp.where(mask, 0.0, y)
            mx = jnp.where(mask, 0.0, mx)
            Sigma_masked = jnp.where(mask + mask.T, 0.0, Sigma.matrix)
            Sigma = Sigma.replace(
                matrix=jnp.where(
                    jnp.diag(jnp.squeeze(mask)), 1 / (2 * jnp.pi), Sigma_masked
                )
            )

        return Sigma

    def calculate_Ktt_and_Kxt(self, t, x, n_test, mask):
        Ktt = self.prior.kernel.gram(t)
        Kxt = self.prior.kernel.cross_covariance(x, t)

        if mask is not None:
            Kxt = jnp.where(mask * jnp.ones((1, n_test), dtype=bool), 0.0, Kxt)

        return Ktt, Kxt

    def calculate_distribution(self, n_test, m, mean, covariance):
        return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)

class MultivariateOutput(AbstractOutput):
    def __init__(self, out_kernel):
        self.out_kernel = out_kernel

    def calculate_Sigma(self, x, y, n, obs_noise, jitter, mask, Kxx):
        m = y.shape[1]
        n_X_m = n * m
        Kyy = self.out_kernel.gram(jnp.arange(m)[:, jnp.newaxis])
        Sigma = DenseLinearOperator(jnp.kron(Kxx.to_dense(), Kyy.to_dense()))
        Sigma += identity(n_X_m) * (jitter + obs_noise)

        if mask is not None:
            y = jnp.where(mask, 0.0, y)
            mx = jnp.where(mask, 0.0, mx)
            Sigma_masked = jnp.where(mask + mask.T, 0.0, Sigma.matrix)
            Sigma = Sigma.replace(
                matrix=jnp.where(
                    jnp.diag(jnp.squeeze(mask)), 1 / (2 * jnp.pi), Sigma_masked
                )
            )

        return Sigma

    def calculate_Ktt_and_Kxt(self, t, x, n_test, mask):
        m = y.shape[1]
        Kyy = self.out_kernel.gram(jnp.arange(m)[:, jnp.newaxis])
        Ktt = jnp.kron(self.prior.kernel.gram(t).to_dense(), Kyy.to_dense())
        Kxt = jnp.kron(self.prior.kernel.cross_covariance(x, t), Kyy.to_dense())

        if mask is not None:
            Kxt = jnp.where(mask * jnp.ones((1, n_test), dtype=bool), 0.0, Kxt)

        return Ktt, Kxt

    def calculate_distribution(self, n_test, m, mean, covariance):
        rval = GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
        if m == 1:
            return rval
        else:
            return ReshapedDistribution(rval, (n_test, m))

Then the predict method of the posterior can look something like this, under the assumption that the posterior is constructed with an output_model (defaults to UnivariateOutput):

def predict(self, test_inputs: Num[Array, "N D"], train_data: Dataset) -> GaussianDistribution:
    x, y, n, mask = train_data.X, train_data.y, train_data.n, train_data.mask
    m = y.shape[1]
    if m > 1 and mask is not None:
        mask = mask.flatten()

    t, n_test = test_inputs, test_inputs.shape[0]

    obs_noise = self.likelihood.obs_noise
    mx = self.prior.mean_function(x)
    mean_t = self.prior.mean_function(t)
    Kxx = self.prior.kernel.gram(x)

    Sigma = self.output_model.calculate_Sigma(x, y, n, obs_noise, self.prior.jitter, mask, Kxx)
    Ktt, Kxt = self.output_model.calculate_Ktt_and_Kxt(t, x, n_test, mask)

    Sigma_inv_Kxt = Sigma.solve(Kxt)

    mean = mean_t.flatten() + jnp.matmul(Sigma_inv_Kxt.T, (y - mx).flatten())

    covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
    covariance += identity(n_test * m) * self.prior.jitter

    return self.output_model.calculate_distribution(n_test, m, mean, covariance)

What do you think?

For reference, the GPytorch API provides

gpytorch.means.MultitaskMean(gpytorch.means.ConstantMean(), num_tasks=2 )
gpytorch.kernels.MultitaskKernel(gpytorch.kernels.RBFKernel(), num_tasks=2, rank=1)

@thomaspinder
Copy link
Collaborator

Hey @ingmarschuster, for me, the AbstractOutput approach is cleaner. I feel it demands less of the user, is slightly more intuitive, and keeps the underlying code more didactic. Perhaps this was always your intention, but maybe the user does not even need to specify the output class i.e., MultivariateOutput and it could instead be formed by passing num_output or num_task into the prior initialiser? I agree on tidying up the naming e.g., calculate_Sigma -> gram or calculate_covariance.

@daniel-dodd
Copy link
Member

Hey nice work @ingmarschuster. Great to see MOGPs coming along.

Just perhaps a naive question though, is there a benefit of this PR, over using another input dimension to the Kernel (defining output indicies) and using a Kronecker product kernel defined on k((i, x), (j, y)) = k(i, j) k(x, y) - just here we would actually be computationally efficient with regards to LinearOperator structure, while your code assumes everything is dense. (We actually nearly support this - following CoLA PR merger #370 , all we need is a Kronecker compute engine #371 - which is a quick PR). We could create an MOGP wrapper to do so in 5 lines of code, and would have all the computational benefits retained, would be simpler to read, would be extensible.

y # Flattened to [N, 1]
x # [output_dimension, covariates]

k_out = Kernel(active_dims = output_dimension)
k_cov = Kernel(active_dims = covariates)
k = Kronecker([k_out, k_cov])
gp = Prior(kernel = k)

As this seems to be the approach you are trying to mimic? Or maybe I am missing something?

@ingmarschuster
Copy link
Contributor Author

ingmarschuster commented Sep 2, 2023

@thomaspinder: passing num_output / num_task into the prior would then force one single multi-output handling onto the user like e.g. coregionalization. I'd be ok with this in principle, just to say that explicitly passing the multi-output handling object gives more flexibility (and we can have a sensible default).
@daniel-dodd: Denseness was used because I was too lazy for a Kronecker operator in the current LinOps architecture and knew Cola integration was coming along, would make this very easy, and have the computational benefits. I totally understand how the CoLa Kronecker would work on the matrix level – on the kernel level I do not see it working at the moment. But maybe you do, since you worked on the CoLa PR? Actually the problem I saw with the approach of making output index an input dimension is exactly the fact that guaranteeing efficient Kronecker calculation would be difficult if some output indices only occur occasionally in the input.

Copy link
Contributor

@henrymoss henrymoss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry its taken so long for me to get to this. Its been my teaching weeks....

I think the approach is nice. We just need to think about how to do the efficient kronecker matrix inverse.

I also wonder if the ReshapedDistribution is a bit too far? You could just have some code that maps stuff into the right shapes where we need it (i.e. in the likelhodd and posteriors?).

Also, perhaps Im being stupid. But is it possible with this setup to have different kernels for each output?

gpjax/dataset.py Outdated
Comment on lines 42 to 49
mask: (Optional[Union[Bool[Array, "N Q"], Literal["infer automatically"]]]): mask for the output data.
Users can optionally specify a pre-computed mask, or explicitly pass `None` which
means no mask will be used. Defaults to `"infer automatically"` which means that
the mask will be computed from the output data, or set to `None` if no output data is provided.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will be its own PR?

@daniel-dodd
Copy link
Member

@ingmarschuster, @henrymoss, @thomaspinder, thinking about this further, these are my current thoughts on this PR.

I find some aspects of this approach quite appealing, and I believe there are strong elements we can build upon. However, I have reservations about merging this PR directly into the main GPJax codebase in its current state, with the implication that it would become a stable code design for several releases.

My primary concern is that this implementation seems tailored exclusively for conjugate priors, rendering it incompatible with various other components of the library, including non-conjugate priors, variational families, approximations like RFF, and pathwise sampling. To make it compatible, substantial code modifications would be necessary.

However, I believe there are simpler ways to achieve the same functionality introduced by this PR. These approaches would not necessitate as extensive changes to existing code, and would offer enhanced efficiency and extensibility for various modelling scenarios, including heterotopic and isotopic cases.

A logical starting point would be to restructure the kernel and incorporate structured likelihoods. In my view, some of the functionality found in your AbstractOutput's methods should be distributed between these components, and I find specifying the out_kernel within the Prior feels overly restrictive; it aligns better with the kernel's responsibilities. This reorganisation of the code would provide greater flexibility and result in minimal compatibility issues with the rest of the codebase. Now with the (albeit minimal) CoLA integration, we are in a good position to start thinking about this.

Nonetheless, I appreciate the innovative aspects of your work, @ingmarschuster. There are undoubtedly neat computational structure exampled here, and I believe this can be effectively utilised with careful kernel and likelihood design. To keep the momentum going, I would recommend that this PR in its current state, be designated for a gpjax.experimental submodule. This would allow us to develop and refine it more freely until we achieve a sufficient level of compatibility and extensibility.

@ingmarschuster
Copy link
Contributor Author

@daniel-dodd The current implementation is only for conjugate priors mostly because I wanted to be sure that the design is well-received before continuing implementation. However the approach is easy to implement for Non-Conjugate models etc, since its basically just making the output dimension index into a kernel input/coregionalization. Variational, pathwise sampling and RFF I didn't think about too deeply, but don't see a problem.

If you think you want to integrate structured likelihoods and rework the code like that, of course this is also a viable way, but I will probably not be able to do it. I'm open to gpjax.experimental integration on the other hand. Just maybe you and @thomaspinder decide either way, because I wouldn't want to spend time on it if its not used.

@daniel-dodd
Copy link
Member

Hey @ingmarschuster, thanks for your quick reply and apologies for my slow one.

That sounds good to me. The main intention of the previous messages was to assess how this fits into the broader developmental plan. Of course, a minimal working implementation is the best place to start - but nowhere did you describe this as such - please briefly mention such known limitations on the PR notes in future so we can open corresponding issues.

Maybe a broader concern I did have is that the prior and conjugate posterior objects are starting to get busy with code - perhaps a bit out of touch with didacticity. But, to get the ball rolling, please feel free to merge as is, but I would recommend an experimental flag to avoid frustrating users by saying we support a stable MOGP version when we are still yet to finalise things and address its integration into the rest of the library.

During your rebase you may find it useful to use CoLA’s Kronecker linear operator in place of your dense jnp.kron for efficiency - to facilitate an actual computational benefit here of your approach. :)

I will shortly open a GitHub discussion where I believe we can simplify your implementation and be agnostic to the type of MOGP while achieving full compatibility with the rest of the library. I would be happy to lead such developmental work and would greatly appreciate your insights — as there are certain things my mental model could do with learning and improving.

@ingmarschuster ingmarschuster merged commit 83fdab2 into main Sep 24, 2023
14 checks passed
@st-- st-- deleted the multi-output branch November 30, 2023 10:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants