Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add automated mask shrinking #791

Merged
merged 7 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes.d/791.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Measurement windows can now automatically shrank in case of overlap to counteract small numeric errors.
5 changes: 2 additions & 3 deletions qupulse/hardware/dacs/alazar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from qupulse.utils.types import TimeType
from qupulse.hardware.dacs.dac_base import DAC
from qupulse.hardware.util import traced
from qupulse.utils.performance import time_windows_to_samples
from qupulse.utils.performance import time_windows_to_samples, shrink_overlapping_windows

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -283,8 +283,7 @@ def _make_mask(self, mask_id: str, begins, lengths) -> Mask:
if mask_type not in ('auto', 'cross_buffer', None):
warnings.warn("Currently only CrossBufferMask is implemented.")

if np.any(begins[:-1]+lengths[:-1] > begins[1:]):
raise ValueError('Found overlapping windows in begins')
begins, lengths = shrink_overlapping_windows(begins, lengths)

mask = CrossBufferMask()
mask.identifier = mask_id
Expand Down
71 changes: 71 additions & 0 deletions qupulse/utils/performance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Tuple
import numpy as np

Expand All @@ -24,6 +25,76 @@ def _is_monotonic_numpy(arr: np.ndarray) -> bool:
return np.all(arr[1:] >= arr[:-1])


def _shrink_overlapping_windows_numpy(begins, lengths) -> bool:
supported_dtypes = ('int64', 'uint64')
if begins.dtype.name not in supported_dtypes or lengths.dtype.name not in supported_dtypes:
raise NotImplementedError("This function only supports 64 bit integer types yet.")

ends = begins + lengths

overlaps = np.zeros_like(ends, dtype=np.int64)
np.maximum(ends[:-1].view(np.int64) - begins[1:].view(np.int64), 0, out=overlaps[1:])

if np.any(overlaps >= lengths):
raise ValueError("Overlap is bigger than measurement window")
if np.any(overlaps > 0):
begins += overlaps.view(begins.dtype)
lengths -= overlaps.view(lengths.dtype)
return True
return False


@njit
def _shrink_overlapping_windows_numba(begins, lengths) -> bool:
shrank = False
for idx in range(len(begins) - 1):
end = begins[idx] + lengths[idx]
next_begin = begins[idx + 1]

if end > next_begin:
overlap = end - next_begin
shrank = True
if lengths[idx + 1] > overlap:
begins[idx + 1] += overlap
lengths[idx + 1] -= overlap
else:
raise ValueError("Overlap is bigger than measurement window")
return shrank


class WindowOverlapWarning(RuntimeWarning):
COMMENT = (" This warning is an error by default. "
"Call 'warnings.simplefilter(WindowOverlapWarning, \"always\")' "
"to demote it to a regular warning.")

def __str__(self):
return super().__str__() + self.COMMENT


warnings.simplefilter(category=WindowOverlapWarning, action='error')


def shrink_overlapping_windows(begins, lengths, use_numba: bool = numba is not None) -> Tuple[np.array, np.array]:
"""Shrink windows in place if they overlap. Emits WindowOverlapWarning if a window was shrunk.

Raises:
ValueError: if the overlap is bigger than a window.

Warnings:
WindowOverlapWarning
"""
if use_numba:
backend = _shrink_overlapping_windows_numba
else:
backend = _shrink_overlapping_windows_numpy
begins = begins.copy()
lengths = lengths.copy()
if backend(begins, lengths):
warnings.warn("Found overlapping measurement windows which can be automatically shrunken if possible.",
category=WindowOverlapWarning)
return begins, lengths


@njit
def _time_windows_to_samples_sorted_numba(begins, lengths,
sample_rate: float) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
4 changes: 2 additions & 2 deletions tests/hardware/alazar_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..hardware import *
from qupulse.hardware.dacs.alazar import AlazarCard, AlazarProgram
from qupulse.utils.types import TimeType

from qupulse.utils.performance import WindowOverlapWarning

class AlazarProgramTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_make_mask(self):
with self.assertRaises(KeyError):
card._make_mask('N', begins, lengths)

with self.assertRaises(ValueError):
with self.assertWarns(WindowOverlapWarning):
card._make_mask('M', begins, lengths*3)

mask = card._make_mask('M', begins, lengths)
Expand Down
57 changes: 55 additions & 2 deletions tests/utils/performance_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import unittest
import warnings

import numpy as np

from qupulse.utils.performance import (_time_windows_to_samples_numba, _time_windows_to_samples_numpy,
_average_windows_numba, _average_windows_numpy, average_windows)
from qupulse.utils.performance import (
_time_windows_to_samples_numba, _time_windows_to_samples_numpy,
_average_windows_numba, _average_windows_numpy, average_windows,
shrink_overlapping_windows, WindowOverlapWarning)


class TimeWindowsToSamplesTest(unittest.TestCase):
Expand Down Expand Up @@ -55,3 +58,53 @@ def test_single_channel(self):

def test_dual_channel(self):
self.assert_implementations_equal(self.time, self.values, self.begins, self.ends)


class TestOverlappingWindowReduction(unittest.TestCase):
def setUp(self):
self.shrank = np.array([1, 4, 8], dtype=np.uint64), np.array([3, 4, 4], dtype=np.uint64)
self.to_shrink = np.array([1, 4, 7], dtype=np.uint64), np.array([3, 4, 5], dtype=np.uint64)

def assert_noop(self, shrink_fn):
begins = np.array([1, 3, 5], dtype=np.uint64)
lengths = np.array([2, 1, 6], dtype=np.uint64)
result = shrink_fn(begins, lengths)
np.testing.assert_equal((begins, lengths), result)

begins = (np.arange(100) * 176.5).astype(dtype=np.uint64)
lengths = (np.ones(100) * 10 * np.pi).astype(dtype=np.uint64)
result = shrink_fn(begins, lengths)
np.testing.assert_equal((begins, lengths), result)

begins = np.arange(15, dtype=np.uint64)*16
lengths = 1+np.arange(15, dtype=np.uint64)
result = shrink_fn(begins, lengths)
np.testing.assert_equal((begins, lengths), result)

def assert_shrinks(self, shrink_fn):
with warnings.catch_warnings():
warnings.simplefilter("always", WindowOverlapWarning)
with self.assertWarns(WindowOverlapWarning):
shrank = shrink_fn(*self.to_shrink)
np.testing.assert_equal(self.shrank, shrank)

def assert_empty_window_error(self, shrink_fn):
invalid = np.array([1, 2], dtype=np.uint64), np.array([5, 1], dtype=np.uint64)
with self.assertRaisesRegex(ValueError, "Overlap is bigger than measurement window"):
shrink_fn(*invalid)

def test_shrink_overlapping_windows_numba(self):
def shrink_fn(begins, lengths):
return shrink_overlapping_windows(begins, lengths, use_numba=True)

self.assert_noop(shrink_fn)
self.assert_shrinks(shrink_fn)
self.assert_empty_window_error(shrink_fn)

def test_shrink_overlapping_windows_numpy(self):
def shrink_fn(begins, lengths):
return shrink_overlapping_windows(begins, lengths, use_numba=False)

self.assert_noop(shrink_fn)
self.assert_shrinks(shrink_fn)
self.assert_empty_window_error(shrink_fn)
Loading