diff --git a/ophyd/tests/test_utils.py b/ophyd/tests/test_utils.py index faf1ff6cc..2fd1a26d4 100644 --- a/ophyd/tests/test_utils.py +++ b/ophyd/tests/test_utils.py @@ -133,3 +133,44 @@ def get(self): def test_set_signal_to_None(): s = Signal(value="0", name="bob") s.set(None).wait(timeout=1) + + +@pytest.mark.parametrize( + "a, b, enums, atol, rtol, expected", + [ + # compare enums + [0, 0, "NO YES".split(), None, None, True], + [0, 1, "NO YES".split(), None, None, False], + ["NO", 0, "NO YES".split(), None, None, True], + ["NO", 1, "NO YES".split(), None, None, False], + [0, "YES", "NO YES".split(), None, None, False], + [1, "YES", "NO YES".split(), None, None, True], + ["NO", "YES", "NO YES".split(), None, None, False], + ["YES", "YES", "NO YES".split(), None, None, True], + # compare array shapes + [[1, 2, 3], [1, 2, 3], [], None, None, True], # identical + [[1, 2, 3], [0, 0, 0], [], None, None, False], # b has different values + [[1, 2, 3], [1, 2, 3, 0], [], None, None, False], # different shape + # [[1, 2, 3], [1, 2, 3, 0], [], None, 1e-8, False], # raises ValueError + [5, [1, 2, 3], [], None, None, False], # not the same type + [[1, 2, 3], 5, [], None, None, False], # not the same type + # numpy arrays + [[1, 2, 3], np.array([1, 2, 3]), [], None, None, True], # identical + # tuple + [(1, 2, 3), np.array([1, 2, 3]), [], None, None, True], # identical + # with tolerance + [[1, 2, 3.0], [1, 2, 3.12345], [], 0.01, None, False], + [[1, 2, 3.0], [1, 2, 3.12345], [], 0.2, None, True], + # absolute tolerance + [3, 3.12345, [], None, 0.01, False], + [3, 3.12345, [], None, 0.2, True], + [True, False, [], None, None, False], # booleans + [1, 1, [], None, None, True], # integers + # relative tolerance + [3, 3.12345, [], 0.01, None, False], + [3, 3.12345, [], 0.2, None, True], + ], +) +def test_compare_maybe_enum(a, b, enums, atol, rtol, expected): + result = epics_utils._compare_maybe_enum(a, b, enums, atol, rtol) + assert result == expected diff --git a/ophyd/utils/epics_pvs.py b/ophyd/utils/epics_pvs.py index 1edd66adf..ce3733e0c 100644 --- a/ophyd/utils/epics_pvs.py +++ b/ophyd/utils/epics_pvs.py @@ -267,7 +267,10 @@ def _wait_for_value(signal, val, poll_time=0.01, timeout=10, rtol=None, atol=Non TimeoutError if timeout is exceeded """ expiration_time = ttime.time() + timeout if timeout is not None else None - current_value = signal.get() + get_kwargs = {} + if isinstance(val, (list, np.ndarray, tuple)): + get_kwargs["count"] = len(val) + current_value = signal.get(**get_kwargs) if atol is None and hasattr(signal, "tolerance"): atol = signal.tolerance @@ -305,7 +308,7 @@ def _wait_for_value(signal, val, poll_time=0.01, timeout=10, rtol=None, atol=Non ttime.sleep(poll_time) if poll_time < 0.1: poll_time *= 2 # logarithmic back-off - current_value = signal.get() + current_value = signal.get(**get_kwargs) if expiration_time is not None and ttime.time() > expiration_time: raise TimeoutError( "Attempted to set %r to value %r and timed "