Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pascal, Hilbert and/or other special matrices in jax.scipy.linalg #10144

Open
pnkraemer opened this issue Apr 5, 2022 · 5 comments
Open

Pascal, Hilbert and/or other special matrices in jax.scipy.linalg #10144

pnkraemer opened this issue Apr 5, 2022 · 5 comments
Labels
enhancement New feature or request good first issue Good for newcomers P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@pnkraemer
Copy link
Contributor

pnkraemer commented Apr 5, 2022

Hi!

scipy.linalg implements special matrices. At the moment, jax.scipy.linalg does not do that, but it would be useful for me to have access to Pascal and Hilbert matrices via jax. (I need them for some state-space modelling stuff, if that is relevant information.)

I could make a pull request if it helps :)
As a starting point, I will explain what I am doing currently below, and I will also explain some difficulties with jitting that might need to be resolved before a pull request?

Related issues

Hilbert matrix

For example, a Hilbert matrix would be rather straightforward to implement, via

def hilbert(n):
    a = jnp.arange(n)
    return 1 / (a[:, None] + a[None, :] + 1)

A problem I could see with this implementation is appropriate jit and vmap behaviour: the size of the output depends on n :( . This is true for most special matrices, unfortunately.

Therefore, in my own code, I have split the function into something like

def hilbert(n):
    a = jnp.arange(n)
    return _hilbert(a)

@jax.jit
def _hilbert(a):
    return 1 / (a[:, None] + a[None, :] + 1)

which at least allows to jit parts of the implementation.
An alternative would be to provide a Hankel matrix (where the shape of the matrix is determined by the shapes of the input vectors, and which should be easier to jit), and then let users like me do the Hilbert-via-Hankel conversion depending on each application. Scipy's Hankel matrix, however, uses np.lib.as_strided, and I understand from issue #3171 that this is not intended to be supported.
But this can probably be resolved, if one decides to go down that road.

Pascal matrix

Pascal matrices are a little less obvious to implement than Hankel/Hilbert matrices, because they use binomial coefficients. For my applications, I have so far computed those via jax.lax.exp(jax.lax.lgamma(n+1.)) (because jnp.prod(jnp.arange(1, n+1)) has jit issues again; if I remember correctly, I took the gamma-function-idea from jet.py) and then vmapped a factorial implementation via

def pascal(n):
    return _pascal(jnp.arange(n))

@jax.jit
def _pascal(a):
    return _binom(a[:, None], a[None, :])


def _broadcast_to_matrix(k):
    """Take a function f(scalar, scalar)-> scalar and return a function g((n,m) array, (m, k) array) -> (n,k) array"""
    k_vmapped_x = jax.vmap(k, in_axes=(0, None), out_axes=-1)
    k_vmapped_xy = jax.vmap(k_vmapped_x, in_axes=(None, 1), out_axes=-1)
    return jax.jit(k_vmapped_xy)


@_broadcast_to_matrix
@jax.jit
def _binom(n, k):
    a = _factorial(n)
    b = _factorial(n - k)
    c = _factorial(k)
    return a / (b * c)


@jax.jit
def _factorial(n):
    return jax.lax.exp(jax.lax.lgamma(n + 1.0))

This is slightly different to scipy.linalg.pascal, however, because my implementation returns only floats and a general Pascal matrix should do integers, at least for scipy.linalg.pascal(..., exact=True). And the implementation would be fairly different to how scipy does it, which might also not be desired.

Long story short

Hilbert and Pascal matrices would be useful to have, and if the jitting issues can be resolved (or do not matter too much?) I can make a pull request.

What do you think?

@pnkraemer pnkraemer added the enhancement New feature or request label Apr 5, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 5, 2022

I think these would be useful contributions. Regarding jit-compatibility, You might think about pre-compiling these with static_argnames (search through the JAX source to see other examples of this).

If you're willing to put together a PR, I'd be happy to review!

@jakevdp jakevdp added the P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) label Apr 5, 2022
@zhangqiaorjc
Copy link
Collaborator

We seem to have a static_argnames example in jax.jit documentation itself, see the last example https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html?highlight=static_argnames#jax.jit

@pnkraemer
Copy link
Contributor Author

Thanks for the suggestions! I will open a PR soon then, and let you know in case I have more questions.

@pnkraemer
Copy link
Contributor Author

#10161 adds the Hilbert matrix. Once this is through, I will add the Pascal matrix as well (but since its implementation might be less obvious, I decided to split them into 2 PRs.)

@jakevdp jakevdp added good first issue Good for newcomers P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR and removed P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) labels Nov 8, 2023
@aaronatp
Copy link

Hi @jakevdp I hope you're doing well. Can I open a PR for this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

4 participants