From 531188d06a4681e567c2e8613d61d846faa81cd8 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Tue, 13 Jun 2023 10:38:35 -0500 Subject: [PATCH] fix: let JetDefinition be pickleable (#224) * let JetDefinition be pickleable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * install distributed for tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * try using dask client in test * assign Client as _ since we never call it directly * actually set args/kwargs in class :-) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more correct * no item assignment for tuple * actually remove assignment statement... --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- setup.cfg | 1 + src/fastjet/__init__.py | 37 +++++++++++++------ tests/test_008-dask.py | 80 +++++++++++++++++++++-------------------- 3 files changed, 69 insertions(+), 49 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7531869b..13989ad8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,6 +58,7 @@ test = uproot>=5 dask>=2023.4.0;python_version>"3.7" dask-awkward>=2023.4.2;python_version>"3.7" + distributed>=2023.4.0;python_version>"3.7" [tool:pytest] addopts = -vv -rs -Wd diff --git a/src/fastjet/__init__.py b/src/fastjet/__init__.py index b75450e8..ad81f1bb 100644 --- a/src/fastjet/__init__.py +++ b/src/fastjet/__init__.py @@ -189,14 +189,7 @@ class JetDefinition(JetDefinitionNoCast): - def __init__( - self, - jet_algorithm_in, - R_in, - recomb_scheme_in=0, - strategy_in=1, - nparameters_in=1, - ): + def __init__(self, *args, **kwargs): r""" `JetDefinition(JetAlgorithm jet_algorithm_in, double R_in, RecombinationScheme @@ -206,6 +199,14 @@ def __init__( how algorithically to run it). """ + + R_in = kwargs.pop("R_in", None) + as_kwargs = False + if R_in is None: + R_in = args[1] + else: + as_kwargs = True + if not isinstance(R_in, (float, int)): raise ValueError( f"R_in should be a real number, got {R_in} of type {type(R_in)}" @@ -214,9 +215,23 @@ def __init__( if isinstance(R_in, int): R_in = float(R_in) - super().__init__( - jet_algorithm_in, R_in, recomb_scheme_in, strategy_in, nparameters_in - ) + new_args = args + new_kwargs = kwargs + if as_kwargs: + new_kwargs = kwargs.copy() + new_kwargs["R_in"] = R_in + else: + new_args = (args[0], R_in, *args[2:]) + + self.args = new_args + self.kwargs = new_kwargs + super().__init__(*new_args, **kwargs) + + def __setstate__(self, state): + self.__init__(*state["args"], **state["kwargs"]) + + def __getstate__(self): + return {"args": self.args, "kwargs": self.kwargs} class ClusterSequence: # The super class diff --git a/tests/test_008-dask.py b/tests/test_008-dask.py index 9012893c..cd6a81f5 100644 --- a/tests/test_008-dask.py +++ b/tests/test_008-dask.py @@ -5,57 +5,61 @@ import fastjet._pyjet # noqa: F401 dak = pytest.importorskip("dask_awkward") # noqa: F401 +distributed = pytest.importorskip("distributed") vector = pytest.importorskip("vector") # noqa: F401 def test_multi(): - array = ak.Array( - [ + from distributed import Client + + with Client() as _: + array = ak.Array( [ - {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}, - {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, - {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}, + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}, + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], + ] + ) + darray = dak.from_awkward(array, 1) + jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6) + cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef) + inclusive_jets_out = [ + [ + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}, + {"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68}, ], [ - {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}, - {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, - {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}, + {"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68}, ], ] - ) - darray = dak.from_awkward(array, 1) - jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.6) - cluster = fastjet._pyjet.DaskAwkwardClusterSequence(darray, jetdef) - inclusive_jets_out = [ - [ - {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}, - {"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68}, - ], - [ - {"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5}, - {"px": 64.65, "py": 127.41999999999999, "pz": 1086.48, "E": 48.68}, - ], - ] - assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list() - constituents_output = [ - [ - [{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}], + assert inclusive_jets_out == cluster.inclusive_jets().compute().to_list() + constituents_output = [ [ - {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, - {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + [{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}], + [ + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], ], - ], - [ - [{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}], [ - {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, - {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + [{"px": 1.2, "py": 3.2, "pz": 5.4, "E": 2.5, "ex": 0.78}], + [ + {"px": 32.2, "py": 64.21, "pz": 543.34, "E": 24.12, "ex": 0.35}, + {"px": 32.45, "py": 63.21, "pz": 543.14, "E": 24.56, "ex": 0.0}, + ], ], - ], - ] - assert cluster.constituents().compute().to_list() == constituents_output - constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]] - assert constituent_index_out == cluster.constituent_index().compute().to_list() + ] + assert cluster.constituents().compute().to_list() == constituents_output + constituent_index_out = [[[0], [1, 2]], [[0], [1, 2]]] + assert constituent_index_out == cluster.constituent_index().compute().to_list() def test_single():