diff --git a/setup.cfg b/setup.cfg index e7321ef3..acbe093b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,8 @@ package_dir = where = src [options.extras_require] +dask = + dask-awkward>=2023.3.2 dev = pytest>=4.6 docs = @@ -52,6 +54,7 @@ docs = sphinx-rtd-theme>=0.5.0 test = pytest>=4.6 + dask-awkward>=2023.3.2;python_version>"3.7" [tool:pytest] addopts = -vv -rs -Wd diff --git a/src/fastjet/__init__.py b/src/fastjet/__init__.py index d0ef91cf..7e8354f3 100644 --- a/src/fastjet/__init__.py +++ b/src/fastjet/__init__.py @@ -204,9 +204,23 @@ def __init__(self, data, jetdef): fastjet._pyjet.AwkwardClusterSequence.__init__( self, data=data, jetdef=jetdef ) - if isinstance(data, list): + elif isinstance(data, list): self.__class__ = fastjet._swig.ClusterSequence fastjet._swig.ClusterSequence.__init__(self, data, jetdef) + else: + try: + import dask_awkward as dak + except ImportError: + dak = None + if dak is not None and isinstance(data, dak.Array): + self.__class__ = fastjet._pyjet.DaskAwkwardClusterSequence + fastjet._pyjet.DaskAwkwardClusterSequence.__init__( + self, data=data, jetdef=jetdef + ) + else: + raise TypeError( + f"{data} must be an awkward.Array, dask_awkward.Array, or list!" + ) def jet_def(self) -> JetDefinition: """Returns the Jet Definition Object associated with the instance diff --git a/src/fastjet/_pyjet.py b/src/fastjet/_pyjet.py index e9ae66a2..258c8af6 100644 --- a/src/fastjet/_pyjet.py +++ b/src/fastjet/_pyjet.py @@ -228,3 +228,260 @@ def get_child(self, data): if not isinstance(data, ak.Array): raise TypeError("The input data is not an Awkward Array") return self._internalrep.get_child(data) + + +class _FnDelayedInternalRepCaller: + def __init__(self, method_name, jetdef, **kwargs): + self.name = method_name + self.jetdef = jetdef + self.kwargs = kwargs + + def __call__(self, array): + if ak.backend(array) == "typetracer": + length_zero_array = array.layout.form.length_zero_array( + behavior=array.behavior + ) + seq = AwkwardClusterSequence(length_zero_array, self.jetdef) + out = getattr(seq, self.name)(**self.kwargs) + return ak.Array(out.layout.to_typetracer(forget_length=True)) + seq = AwkwardClusterSequence(array, self.jetdef) + return getattr(seq, self.name)(**self.kwargs) + + +def _dak_dispatch(cluseq, method_name, **kwargs): + from dask_awkward.utils import hyphenize + + return cluseq._data.map_partitions( + _FnDelayedInternalRepCaller(method_name, cluseq._jetdef, **kwargs), + label=hyphenize(method_name), + ) + + +class DaskAwkwardClusterSequence(ClusterSequence): + def __init__(self, data, jetdef): + import dask_awkward as dak + + if not isinstance(data, dak.Array): + raise TypeError("The input data is not an Dask Array!") + if not isinstance(jetdef, fastjet._swig.JetDefinition): + raise TypeError("JetDefinition is not of valid type") + self._jetdef = jetdef + self._data = data + self._jagedness = self._check_jaggedness(data._meta) + self._flag = 1 + length_zero_data = data._meta.layout.form.length_zero_array( + behavior=data.behavior + ) + if (self._check_listoffset(data._meta) and self._jagedness == 2) or ( + self._check_listoffset_index(data._meta) + ): + self._flag = 0 + self._internalrep = fastjet._multievent._classmultievent( + length_zero_data, self._jetdef + ) + elif self._jagedness == 1 and data.layout.is_record: + self._internalrep = fastjet._singleevent._classsingleevent( + length_zero_data, self._jetdef + ) + elif self._jagedness >= 3 or self._check_general(data): + self._internalrep = fastjet._generalevent._classgeneralevent( + length_zero_data, jetdef + ) + + # else: + # raise TypeError( + # "This kind of Awkward Array is not supported yet. Please contact the maintainers for further action." + # ) + + def _check_jaggedness(self, data): + if self._check_general_jaggedness(data) or self._check_listoffset(data): + return 1 + self._check_jaggedness(ak.Array(data.layout.content)) + if data.layout.is_union: + return 1 + max( + self._check_jaggedness(ak.Array(x)) for x in data.layout.contents + ) + if data.layout.is_record: + return 1 + max( + self._check_jaggedness(ak.Array(x)) for x in data.layout.contents + ) + return 0 + + def _check_listoffset_index(self, data): + if self._check_listoffset_subtree(ak.Array(data.layout)): + if self._check_record( + ak.Array(ak.Array(data.layout.content)), + ): + return True + elif self._check_indexed( + ak.Array(data.layout.content), + ): + if self._check_record( + ak.Array(ak.Array(data.layout.content).layout.content) + ): + return True + else: + return False + else: + return False + + def _check_record(self, data): + return data.layout.is_record or data.layout.is_numpy + + def _check_indexed(self, data): + return data.layout.is_indexed + + def _check_listoffset_subtree(self, data): + return data.layout.is_list + + def _check_general(self, data): + out = isinstance( + data.layout, + ( + ak.contents.BitMaskedArray, + ak.contents.ByteMaskedArray, + ak.contents.IndexedArray, + ak.contents.IndexedOptionArray, + ak.contents.UnionArray, + ak.contents.UnmaskedArray, + ak.record.Record, + ), + ) + return out + + def _check_general_jaggedness(self, data): + out = isinstance( + data.layout, + ( + ak.contents.BitMaskedArray, + ak.contents.ByteMaskedArray, + ak.contents.IndexedArray, + ak.contents.IndexedOptionArray, + ak.contents.UnmaskedArray, + ak.record.Record, + ), + ) + return out + + def _check_listoffset(self, data): + out = isinstance( + data.layout, + ( + ak.contents.ListArray, + ak.contents.ListOffsetArray, + ak.contents.RegularArray, + ), + ) + return out + + def jet_def(self): + return self._jetdef + + def inclusive_jets(self, min_pt=0): + return _dak_dispatch(self, "inclusive_jets", min_pt=min_pt) + + def unclustered_particles(self): + return _dak_dispatch(self, "unclustered_particles") + + def exclusive_jets(self, n_jets=-1, dcut=-1): + return _dak_dispatch(self, "exclusive_jets", n_jets=n_jets, dcut=dcut) + + def exclusive_jets_ycut(self, ycut=-1): + return _dak_dispatch(self, "exclusive_jets_ycut", ycut=ycut) + + def constituent_index(self, min_pt=0): + return _dak_dispatch(self, "constituent_index", min_pt=min_pt) + + def constituents(self, min_pt=0): + return _dak_dispatch(self, "constituents", min_pt=min_pt) + + def exclusive_jets_constituent_index(self, njets=10): + return self._internalrep.exclusive_jets_constituent_index(njets) + + def exclusive_jets_constituents(self, njets=10): + return self._internalrep.exclusive_jets_constituents(njets) + + def exclusive_jets_lund_declusterings(self, njets=10): + return self._internalrep.exclusive_jets_lund_declusterings(njets) + + def exclusive_dmerge(self, njets=10): + return self._internalrep.exclusive_dmerge(njets) + + def exclusive_dmerge_max(self, njets=10): + return self._internalrep.exclusive_dmerge_max(njets) + + def exclusive_ymerge_max(self, njets=10): + return self._internalrep.exclusive_ymerge_max(njets) + + def exclusive_ymerge(self, njets=10): + return self._internalrep.exclusive_ymerge(njets) + + def Q(self): + return self._internalrep.Q() + + def Q2(self): + return self._internalrep.Q2() + + def exclusive_subjets(self, data, dcut=-1, nsub=-1): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.exclusive_subjets(data, dcut, nsub) + + def exclusive_subjets_up_to(self, data, nsub=0): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.exclusive_subjets_up_to(data, nsub) + + def exclusive_subdmerge(self, data, nsub=0): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.exclusive_subdmerge(data, nsub) + + def exclusive_subdmerge_max(self, data, nsub=0): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.exclusive_subdmerge_max(data, nsub) + + def n_exclusive_subjets(self, data, dcut=0): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.n_exclusive_subjets(data, dcut) + + def has_parents(self, data): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.has_parents(data) + + def has_child(self, data): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.has_child(data) + + def jet_scale_for_algorithm(self, data): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.jet_scale_for_algorithm(data) + + def unique_history_order(self): + return self._internalrep.unique_history_order() + + def n_particles(self): + return self._internalrep.n_particles() + + def n_exclusive_jets(self, dcut=0): + return self._internalrep.n_exclusive_jets(dcut) + + def childless_pseudojets(self): + return self._internalrep.childless_pseudojets() + + def jets(self): + return self._internalrep.jets() + + def get_parents(self, data): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.get_parents(data) + + def get_child(self, data): + if not isinstance(data, ak.Array): + raise TypeError("The input data is not an Awkward Array") + return self._internalrep.get_child(data) diff --git a/tests/test_008-dask.py b/tests/test_008-dask.py new file mode 100644 index 00000000..48f598a3 --- /dev/null +++ b/tests/test_008-dask.py @@ -0,0 +1,87 @@ +import awkward as ak # noqa: F401 +import numpy as np # noqa: F401 +import pytest # noqa: F401 + +import fastjet._pyjet # noqa: F401 + +dak = pytest.importorskip("dask_awkward") # noqa: F401 +vector = pytest.importorskip("vector") # noqa: F401 + + +def test_multi(): + array = ak.Array( + [ + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}, + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}, + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], + ] + ) + darray = dak.from_awkward(array, 1) + jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6) + cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef) + inclusive_jets_out = [ + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}, + {"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68}, + ], + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}, + {"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68}, + ], + ] + assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list() + constituents_output = [ + [ + [{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}], + [ + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], + ], + [ + [{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}], + [ + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], + ], + ] + assert cluster.constituents().compute().to_list() == constituents_output + constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]] + assert constituent_index_out == cluster.constituent_index().compute().to_list() + + +def test_single(): + array = ak.Array( + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}, + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ] + ) + darray = dak.from_awkward(array, 1) + jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6) + cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef) + inclusive_jets_out = [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}, + {"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68}, + ] + + assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list() + constituent_output = [ + [{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}], + [ + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56}, + ], + ] + assert constituent_output == cluster.constituents().compute().to_list() + constituent_index_output = [[0], [1, 2]] + assert constituent_index_output == cluster.constituent_index().compute().to_list()