-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pascal, Hilbert and/or other special matrices in jax.scipy.linalg
#10144
Comments
I think these would be useful contributions. Regarding jit-compatibility, You might think about pre-compiling these with If you're willing to put together a PR, I'd be happy to review! |
We seem to have a |
Thanks for the suggestions! I will open a PR soon then, and let you know in case I have more questions. |
#10161 adds the Hilbert matrix. Once this is through, I will add the Pascal matrix as well (but since its implementation might be less obvious, I decided to split them into 2 PRs.) |
Hi @jakevdp I hope you're doing well. Can I open a PR for this issue? |
Hi!
scipy.linalg
implements special matrices. At the moment,jax.scipy.linalg
does not do that, but it would be useful for me to have access to Pascal and Hilbert matrices viajax
. (I need them for some state-space modelling stuff, if that is relevant information.)I could make a pull request if it helps :)
As a starting point, I will explain what I am doing currently below, and I will also explain some difficulties with
jit
ting that might need to be resolved before a pull request?Related issues
Hilbert matrix
For example, a Hilbert matrix would be rather straightforward to implement, via
A problem I could see with this implementation is appropriate
jit
andvmap
behaviour: the size of the output depends onn
:( . This is true for most special matrices, unfortunately.Therefore, in my own code, I have split the function into something like
which at least allows to
jit
parts of the implementation.An alternative would be to provide a Hankel matrix (where the shape of the matrix is determined by the shapes of the input vectors, and which should be easier to
jit
), and then let users like me do the Hilbert-via-Hankel conversion depending on each application. Scipy's Hankel matrix, however, usesnp.lib.as_strided
, and I understand from issue #3171 that this is not intended to be supported.But this can probably be resolved, if one decides to go down that road.
Pascal matrix
Pascal matrices are a little less obvious to implement than Hankel/Hilbert matrices, because they use binomial coefficients. For my applications, I have so far computed those via
jax.lax.exp(jax.lax.lgamma(n+1.))
(becausejnp.prod(jnp.arange(1, n+1))
hasjit
issues again; if I remember correctly, I took the gamma-function-idea fromjet.py
) and then vmapped a factorial implementation viaThis is slightly different to
scipy.linalg.pascal
, however, because my implementation returns only floats and a general Pascal matrix should do integers, at least forscipy.linalg.pascal(..., exact=True)
. And the implementation would be fairly different to , which might also not be desired.Long story short
Hilbert and Pascal matrices would be useful to have, and if the jitting issues can be resolved (or do not matter too much?) I can make a pull request.
What do you think?
The text was updated successfully, but these errors were encountered: