Skip to content

Commit

Permalink
Merge pull request #67 from qutech/feature/len_and_getitem
Browse files Browse the repository at this point in the history
Add __len__ and __getitem__ methods to PulseSequence
  • Loading branch information
thangleiter authored May 12, 2021
2 parents 8bf6159 + d43e184 commit df0e553
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 7 deletions.
9 changes: 7 additions & 2 deletions doc/source/examples/advanced_concatenation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@
"FF_X2 = {key: val.get_filter_function(omega[key]) for key, val in X2.items()}\n",
"FF_Y2 = {key: val.get_filter_function(omega[key]) for key, val in Y2.items()}\n",
"H = {key: ff.concatenate((Y2, X2, X2), calc_pulse_correlation_FF=True)\n",
" for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}"
" for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}\n",
"\n",
"# Note that we can also slice PulseSequence objects, eg\n",
"# X = H['primitive'][1:]\n",
"# or\n",
"# segments = [segment for segment in H['primitive']]"
]
},
{
Expand Down Expand Up @@ -246,7 +251,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.4"
}
},
"nbformat": 4,
Expand Down
31 changes: 27 additions & 4 deletions filter_functions/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,28 @@ def __eq__(self, other: object) -> bool:

return True

def __len__(self) -> int:
return len(self.dt)

def __getitem__(self, key) -> 'PulseSequence':
"""Return a slice of the PulseSequence."""
new_dt = np.atleast_1d(self.dt[key])
if not new_dt.size:
raise IndexError('Cannot create empty PulseSequence')

new = self.__class__(
c_opers=self.c_opers,
n_opers=self.n_opers,
c_oper_identifiers=self.c_oper_identifiers,
n_oper_identifiers=self.n_oper_identifiers,
c_coeffs=np.atleast_2d(self.c_coeffs.T[key]).T,
n_coeffs=np.atleast_2d(self.n_coeffs.T[key]).T,
dt=new_dt,
d=self.d,
basis=self.basis
)
return new

def __copy__(self) -> 'PulseSequence':
"""Return shallow copy of self"""
cls = self.__class__
Expand Down Expand Up @@ -1488,14 +1510,15 @@ def concatenate_without_filter_function(pulses: Iterable[PulseSequence],
concatenate: Concatenate PulseSequences including filter functions.
concatenate_periodic: Concatenate PulseSequences periodically.
"""
pulses = tuple(pulses)
try:
# Do awkward checking for type
if not all(hasattr(pls, 'c_opers') for pls in pulses):
raise TypeError('Can only concatenate PulseSequences!')
pulses = tuple(pulses)
except TypeError:
raise TypeError(f'Expected pulses to be iterable, not {type(pulses)}')

if not all(hasattr(pls, 'c_opers') for pls in pulses):
# Do awkward checking for type
raise TypeError('Can only concatenate PulseSequences!')

# Check if the Hamiltonians' shapes are compatible, ie the set of all
# shapes has length 1
if len(set(pulse.c_opers.shape[1:] for pulse in pulses)) != 1:
Expand Down
40 changes: 39 additions & 1 deletion tests/test_sequencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,44 @@ def test_concatenate_base(self):
pulse_2.omega = [3, 4]
ff.concatenate([pulse_1, pulse_2], calc_filter_function=True)

def test_slicing(self):
"""Tests _getitem__."""
for d, n in zip(rng.integers(2, 5, 20), rng.integers(3, 51, 20)):
pulse = testutil.rand_pulse_sequence(d, n)
parts = np.array([part for part in pulse], dtype=object).squeeze()

# Iterable
self.assertEqual(pulse, ff.concatenate(parts))
self.assertEqual(len(pulse), n)

# Slices
ix = rng.integers(1, n-1)
part = pulse[ix]
self.assertEqual(part, parts[ix])
self.assertEqual(pulse, ff.concatenate([pulse[:ix], pulse[ix:]]))

# More complicated slices
self.assertEqual(pulse[:len(pulse) // 2 * 2],
ff.concatenate([p for zipped in zip(pulse[::2], pulse[1::2])
for p in zipped]))
self.assertEqual(pulse[::-1], ff.concatenate(parts[::-1]))

# Boolean indices
ix = rng.integers(0, 2, size=n, dtype=bool)
if not ix.any():
with self.assertRaises(IndexError):
pulse[ix]
else:
self.assertEqual(pulse[ix], ff.concatenate(parts[ix]))

# Raises
with self.assertRaises(IndexError):
pulse[:0]
with self.assertRaises(IndexError):
pulse[1, 3]
with self.assertRaises(IndexError):
pulse['a']

def test_concatenate_without_filter_function(self):
"""Concatenate two Spin Echos without filter functions."""
tau = 10
Expand Down Expand Up @@ -103,7 +141,7 @@ def test_concatenate_without_filter_function(self):

with self.assertRaises(TypeError):
# Not iterable
pulse_sequence.concatenate_without_filter_function(pulse)
pulse_sequence.concatenate_without_filter_function(1)

with self.assertRaises(ValueError):
# Incompatible Hamiltonian shapes
Expand Down

0 comments on commit df0e553

Please sign in to comment.