-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax autodiff incompatible with SVD? #147
Comments
Hi @nikitn2, turning on compressed contraction through the usual (exact) The general problem here is the difficulty of choosing minimal default arguments/options for these advanced algorithms which don't hide the details. E.g. compressed contraction shouldn't be used with the current default exact Thanks for the kind words about quimb! |
Hi @jcmgray, Thanks for the reply – your explanation makes perfect sense. It's probably too much to expect an autodifferencing framework to easily handle dynamically-sized arrays, and I can indeed see how difficult it is to integrate these considerations into a numerical library as advanced as yours, while still keeping quimb simple to use. Dilemmas...! Your
Is there perhaps a way for me to just globally turn off dynamic shaping of tensors when using autodiff? |
Hi @jcmgray, I've been playing around a bit with this issue in my code and I find that Jax basically doesn't like it when the bond-dimensions are in any way dynamically allocated. I've found two fixes so far. In my example above, if you pass Do you think there is anything I can do to get the Jax autodifferencing to work with compression in a more stable manner? |
Just a couple of things:
|
Hi @jcmgray, Thanks very much for the reply! You're right, it does indeed make perfect sense that the above example should be numerically unstable. Though the above example is not my main use case. My main use case involves applying operators My first idea was to contract
After applying the fix I mentioned in my previous post, I now get a Jax "ConcretizationTypeError", originating at the Giving up on this, I tried just computing the four terms of
The only way I can compute and minimise Do you have any idea what I could do to proceed? I tried using tensorflow instead of Jax, and it actually worked with compression turned on, albeit it's almost two orders of magnitude slower than Jax... EDIT: I'm thinking of just restricted |
I see, so its a fitting task (
This is saying that
Are you sure the complexity of even the zip up algrorithm is not similar as exact contraction
What the compressed contraction algorithm does depends entirely on the contraction path ( That being said, if the MPO has a really large bond dimension, maybe that's different enough from '1D and tree like' to apply some compression somewhere, you could run a |
HI @jcmgray, Thanks so much for your reply!
It's a bit worse than that I'm afraid, as I'm basically dealing with a highly nonlinear problem. In my case Therefore, even if I represent So to caculate
I'll try to investigate this Thanks again for your replies – you've no idea how much time they save for me, and for that I'm incredibly grateful :) |
I see, yes I just meant if you were only interested in finding
Only that you need to find where your implementation of the algorithm calls |
Ah, I see now the confusion re the Loss function: my variational function is Psi, not Phi! Sorry, I forgot to mention that. So I need to always hard–set the bond dimension, or it might perform compression even if cutoff=0.0? That makes sense, actually. Thanks a lot! |
What is your issue?
Hello,
Lately I've been trying to use jax-autodifferencing to minimise a certain loss-function of mine, albeit I keep running into errors like “unhashable type: 'DeviceArray'” or “ConcretizationTypeError”. They happen when I try to use compression or tensor.split(), which makes me think that JAX has some sort of incompatibility with SVDs, since SVDs are typically used to compress bonds and split tensors.
Below is a reproduction of the issue based on Chapter 9 in your user guide. By setting chi = None, this code will run, but when chi = 8 and compression is performed, the code crashes with the following error message:
I find this particularly strange given that the code in Chapter 4.8 works just fine for me despite it also making use of jax-autodifferencing along with compression (max_bond is limited to 32 in compute_local_expectation() ).
Do you have any idea what might be wrong?
And thanks very much for this excellent numerical library – I LOVE quimb!
Cheers
The text was updated successfully, but these errors were encountered: