From ef593653fe286b7294a6acb51e11cc5cd2feda79 Mon Sep 17 00:00:00 2001 From: Gonzalo Martinez Lema Date: Fri, 16 Feb 2024 16:59:55 +0100 Subject: [PATCH] Refactor `count_masked` --- invisible_cities/reco/xy_algorithms.py | 22 +++++++++++++-------- invisible_cities/reco/xy_algorithms_test.py | 10 ++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/invisible_cities/reco/xy_algorithms.py b/invisible_cities/reco/xy_algorithms.py index 35e5f4066..c94d4b1d3 100644 --- a/invisible_cities/reco/xy_algorithms.py +++ b/invisible_cities/reco/xy_algorithms.py @@ -102,12 +102,19 @@ def get_nearby_sipm_inds( center : np.ndarray # shape (2,) """ return np.where(np.linalg.norm(pos - center, axis=1) <= d)[0] -def count_masked(cs, d, datasipm, is_masked): - if is_masked is None: return 0 - - pos = np.stack([datasipm.X.values, datasipm.Y.values], axis=1) - indices = get_nearby_sipm_inds(cs, d, pos) - return np.count_nonzero(~is_masked.astype(bool)[indices]) +def count_masked( center : np.ndarray # shape (2,) + , d : float + , all_sipms : pd.DataFrame + ): + """ + Count the number of masked (inactive) SiPMs within a distance `d` + of `center`. Note that, unlike `get_nearby_sipm_inds`, this + function is meant to be called with *all* SiPMs. + """ + pos = all_sipms.filter(list("XY")) + masked = ~all_sipms.Active.values.astype(bool) # True if masked + indices = get_nearby_sipm_inds(center, d, pos) + return masked[indices].sum() @check_annotations @@ -222,7 +229,6 @@ def corona( pos : np.ndarray # (n, 2) assert new_lm_radius >= 0, "new_lm_radius must be non-negative" pos, qs = threshold_check(pos, qs, Qthr) - masked = all_sipms.Active.values.astype(bool) if consider_masked else None c = [] # While there are more local maxima @@ -237,7 +243,7 @@ def corona( pos : np.ndarray # (n, 2) # find the SiPMs within new_lm_radius of the new local maximum of charge within_new_lm_radius = get_nearby_sipm_inds(new_local_maximum, new_lm_radius, pos ) - n_masked_neighbours = count_masked (new_local_maximum, new_lm_radius, all_sipms, masked) + n_masked_neighbours = count_masked (new_local_maximum, new_lm_radius, all_sipms) if consider_masked else 0 # if there are at least msipms within_new_lm_radius, taking # into account any masked channel, get the barycenter diff --git a/invisible_cities/reco/xy_algorithms_test.py b/invisible_cities/reco/xy_algorithms_test.py index 4578e6226..fab3656c8 100644 --- a/invisible_cities/reco/xy_algorithms_test.py +++ b/invisible_cities/reco/xy_algorithms_test.py @@ -322,12 +322,7 @@ def test_count_masked_all_active(datasipm_all_active): is_masked = datasipm_all_active.Active.values # All sipms are active in run number 1 - assert count_masked(xy0, np.inf, datasipm_all_active, is_masked) == 0 - - -def test_count_masked_is_masked_None(): - dummy = None - assert count_masked(dummy, dummy, dummy, None) == 0 + assert count_masked(xy0, np.inf, datasipm_all_active) == 0 @mark.parametrize("sipm_id radius expected_nmasked".split(), @@ -341,11 +336,10 @@ def test_count_masked_near_masked(datasipm_5000, sipm_id, radius, expected_nmask sipm_indx = np.argwhere(datasipm_5000.SensorID.values == sipm_id)[0][0] masked_sipm = datasipm_5000.iloc[sipm_indx] masked_xy = np.array([masked_sipm.X, masked_sipm.Y]) - is_masked = datasipm_5000.Active.values # small smear so the search point doesn't fall exactly at sipm position masked_xy += np.random.normal(0, 0.001 * radius, size=2) - assert count_masked(masked_xy, radius, datasipm_5000, is_masked) == expected_nmasked + assert count_masked(masked_xy, radius, datasipm_5000) == expected_nmasked def test_masked_channels(datasipm_3x5):