Skip to content

Commit

Permalink
partial set of bindings for fastjet into dask awkward - needs completion
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 28, 2023
1 parent 499c127 commit 60f1d03
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 1 deletion.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ package_dir =
where = src

[options.extras_require]
dask =
dask-awkward>=2023.3.2
dev =
pytest>=4.6
docs =
Expand All @@ -52,6 +54,7 @@ docs =
sphinx-rtd-theme>=0.5.0
test =
pytest>=4.6
dask-awkward>=2023.3.2;python_version>"3.7"

[tool:pytest]
addopts = -vv -rs -Wd
Expand Down
16 changes: 15 additions & 1 deletion src/fastjet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,23 @@ def __init__(self, data, jetdef):
fastjet._pyjet.AwkwardClusterSequence.__init__(
self, data=data, jetdef=jetdef
)
if isinstance(data, list):
elif isinstance(data, list):
self.__class__ = fastjet._swig.ClusterSequence
fastjet._swig.ClusterSequence.__init__(self, data, jetdef)
else:
try:
import dask_awkward as dak
except ImportError:
dak = None
if dak is not None and isinstance(data, dak.Array):
self.__class__ = fastjet._pyjet.DaskAwkwardClusterSequence
fastjet._pyjet.DaskAwkwardClusterSequence.__init__(
self, data=data, jetdef=jetdef
)
else:
raise TypeError(
f"{data} must be an awkward.Array, dask_awkward.Array, or list!"
)

def jet_def(self) -> JetDefinition:
"""Returns the Jet Definition Object associated with the instance
Expand Down
257 changes: 257 additions & 0 deletions src/fastjet/_pyjet.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,260 @@ def get_child(self, data):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.get_child(data)


class _FnDelayedInternalRepCaller:
def __init__(self, method_name, jetdef, **kwargs):
self.name = method_name
self.jetdef = jetdef
self.kwargs = kwargs

def __call__(self, array):
if ak.backend(array) == "typetracer":
length_zero_array = array.layout.form.length_zero_array(
behavior=array.behavior
)
seq = AwkwardClusterSequence(length_zero_array, self.jetdef)
out = getattr(seq, self.name)(**self.kwargs)
return ak.Array(out.layout.to_typetracer(forget_length=True))
seq = AwkwardClusterSequence(array, self.jetdef)
return getattr(seq, self.name)(**self.kwargs)


def _dak_dispatch(cluseq, method_name, **kwargs):
from dask_awkward.utils import hyphenize

return cluseq._data.map_partitions(
_FnDelayedInternalRepCaller(method_name, cluseq._jetdef, **kwargs),
label=hyphenize(method_name),
)


class DaskAwkwardClusterSequence(ClusterSequence):
def __init__(self, data, jetdef):
import dask_awkward as dak

if not isinstance(data, dak.Array):
raise TypeError("The input data is not an Dask Array!")
if not isinstance(jetdef, fastjet._swig.JetDefinition):
raise TypeError("JetDefinition is not of valid type")
self._jetdef = jetdef
self._data = data
self._jagedness = self._check_jaggedness(data._meta)
self._flag = 1
length_zero_data = data._meta.layout.form.length_zero_array(
behavior=data.behavior
)
if (self._check_listoffset(data._meta) and self._jagedness == 2) or (
self._check_listoffset_index(data._meta)
):
self._flag = 0
self._internalrep = fastjet._multievent._classmultievent(
length_zero_data, self._jetdef
)
elif self._jagedness == 1 and data.layout.is_record:
self._internalrep = fastjet._singleevent._classsingleevent(
length_zero_data, self._jetdef
)
elif self._jagedness >= 3 or self._check_general(data):
self._internalrep = fastjet._generalevent._classgeneralevent(
length_zero_data, jetdef
)

# else:
# raise TypeError(
# "This kind of Awkward Array is not supported yet. Please contact the maintainers for further action."
# )

def _check_jaggedness(self, data):
if self._check_general_jaggedness(data) or self._check_listoffset(data):
return 1 + self._check_jaggedness(ak.Array(data.layout.content))
if data.layout.is_union:
return 1 + max(
self._check_jaggedness(ak.Array(x)) for x in data.layout.contents
)
if data.layout.is_record:
return 1 + max(
self._check_jaggedness(ak.Array(x)) for x in data.layout.contents
)
return 0

def _check_listoffset_index(self, data):
if self._check_listoffset_subtree(ak.Array(data.layout)):
if self._check_record(
ak.Array(ak.Array(data.layout.content)),
):
return True
elif self._check_indexed(
ak.Array(data.layout.content),
):
if self._check_record(
ak.Array(ak.Array(data.layout.content).layout.content)
):
return True
else:
return False
else:
return False

def _check_record(self, data):
return data.layout.is_record or data.layout.is_numpy

def _check_indexed(self, data):
return data.layout.is_indexed

def _check_listoffset_subtree(self, data):
return data.layout.is_list

def _check_general(self, data):
out = isinstance(
data.layout,
(
ak.contents.BitMaskedArray,
ak.contents.ByteMaskedArray,
ak.contents.IndexedArray,
ak.contents.IndexedOptionArray,
ak.contents.UnionArray,
ak.contents.UnmaskedArray,
ak.record.Record,
),
)
return out

def _check_general_jaggedness(self, data):
out = isinstance(
data.layout,
(
ak.contents.BitMaskedArray,
ak.contents.ByteMaskedArray,
ak.contents.IndexedArray,
ak.contents.IndexedOptionArray,
ak.contents.UnmaskedArray,
ak.record.Record,
),
)
return out

def _check_listoffset(self, data):
out = isinstance(
data.layout,
(
ak.contents.ListArray,
ak.contents.ListOffsetArray,
ak.contents.RegularArray,
),
)
return out

def jet_def(self):
return self._jetdef

def inclusive_jets(self, min_pt=0):
return _dak_dispatch(self, "inclusive_jets", min_pt=min_pt)

def unclustered_particles(self):
return _dak_dispatch(self, "unclustered_particles")

def exclusive_jets(self, n_jets=-1, dcut=-1):
return _dak_dispatch(self, "exclusive_jets", n_jets=n_jets, dcut=dcut)

def exclusive_jets_ycut(self, ycut=-1):
return _dak_dispatch(self, "exclusive_jets_ycut", ycut=ycut)

def constituent_index(self, min_pt=0):
return _dak_dispatch(self, "constituent_index", min_pt=min_pt)

def constituents(self, min_pt=0):
return _dak_dispatch(self, "constituents", min_pt=min_pt)

def exclusive_jets_constituent_index(self, njets=10):
return self._internalrep.exclusive_jets_constituent_index(njets)

def exclusive_jets_constituents(self, njets=10):
return self._internalrep.exclusive_jets_constituents(njets)

def exclusive_jets_lund_declusterings(self, njets=10):
return self._internalrep.exclusive_jets_lund_declusterings(njets)

def exclusive_dmerge(self, njets=10):
return self._internalrep.exclusive_dmerge(njets)

def exclusive_dmerge_max(self, njets=10):
return self._internalrep.exclusive_dmerge_max(njets)

def exclusive_ymerge_max(self, njets=10):
return self._internalrep.exclusive_ymerge_max(njets)

def exclusive_ymerge(self, njets=10):
return self._internalrep.exclusive_ymerge(njets)

def Q(self):
return self._internalrep.Q()

def Q2(self):
return self._internalrep.Q2()

def exclusive_subjets(self, data, dcut=-1, nsub=-1):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.exclusive_subjets(data, dcut, nsub)

def exclusive_subjets_up_to(self, data, nsub=0):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.exclusive_subjets_up_to(data, nsub)

def exclusive_subdmerge(self, data, nsub=0):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.exclusive_subdmerge(data, nsub)

def exclusive_subdmerge_max(self, data, nsub=0):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.exclusive_subdmerge_max(data, nsub)

def n_exclusive_subjets(self, data, dcut=0):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.n_exclusive_subjets(data, dcut)

def has_parents(self, data):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.has_parents(data)

def has_child(self, data):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.has_child(data)

def jet_scale_for_algorithm(self, data):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.jet_scale_for_algorithm(data)

def unique_history_order(self):
return self._internalrep.unique_history_order()

def n_particles(self):
return self._internalrep.n_particles()

def n_exclusive_jets(self, dcut=0):
return self._internalrep.n_exclusive_jets(dcut)

def childless_pseudojets(self):
return self._internalrep.childless_pseudojets()

def jets(self):
return self._internalrep.jets()

def get_parents(self, data):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.get_parents(data)

def get_child(self, data):
if not isinstance(data, ak.Array):
raise TypeError("The input data is not an Awkward Array")
return self._internalrep.get_child(data)
87 changes: 87 additions & 0 deletions tests/test_008-dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import awkward as ak # noqa: F401
import numpy as np # noqa: F401
import pytest # noqa: F401

import fastjet._pyjet # noqa: F401

dak = pytest.importorskip("dask_awkward") # noqa: F401
vector = pytest.importorskip("vector") # noqa: F401


def test_multi():
array = ak.Array(
[
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
]
)
darray = dak.from_awkward(array, 1)
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6)
cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef)
inclusive_jets_out = [
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
]
assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list()
constituents_output = [
[
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
],
[
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
],
]
assert cluster.constituents().compute().to_list() == constituents_output
constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]]
assert constituent_index_out == cluster.constituent_index().compute().to_list()


def test_single():
array = ak.Array(
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
]
)
darray = dak.from_awkward(array, 1)
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6)
cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef)
inclusive_jets_out = [
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
]

assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list()
constituent_output = [
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56},
],
]
assert constituent_output == cluster.constituents().compute().to_list()
constituent_index_output = [[0], [1, 2]]
assert constituent_index_output == cluster.constituent_index().compute().to_list()

0 comments on commit 60f1d03

Please sign in to comment.