-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiating through an ak.mean #2638
Comments
Another |
Thanks for the tags, self-assigned! |
Hi @alexander-held, I've been looking at this issue, and it seems more of a new feature request for the Jax backend. The implementation ( awkward/src/awkward/operations/ak_mean.py Line 195 in c1e4f9f
but, awkward/src/awkward/_connect/jax/reducers.py Lines 91 to 114 in c1e4f9f
I could trace back the history of this file, and it looks like the intentional error has always been there. @jpivarski, were |
I vaguely remember some conversations in the past about what kind of derivatives we might want to be able to evaluate and which might not be as useful. For discrete things like As soon as the number of elements changes (which might happen in practice with selection cuts) then handling that seems like a broader issue that the user might need to take care of externally. I'll point some more people towards this issue to invite some other opinions for what might be the most useful. |
These seem to be fundamentally non-differentiable—in my understanding, at least. However, differentiable arrays should carry through some of this stuff. For instance, |
Version of Awkward Array
main branch
Description and code to reproduce
This is a follow-up to #2591 with a slightly more simplified setup. It should be conceptually possible to differentiate through taking a mean. Currently this does not work.
Reproducer:
Result:
A standalone jax version of taking a mean works fine:
The text was updated successfully, but these errors were encountered: