Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __len__ and __getitem__ methods to PulseSequence #67

Merged
merged 6 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions doc/source/examples/advanced_concatenation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@
"FF_X2 = {key: val.get_filter_function(omega[key]) for key, val in X2.items()}\n",
"FF_Y2 = {key: val.get_filter_function(omega[key]) for key, val in Y2.items()}\n",
"H = {key: ff.concatenate((Y2, X2, X2), calc_pulse_correlation_FF=True)\n",
" for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}"
" for (key, X2), (key, Y2) in zip(X2.items(), Y2.items())}\n",
"\n",
"# Note that we can also slice PulseSequence objects, eg\n",
"# X = H['primitive'][1:]\n",
"# or\n",
"# segments = [segment for segment in H['primitive']]"
]
},
{
Expand Down Expand Up @@ -246,7 +251,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.4"
}
},
"nbformat": 4,
Expand Down
31 changes: 27 additions & 4 deletions filter_functions/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,28 @@ def __eq__(self, other: object) -> bool:

return True

def __len__(self) -> int:
return len(self.dt)

def __getitem__(self, key) -> 'PulseSequence':
"""Return a slice of the PulseSequence."""
new_dt = np.atleast_1d(self.dt[key])
if not new_dt.size:
raise IndexError('Cannot create empty PulseSequence')

new = self.__class__(
c_opers=self.c_opers,
n_opers=self.n_opers,
c_oper_identifiers=self.c_oper_identifiers,
n_oper_identifiers=self.n_oper_identifiers,
c_coeffs=np.atleast_2d(self.c_coeffs.T[key]).T,
n_coeffs=np.atleast_2d(self.n_coeffs.T[key]).T,
dt=new_dt,
d=self.d,
basis=self.basis
)
return new

def __copy__(self) -> 'PulseSequence':
"""Return shallow copy of self"""
cls = self.__class__
Expand Down Expand Up @@ -1488,14 +1510,15 @@ def concatenate_without_filter_function(pulses: Iterable[PulseSequence],
concatenate: Concatenate PulseSequences including filter functions.
concatenate_periodic: Concatenate PulseSequences periodically.
"""
pulses = tuple(pulses)
try:
# Do awkward checking for type
if not all(hasattr(pls, 'c_opers') for pls in pulses):
raise TypeError('Can only concatenate PulseSequences!')
pulses = tuple(pulses)
except TypeError:
raise TypeError(f'Expected pulses to be iterable, not {type(pulses)}')

if not all(hasattr(pls, 'c_opers') for pls in pulses):
# Do awkward checking for type
raise TypeError('Can only concatenate PulseSequences!')

# Check if the Hamiltonians' shapes are compatible, ie the set of all
# shapes has length 1
if len(set(pulse.c_opers.shape[1:] for pulse in pulses)) != 1:
Expand Down
40 changes: 39 additions & 1 deletion tests/test_sequencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,44 @@ def test_concatenate_base(self):
pulse_2.omega = [3, 4]
ff.concatenate([pulse_1, pulse_2], calc_filter_function=True)

def test_slicing(self):
"""Tests _getitem__."""
for d, n in zip(rng.integers(2, 5, 20), rng.integers(3, 51, 20)):
pulse = testutil.rand_pulse_sequence(d, n)
parts = np.array([part for part in pulse], dtype=object).squeeze()

# Iterable
self.assertEqual(pulse, ff.concatenate(parts))
self.assertEqual(len(pulse), n)

# Slices
ix = rng.integers(1, n-1)
part = pulse[ix]
self.assertEqual(part, parts[ix])
self.assertEqual(pulse, ff.concatenate([pulse[:ix], pulse[ix:]]))

# More complicated slices
self.assertEqual(pulse[:len(pulse) // 2 * 2],
ff.concatenate([p for zipped in zip(pulse[::2], pulse[1::2])
for p in zipped]))
self.assertEqual(pulse[::-1], ff.concatenate(parts[::-1]))

# Boolean indices
ix = rng.integers(0, 2, size=n, dtype=bool)
if not ix.any():
with self.assertRaises(IndexError):
pulse[ix]
else:
self.assertEqual(pulse[ix], ff.concatenate(parts[ix]))

# Raises
with self.assertRaises(IndexError):
pulse[:0]
with self.assertRaises(IndexError):
pulse[1, 3]
with self.assertRaises(IndexError):
pulse['a']

def test_concatenate_without_filter_function(self):
"""Concatenate two Spin Echos without filter functions."""
tau = 10
Expand Down Expand Up @@ -103,7 +141,7 @@ def test_concatenate_without_filter_function(self):

with self.assertRaises(TypeError):
# Not iterable
pulse_sequence.concatenate_without_filter_function(pulse)
pulse_sequence.concatenate_without_filter_function(1)

with self.assertRaises(ValueError):
# Incompatible Hamiltonian shapes
Expand Down