Skip to content

Commit

Permalink
fix: let JetDefinition be pickleable (#224)
Browse files Browse the repository at this point in the history
* let JetDefinition be pickleable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* install distributed for tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* try using dask client in test

* assign Client as _ since we never call it directly

* actually set args/kwargs in class :-)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more correct

* no item assignment for tuple

* actually remove assignment statement...

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lgray and pre-commit-ci[bot] committed Jun 13, 2023
1 parent 9d3c90b commit 531188d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 49 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ test =
uproot>=5
dask>=2023.4.0;python_version>"3.7"
dask-awkward>=2023.4.2;python_version>"3.7"
distributed>=2023.4.0;python_version>"3.7"

[tool:pytest]
addopts = -vv -rs -Wd
Expand Down
37 changes: 26 additions & 11 deletions src/fastjet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,7 @@


class JetDefinition(JetDefinitionNoCast):
def __init__(
self,
jet_algorithm_in,
R_in,
recomb_scheme_in=0,
strategy_in=1,
nparameters_in=1,
):
def __init__(self, *args, **kwargs):
r"""
`JetDefinition(JetAlgorithm jet_algorithm_in, double R_in, RecombinationScheme
Expand All @@ -206,6 +199,14 @@ def __init__(
how algorithically to run it).
"""

R_in = kwargs.pop("R_in", None)
as_kwargs = False
if R_in is None:
R_in = args[1]
else:
as_kwargs = True

if not isinstance(R_in, (float, int)):
raise ValueError(
f"R_in should be a real number, got {R_in} of type {type(R_in)}"
Expand All @@ -214,9 +215,23 @@ def __init__(
if isinstance(R_in, int):
R_in = float(R_in)

super().__init__(
jet_algorithm_in, R_in, recomb_scheme_in, strategy_in, nparameters_in
)
new_args = args
new_kwargs = kwargs
if as_kwargs:
new_kwargs = kwargs.copy()
new_kwargs["R_in"] = R_in
else:
new_args = (args[0], R_in, *args[2:])

self.args = new_args
self.kwargs = new_kwargs
super().__init__(*new_args, **kwargs)

def __setstate__(self, state):
self.__init__(*state["args"], **state["kwargs"])

def __getstate__(self):
return {"args": self.args, "kwargs": self.kwargs}


class ClusterSequence: # The super class
Expand Down
80 changes: 42 additions & 38 deletions tests/test_008-dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,61 @@
import fastjet._pyjet # noqa: F401

dak = pytest.importorskip("dask_awkward") # noqa: F401
distributed = pytest.importorskip("distributed")
vector = pytest.importorskip("vector") # noqa: F401


def test_multi():
array = ak.Array(
[
from distributed import Client

with Client() as _:
array = ak.Array(
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
]
)
darray = dak.from_awkward(array, 1)
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6)
cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef)
inclusive_jets_out = [
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
]
)
darray = dak.from_awkward(array, 1)
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6)
cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef)
inclusive_jets_out = [
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
]
assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list()
constituents_output = [
[
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list()
constituents_output = [
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
],
],
[
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
],
],
]
assert cluster.constituents().compute().to_list() == constituents_output
constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]]
assert constituent_index_out == cluster.constituent_index().compute().to_list()
]
assert cluster.constituents().compute().to_list() == constituents_output
constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]]
assert constituent_index_out == cluster.constituent_index().compute().to_list()


def test_single():
Expand Down

0 comments on commit 531188d

Please sign in to comment.