From 89c29c45c6837c8e4f9a7f8da9ca08dc955b26dd Mon Sep 17 00:00:00 2001 From: Tobias Hangleiter Date: Sun, 11 Apr 2021 17:54:02 +0200 Subject: [PATCH 1/6] Implement __len__ and __getitem__ Todo: tests --- filter_functions/pulse_sequence.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/filter_functions/pulse_sequence.py b/filter_functions/pulse_sequence.py index a1f4b8e..9a2620f 100644 --- a/filter_functions/pulse_sequence.py +++ b/filter_functions/pulse_sequence.py @@ -368,6 +368,28 @@ def __eq__(self, other: object) -> bool: return True + def __len__(self) -> int: + return len(self.dt) + + def __getitem__(self, key) -> 'PulseSequence': + """Return a slice of the PulseSequence.""" + new_dt = np.atleast_1d(self.dt[key]) + if not new_dt.size: + raise IndexError('Cannot create empty PulseSequence') + + new = self.__class__( + c_opers=self.c_opers, + n_opers=self.n_opers, + c_oper_identifiers=self.c_oper_identifiers, + n_oper_identifiers=self.n_oper_identifiers, + c_coeffs=np.atleast_2d(self.c_coeffs.T[key]).T, + n_coeffs=np.atleast_2d(self.c_coeffs.T[key]).T, + dt=new_dt, + d=self.d, + basis=self.basis + ) + return new + def __copy__(self) -> 'PulseSequence': """Return shallow copy of self""" cls = self.__class__ From 5bdeb47df66438f8eb34f22f7e640748e1185050 Mon Sep 17 00:00:00 2001 From: Tobias Hangleiter Date: Wed, 5 May 2021 20:45:55 +0200 Subject: [PATCH 2/6] Fix copy-paste bug --- filter_functions/pulse_sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/filter_functions/pulse_sequence.py b/filter_functions/pulse_sequence.py index 9a2620f..de88d9f 100644 --- a/filter_functions/pulse_sequence.py +++ b/filter_functions/pulse_sequence.py @@ -383,7 +383,7 @@ def __getitem__(self, key) -> 'PulseSequence': c_oper_identifiers=self.c_oper_identifiers, n_oper_identifiers=self.n_oper_identifiers, c_coeffs=np.atleast_2d(self.c_coeffs.T[key]).T, - n_coeffs=np.atleast_2d(self.c_coeffs.T[key]).T, + n_coeffs=np.atleast_2d(self.n_coeffs.T[key]).T, dt=new_dt, d=self.d, basis=self.basis From 9201c19affe6c39df7a70cbe76bc044b045ea595 Mon Sep 17 00:00:00 2001 From: Tobias Hangleiter Date: Thu, 6 May 2021 17:00:37 +0200 Subject: [PATCH 3/6] Improve error handling in concatenate_without_filter_function --- filter_functions/pulse_sequence.py | 9 +++++---- tests/test_sequencing.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/filter_functions/pulse_sequence.py b/filter_functions/pulse_sequence.py index de88d9f..2082c24 100644 --- a/filter_functions/pulse_sequence.py +++ b/filter_functions/pulse_sequence.py @@ -1510,14 +1510,15 @@ def concatenate_without_filter_function(pulses: Iterable[PulseSequence], concatenate: Concatenate PulseSequences including filter functions. concatenate_periodic: Concatenate PulseSequences periodically. """ - pulses = tuple(pulses) try: - # Do awkward checking for type - if not all(hasattr(pls, 'c_opers') for pls in pulses): - raise TypeError('Can only concatenate PulseSequences!') + pulses = tuple(pulses) except TypeError: raise TypeError(f'Expected pulses to be iterable, not {type(pulses)}') + if not all(hasattr(pls, 'c_opers') for pls in pulses): + # Do awkward checking for type + raise TypeError('Can only concatenate PulseSequences!') + # Check if the Hamiltonians' shapes are compatible, ie the set of all # shapes has length 1 if len(set(pulse.c_opers.shape[1:] for pulse in pulses)) != 1: diff --git a/tests/test_sequencing.py b/tests/test_sequencing.py index 69efafc..1c63385 100644 --- a/tests/test_sequencing.py +++ b/tests/test_sequencing.py @@ -103,7 +103,7 @@ def test_concatenate_without_filter_function(self): with self.assertRaises(TypeError): # Not iterable - pulse_sequence.concatenate_without_filter_function(pulse) + pulse_sequence.concatenate_without_filter_function(1) with self.assertRaises(ValueError): # Incompatible Hamiltonian shapes From b5c50df551e3c6ce48e18ffd5fb560e6d1471b42 Mon Sep 17 00:00:00 2001 From: Tobias Hangleiter Date: Mon, 10 May 2021 15:21:50 +0200 Subject: [PATCH 4/6] Add tests for __len__ and __getitem__ --- tests/test_sequencing.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_sequencing.py b/tests/test_sequencing.py index 1c63385..3e89396 100644 --- a/tests/test_sequencing.py +++ b/tests/test_sequencing.py @@ -62,6 +62,40 @@ def test_concatenate_base(self): pulse_2.omega = [3, 4] ff.concatenate([pulse_1, pulse_2], calc_filter_function=True) + def test_slicing(self): + """Tests _getitem__.""" + for d, n in zip(rng.integers(2, 5, 20), rng.integers(3, 51, 20)): + pulse = testutil.rand_pulse_sequence(d, n) + parts = np.array([part for part in pulse], dtype=object).squeeze() + + # Iterable + self.assertEqual(pulse, ff.concatenate(parts)) + self.assertEqual(len(pulse), n) + + # Slices + ix = rng.integers(1, n-1) + part = pulse[ix] + self.assertEqual(part, parts[ix]) + self.assertEqual(pulse, ff.concatenate([pulse[:ix], pulse[ix:]])) + + # More complicated slices + self.assertEqual(pulse[:len(pulse) // 2 * 2], + ff.concatenate([p for zipped in zip(pulse[::2], pulse[1::2]) + for p in zipped])) + self.assertEqual(pulse[::-1], ff.concatenate(parts[::-1])) + + # Boolean indices + ix = rng.integers(0, 2, size=n, dtype=bool) + self.assertEqual(pulse[ix], ff.concatenate(parts[ix])) + + # Raises + with self.assertRaises(IndexError): + pulse[:0] + with self.assertRaises(IndexError): + pulse[1, 3] + with self.assertRaises(IndexError): + pulse['a'] + def test_concatenate_without_filter_function(self): """Concatenate two Spin Echos without filter functions.""" tau = 10 From 7c6155dc1f9ea269d59802d9c3abb04e23f1c0e2 Mon Sep 17 00:00:00 2001 From: Tobias Hangleiter Date: Mon, 10 May 2021 15:22:03 +0200 Subject: [PATCH 5/6] Add comment about slicing to notebook --- doc/source/examples/advanced_concatenation.ipynb | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/doc/source/examples/advanced_concatenation.ipynb b/doc/source/examples/advanced_concatenation.ipynb index 00cd43c..f30063d 100644 --- a/doc/source/examples/advanced_concatenation.ipynb +++ b/doc/source/examples/advanced_concatenation.ipynb @@ -164,7 +164,12 @@ "FF_X2 = {key: val.get_filter_function(omega[key]) for key, val in X2.items()}\n", "FF_Y2 = {key: val.get_filter_function(omega[key]) for key, val in Y2.items()}\n", "H = {key: ff.concatenate((Y2, X2, X2), calc_pulse_correlation_FF=True)\n", - " for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}" + " for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}\n", + "\n", + "# Note that we can also slice PulseSequence objects, eg\n", + "# X = H['primitive'][1:]\n", + "# or\n", + "# segments = [segment for segment in H['primitive']]" ] }, { @@ -246,7 +251,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.4" } }, "nbformat": 4, From d43e18477097c6c2d01fdf77f11d65ba04c6231e Mon Sep 17 00:00:00 2001 From: Tobias Hangleiter Date: Mon, 10 May 2021 18:48:41 +0200 Subject: [PATCH 6/6] Catch edge case --- tests/test_sequencing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_sequencing.py b/tests/test_sequencing.py index 3e89396..0e30411 100644 --- a/tests/test_sequencing.py +++ b/tests/test_sequencing.py @@ -86,7 +86,11 @@ def test_slicing(self): # Boolean indices ix = rng.integers(0, 2, size=n, dtype=bool) - self.assertEqual(pulse[ix], ff.concatenate(parts[ix])) + if not ix.any(): + with self.assertRaises(IndexError): + pulse[ix] + else: + self.assertEqual(pulse[ix], ff.concatenate(parts[ix])) # Raises with self.assertRaises(IndexError):