Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating through an ak.mean #2638

Closed
Tracked by #3
alexander-held opened this issue Aug 10, 2023 · 5 comments · Fixed by #3020
Closed
Tracked by #3

Differentiating through an ak.mean #2638

alexander-held opened this issue Aug 10, 2023 · 5 comments · Fixed by #3020
Assignees
Labels
autodiff Issue related to auto-differentiation feature New feature or request

Comments

@alexander-held
Copy link
Member

Version of Awkward Array

main branch

Description and code to reproduce

This is a follow-up to #2591 with a slightly more simplified setup. It should be conceptually possible to differentiate through taking a mean. Currently this does not work.

Reproducer:

import awkward as ak
import jax
import uproot

ak.jax.register_and_check()

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

def mean_jet_pt(jets):
    return ak.mean(jets.pt)

with uproot.open(ttbar_file) as f:
    arr = f["Events"].arrays(["Jet_pt","Jet_eta", "Jet_phi", "Jet_mass"])
    evtfilter = ak.num(arr["Jet_pt"]) >= 2
    jets = ak.zip(dict(zip(["pt","eta", "phi", "mass"], ak.unzip(arr))), with_name="Momentum4D")[evtfilter]
    jets = ak.to_backend(jets, "jax")


jax.value_and_grad(mean_jet_pt, argnums=0)(jets)

Result:

RuntimeError: Cannot differentiate through count_zero

This error occurred while calling

    ak.mean(
        <Array [[...], [...], ..., [...], [...]] type='140 * var * float32'>
    )

A standalone jax version of taking a mean works fine:

import jax.numpy as jnp

def mean(j):
    return jnp.mean(j)
    
data = jnp.array([1, 7, 3, 5],dtype=float)

jax.value_and_grad(mean, argnums=0)(data)
@alexander-held alexander-held added the bug (unverified) The problem described would be a bug, but needs to be triaged label Aug 10, 2023
@jpivarski jpivarski added the autodiff Issue related to auto-differentiation label Oct 2, 2023
@jpivarski
Copy link
Member

Another autodiff issue to self-assign, @Saransh-cpp. Thanks!

@Saransh-cpp Saransh-cpp self-assigned this Jan 20, 2024
@Saransh-cpp
Copy link
Member

Saransh-cpp commented Jan 20, 2024

Thanks for the tags, self-assigned!

@Saransh-cpp
Copy link
Member

Hi @alexander-held, I've been looking at this issue, and it seems more of a new feature request for the Jax backend.

The implementation (_impl) of ak.mean calls _impl of ak.count -

sumw = ak.operations.ak_count._impl(

but, ak.count is not implemented for the Jax backend -

@overloads(_reducers.Count)
class Count(JAXReducer):
name: Final = "count"
preferred_dtype: Final = np.float64
needs_position: Final = False
@classmethod
def from_kernel_reducer(cls, reducer: Reducer) -> Self:
assert isinstance(reducer, _reducers.Count)
return cls()
@classmethod
def _return_dtype(cls, given_dtype):
return np.int64
def apply(
self,
array: ak.contents.NumpyArray,
parents: ak.index.Index,
starts: ak.index.Index,
shifts: ak.index.Index | None,
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
raise RuntimeError("Cannot differentiate through count_zero")

I could trace back the history of this file, and it looks like the intentional error has always been there.

@jpivarski, were count, argmin, argmax, and count_nonzero not implemented for the Jax backend because they were not feasible for some reason, or were they just left for the future? Will adding their implementations be a good starting point for my project? I can work on these this week.

@alexander-held
Copy link
Member Author

I vaguely remember some conversations in the past about what kind of derivatives we might want to be able to evaluate and which might not be as useful. For discrete things like ak.count we would need to have some relaxation to define a derivative I believe, but we might not need to support cases where the amount of elements over which we take the mean changes. With the mean being the sum over the elements divided by the count, the derivative I was ultimately after in the code snippet above is just the derivative of the mean and then a division by a constant.

As soon as the number of elements changes (which might happen in practice with selection cuts) then handling that seems like a broader issue that the user might need to take care of externally. I'll point some more people towards this issue to invite some other opinions for what might be the most useful.

@jpivarski
Copy link
Member

@jpivarski, were count, argmin, argmax, and count_nonzero not implemented for the Jax backend because they were not feasible for some reason, or were they just left for the future?

These seem to be fundamentally non-differentiable—in my understanding, at least. ak.count, at least, depends only on the array's structure, and we don't even represent the array structure (the ak.index.Index parts) using JAX because you can't differentiate through that. It would be equivalent to differentiating through variables that are used in an if predicate. It's less obvious for ak.count_nonzero (the delta function is differentiable in Lebesgue measure but not Riemannian measure—I don't think that's relevant) and ak.argmin/ak.argmax.

However, differentiable arrays should carry through some of this stuff. For instance, ak.count should be treated as a constant, so that ak.mean can be implemented by using the implementation of ak.sum and dividing it by the constant that comes out of ak.count. (The derivative of ak.mean is a scaled version of the derivative of ak.sum.) I don't know about the others, ak.count_nonzero and ak.argmin/ak.argmax, though.

@Saransh-cpp Saransh-cpp added feature New feature or request and removed bug (unverified) The problem described would be a bug, but needs to be triaged labels Feb 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autodiff Issue related to auto-differentiation feature New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants