-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change _integrate.py to make differentiation possible over vmapped diffeqsolve #569
Conversation
Thanks for identifying a fix! This looks great. Can you add a test asserting that the bug in #568 isn't hit? I'd hate to accidentally break this for you because of some other change. |
Hope this works! :) |
seems like theres an issue. ill have a look tomorrow. sorry my bad probably |
Okay catches the issue now and test doesn't throw an error. Sorry it was throwing an error on the first go! |
Looks like pre-commits are failing. Take a look at CONTRIBUTING.md for reference on how to set this up locally, so that the same set of checks run locally whenever you make a commit. :) |
Thanks :) sorry for that! |
Sorry did'nt have time to look into it earlier but seems good now:
:) Just missed an assertion |
Do we have an issue vmapping over Newton root finding that I should look into over at optimistix? |
Ah no I don't think the root finder was not the issue itself. I just changed to a diffrax root finder since _integrate.py did not import optimistix before so I also changed my test to use the diffrax root finder used in the other tests in that file. The actual issue the pre commit hook had was, that it wanted me to include an assertion to check if sol.ys is not None before accessing its last element. I just didn't catch that problem before bc i didn't read the CONTRIBUTING.md . Sorry didn't want to cause confusion. |
Thanks for prompt the clarification! |
Okay seems like there is another issue with my fix than just the pre-commits :/ |
Oh no, this isn't you... looks like JAX just did a 0.5.0 release, and it's completely destroyed their pseudorandom generation! The failing tests are from our statistical tests that our random number generation is, well, random. I've been able to reproduce this locally, with setting I've just opened jax-ml/jax#26019 to alert the upstream JAX folks. Hopefully they can identify this / revert in the mean time. If they're able to revert quickly then we can retrigger the tests here as a validation that things are back to being statistically correct. Let's wait and see. |
Okay, returning to this! As discussed in the thread, this isn't a JAX issue and is essentially a surprising case of correlated randomness :D #575 (just merged) should fix us up to having our tests pass under JAX 0.5.0. If you'd like to rebase this branch on top of the latest |
Thanks :)! |
Closing in favour of #578 then! |
No description provided.