Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: let JetDefinition be pickleable #224

Merged
merged 11 commits into from
Jun 13, 2023
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ test =
uproot>=5
dask>=2023.4.0;python_version>"3.7"
dask-awkward>=2023.4.2;python_version>"3.7"
distributed>=2023.4.0;python_version>"3.7"

[tool:pytest]
addopts = -vv -rs -Wd
Expand Down
37 changes: 26 additions & 11 deletions src/fastjet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,7 @@


class JetDefinition(JetDefinitionNoCast):
def __init__(
self,
jet_algorithm_in,
R_in,
recomb_scheme_in=0,
strategy_in=1,
nparameters_in=1,
):
def __init__(self, *args, **kwargs):
r"""

`JetDefinition(JetAlgorithm jet_algorithm_in, double R_in, RecombinationScheme
Expand All @@ -206,6 +199,14 @@ def __init__(
how algorithically to run it).

"""

R_in = kwargs.pop("R_in", None)
as_kwargs = False
if R_in is None:
R_in = args[1]
else:
as_kwargs = True

if not isinstance(R_in, (float, int)):
raise ValueError(
f"R_in should be a real number, got {R_in} of type {type(R_in)}"
Expand All @@ -214,9 +215,23 @@ def __init__(
if isinstance(R_in, int):
R_in = float(R_in)

super().__init__(
jet_algorithm_in, R_in, recomb_scheme_in, strategy_in, nparameters_in
)
new_args = args
new_kwargs = kwargs
if as_kwargs:
new_kwargs = kwargs.copy()
new_kwargs["R_in"] = R_in
else:
new_args = (args[0], R_in, *args[2:])

self.args = new_args
self.kwargs = new_kwargs
super().__init__(*new_args, **kwargs)

def __setstate__(self, state):
self.__init__(*state["args"], **state["kwargs"])

def __getstate__(self):
return {"args": self.args, "kwargs": self.kwargs}


class ClusterSequence: # The super class
Expand Down
80 changes: 42 additions & 38 deletions tests/test_008-dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,61 @@
import fastjet._pyjet # noqa: F401

dak = pytest.importorskip("dask_awkward") # noqa: F401
distributed = pytest.importorskip("distributed")
vector = pytest.importorskip("vector") # noqa: F401


def test_multi():
array = ak.Array(
[
from distributed import Client

with Client() as _:
array = ak.Array(
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
]
)
darray = dak.from_awkward(array, 1)
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6)
cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef)
inclusive_jets_out = [
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78},
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
]
)
darray = dak.from_awkward(array, 1)
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6)
cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef)
inclusive_jets_out = [
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
[
{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5},
{"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68},
],
]
assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list()
constituents_output = [
[
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list()
constituents_output = [
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
],
],
[
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
[{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}],
[
{"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35},
{"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0},
],
],
],
]
assert cluster.constituents().compute().to_list() == constituents_output
constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]]
assert constituent_index_out == cluster.constituent_index().compute().to_list()
]
assert cluster.constituents().compute().to_list() == constituents_output
constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]]
assert constituent_index_out == cluster.constituent_index().compute().to_list()


def test_single():
Expand Down