Skip to content

Commit

Permalink
Merge pull request #224 from sot/improve-reduce-states
Browse files Browse the repository at this point in the history
Allow outputting all keys in reduce_states
  • Loading branch information
taldcroft authored Apr 19, 2022
2 parents b91afc3 + d1ac976 commit 3401c62
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
19 changes: 14 additions & 5 deletions kadi/commands/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,11 +1415,14 @@ def get_states(start=None, stop=None, state_keys=None, cmds=None, continuity=Non
return out


def reduce_states(states, state_keys, merge_identical=False):
def reduce_states(states, state_keys, merge_identical=False, all_keys=False):
"""
Reduce the input ``states`` so that only transitions in the ``state_keys``
are in the output. This also reduces the states table to only include
columns for those ``state_keys``.
are in the output.
By default this also reduces the states table to only include
columns for those ``state_keys``, but if the ``all_keys`` argument is
``True`` then all columns are included in the output.
By default, the output table will reflect every state transition
generated by commands even if this does not change the state value. This
Expand All @@ -1434,6 +1437,7 @@ def reduce_states(states, state_keys, merge_identical=False):
:param states: table of states
:param state_keys: notice transitions in this list of state keys
:param merge_identical: merge adjacent identical states
:param all_keys: if True, then all state keys are included in the output
:returns: numpy recarray of reduced states
"""
Expand All @@ -1458,8 +1462,13 @@ def reduce_states(states, state_keys, merge_identical=False):
# Master array if *any* key has a transition
has_transition |= has_transitions[key]

# Create output with only desired state keys and only states with a transition
out = states[['datestart', 'datestop', 'tstart', 'tstop'] + list(state_keys)][has_transition]
# Create output with desired state keys and only states with a transition
if all_keys:
out = states
else:
out = states[['datestart', 'datestop', 'tstart', 'tstop'] + list(state_keys)]
out = out[has_transition]

for dt in ('date', 't'):
out[f'{dt}stop'][:-1] = out[f'{dt}start'][1:]
out[f'{dt}stop'][-1] = states[f'{dt}stop'][-1]
Expand Down
24 changes: 18 additions & 6 deletions kadi/commands/tests/test_states.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os
import hashlib
from pathlib import Path
Expand Down Expand Up @@ -477,34 +478,45 @@ def test_get_continuity_fail():
assert 'did not find transitions' in str(err)


def test_reduce_states_merge_identical():
@pytest.mark.parametrize('all_keys', [True, False])
def test_reduce_states_merge_identical(all_keys):
tstart = np.arange(0, 5)
tstop = np.arange(1, 6)
datestart = DateTime(tstart).date
datestop = DateTime(tstop).date
dat0 = Table([datestart, datestop, tstart, tstop],
names=['datestart', 'datestop', 'tstart', 'tstop'])
reduce_states = functools.partial(states.reduce_states, all_keys=all_keys)

# Table with something that changes every time
dat = dat0.copy()
dat['vals'] = np.arange(5)
dat['val1'] = 1
dr = states.reduce_states(dat, ['vals', 'val1'], merge_identical=True)
assert np.all(dr[dat.colnames] == dat)
dat['val_not_key'] = 2 # Not part of the key
dr = reduce_states(dat, ['vals', 'val1'], merge_identical=True)
reduce_names = ['datestart', 'datestop', 'tstart', 'tstop', 'vals', 'val1']
if all_keys:
# All the original cols + trans_keys
assert dr.colnames == dat.colnames + ['trans_keys']
assert np.all(dr[dat.colnames] == dat)
else:
# No `val_not_key` column
assert dr.colnames == reduce_names + ['trans_keys']
assert np.all(dr[reduce_names] == dat[reduce_names])

# Table with nothing that changes
dat = dat0.copy()
dat['vals'] = 1
dat['val1'] = 1
dr = states.reduce_states(dat, ['vals', 'val1'], merge_identical=True)
dr = reduce_states(dat, ['vals', 'val1'], merge_identical=True)
assert len(dr) == 1
assert dr['datestart'][0] == dat['datestart'][0]
assert dr['datestop'][0] == dat['datestop'][-1]

# Table with edge changes
dat = dat0.copy()
dat['vals'] = [1, 0, 0, 0, 1]
dr = states.reduce_states(dat, ['vals'], merge_identical=True)
dr = reduce_states(dat, ['vals'], merge_identical=True)
assert len(dr) == 3
assert np.all(dr['datestart'] == dat['datestart'][[0, 1, 4]])
assert np.all(dr['datestop'] == dat['datestop'][[0, 3, 4]])
Expand All @@ -514,7 +526,7 @@ def test_reduce_states_merge_identical():
dat = dat0.copy()
dat['val1'] = [1, 0, 1, 1, 1]
dat['val2'] = [1, 1, 1, 1, 0]
dr = states.reduce_states(dat, ['val1', 'val2'], merge_identical=True)
dr = reduce_states(dat, ['val1', 'val2'], merge_identical=True)
assert len(dr) == 4
assert np.all(dr['datestart'] == dat['datestart'][[0, 1, 2, 4]])
assert np.all(dr['datestop'] == dat['datestop'][[0, 1, 3, 4]])
Expand Down

0 comments on commit 3401c62

Please sign in to comment.