Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2: Refactor and Test DeviceCollector Error Handling #1151

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions ophyd/v2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,25 +327,24 @@ async def _on_exit(self) -> None:

async def _wait_for_tasks(self, tasks: Dict[asyncio.Task, str]):
done, pending = await asyncio.wait(tasks, timeout=self._timeout)

# Handle all devices where connection has timed out
if pending:
msg = f"{len(pending)} Devices did not connect:"
logging.error(f"{len(pending)} Devices did not connect:")
for t in pending:
t.cancel()
with suppress(Exception):
await t
e = t.exception()
msg += f"\n {tasks[t]}: {type(e).__name__}"
lines = str(e).splitlines()
if len(lines) <= 1:
msg += f": {e}"
else:
msg += "".join(f"\n {line}" for line in lines)
logging.error(msg)
logging.exception(f" {tasks[t]}:", exc_info=t.exception())

# Handle all devices where connection has raised an error before
# timeout
raised = [t for t in done if t.exception()]
if raised:
logging.error(f"{len(raised)} Devices raised an error:")
for t in raised:
logging.exception(f" {tasks[t]}:", exc_info=t.exception())

if pending or raised:
raise NotConnected("Not all Devices connected")

Expand Down
162 changes: 161 additions & 1 deletion ophyd/v2/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import logging
import re
import time
import traceback
from enum import Enum
from typing import Any, Callable, Sequence, Tuple, Type
from typing import Any, Callable, Sequence, Tuple, Type, cast
from unittest.mock import Mock

import bluesky.plan_stubs as bps
Expand All @@ -18,10 +19,12 @@
Device,
DeviceCollector,
DeviceVector,
NotConnected,
Signal,
SignalBackend,
SignalRW,
SimSignalBackend,
StandardReadable,
T,
get_device_children,
set_and_wait_for_value,
Expand Down Expand Up @@ -195,6 +198,19 @@ async def connect(self, sim=False):
self.connected = True


class DummyDeviceThatErrorsWhenConnecting(Device):
async def connect(self, sim: bool = False):
raise IOError("Connection failed")


class DummyDeviceThatTimesOutWhenConnecting(StandardReadable):
async def connect(self, sim: bool = False):
try:
await asyncio.Future()
except asyncio.CancelledError:
raise NotConnected("source: foo")


class DummyDeviceGroup(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyBaseDevice()
Expand All @@ -205,6 +221,25 @@ def __init__(self, name: str) -> None:
self.set_name(name)


class DummyDeviceGroupThatTimesOut(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyDeviceThatTimesOutWhenConnecting()
self.set_name(name)


class DummyDeviceGroupThatErrors(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyDeviceThatErrorsWhenConnecting()
self.set_name(name)


class DummyDeviceGroupThatErrorsAndTimesOut(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyDeviceThatErrorsWhenConnecting()
self.child2 = DummyDeviceThatTimesOutWhenConnecting()
self.set_name(name)


def test_get_device_children():
parent = DummyDeviceGroup("parent")

Expand Down Expand Up @@ -246,6 +281,131 @@ async def test_device_with_device_collector():
assert parent.dict_with_children[123].connected


@pytest.mark.parametrize(
"device_constructor",
[
DummyDeviceThatErrorsWhenConnecting,
DummyDeviceThatTimesOutWhenConnecting,
DummyDeviceGroupThatErrors,
DummyDeviceGroupThatTimesOut,
DummyDeviceGroupThatErrorsAndTimesOut,
],
)
async def test_device_collector_propagates_errors_and_timeouts(
device_constructor: Callable[[str], Device]
):
await _assert_failing_device_does_not_connect(device_constructor)


@pytest.mark.parametrize(
"device_constructor_1,device_constructor_2",
[
(DummyDeviceThatErrorsWhenConnecting, DummyDeviceThatTimesOutWhenConnecting),
(DummyDeviceGroupThatErrors, DummyDeviceGroupThatTimesOut),
(DummyDeviceGroupThatErrors, DummyDeviceGroupThatErrorsAndTimesOut),
(DummyDeviceThatErrorsWhenConnecting, DummyDeviceGroupThatErrors),
],
)
async def test_device_collector_propagates_errors_and_timeouts_from_multiple_devices(
device_constructor_1: Callable[[str], Device],
device_constructor_2: Callable[[str], Device],
):
await _assert_failing_devices_do_not_connect(
device_constructor_1,
device_constructor_2,
)


async def test_device_collector_logs_exceptions_for_raised_errors(
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.INFO)
await _assert_failing_device_does_not_connect(DummyDeviceGroupThatErrors)
assert caplog.records[0].message == "1 Devices raised an error:"
assert caplog.records[1].message == " should_fail:"
assert_exception_type_and_message(
caplog.records[1],
OSError,
"Connection failed",
)


async def test_device_collector_logs_exceptions_for_timeouts(
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.INFO)
await _assert_failing_device_does_not_connect(DummyDeviceGroupThatTimesOut)
assert caplog.records[0].message == "1 Devices did not connect:"
assert caplog.records[1].message == " should_fail:"
assert_exception_type_and_message(
caplog.records[1],
NotConnected,
"child1: source: foo",
)


async def test_device_collector_logs_exceptions_for_multiple_devices(
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.INFO)
await _assert_failing_devices_do_not_connect(
DummyDeviceGroupThatErrorsAndTimesOut, DummyDeviceGroupThatErrors
)
assert caplog.records[0].message == "1 Devices did not connect:"
assert caplog.records[1].message == " should_fail_1:"
assert_exception_type_and_message(
caplog.records[1],
OSError,
"Connection failed",
)
assert caplog.records[2].message == "1 Devices raised an error:"
assert caplog.records[3].message == " should_fail_2:"
assert_exception_type_and_message(
caplog.records[3],
OSError,
"Connection failed",
)


async def _assert_failing_device_does_not_connect(
device_constructor: Callable[[str], Device]
) -> pytest.ExceptionInfo[NotConnected]:
with pytest.raises(NotConnected) as excepton_info:
async with DeviceCollector(
sim=False,
timeout=1.0,
):
should_fail = device_constructor("should_fail") # noqa: F841
return excepton_info


async def _assert_failing_devices_do_not_connect(
device_constructor_1: Callable[[str], Device],
device_constructor_2: Callable[[str], Device],
) -> pytest.ExceptionInfo[NotConnected]:
with pytest.raises(NotConnected) as excepton_info:
async with DeviceCollector(
sim=False,
timeout=1.0,
):
should_fail_1 = device_constructor_1("should_fail_1") # noqa: F841
should_fail_2 = device_constructor_2("should_fail_2") # noqa: F841
return excepton_info


def assert_exception_type_and_message(
record: logging.LogRecord,
expected_type: Type[Exception],
expected_message: str,
):
exception_type, exception, _ = cast(
Tuple[Type[Exception], Exception, str],
record.exc_info,
)
assert expected_type is exception_type
assert (expected_message,) == exception.args


async def normal_coroutine(time: float):
await asyncio.sleep(time)

Expand Down
36 changes: 22 additions & 14 deletions ophyd/v2/tests/test_epicsdemo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Dict
from unittest.mock import Mock, call, patch
import logging
from typing import Dict, Tuple, Type, cast
from unittest.mock import Mock, call

import pytest
from bluesky.protocols import Reading
Expand Down Expand Up @@ -130,18 +131,25 @@ async def test_mover_disconncted():
assert m.name == "mover"


async def test_sensor_disconncted():
with patch("ophyd.v2.core.logging") as mock_logging:
with pytest.raises(NotConnected, match="Not all Devices connected"):
async with DeviceCollector(timeout=0.1):
s = epicsdemo.Sensor("ca://PRE:", name="sensor")
mock_logging.error.assert_called_once_with(
"""\
1 Devices did not connect:
s: NotConnected
value: ca://PRE:Value
mode: ca://PRE:Mode"""
)
async def test_sensor_disconncted(caplog: pytest.LogCaptureFixture):
caplog.set_level(logging.INFO)
with pytest.raises(NotConnected, match="Not all Devices connected"):
async with DeviceCollector(timeout=0.1):
s = epicsdemo.Sensor("ca://PRE:", name="sensor")

# Check log messages
assert caplog.records[0].message == "1 Devices did not connect:"
assert caplog.records[1].message == " s:"

# Check logged exception
exception_type, exception, _ = cast(
Tuple[Type[Exception], Exception, str],
caplog.records[1].exc_info,
)
assert NotConnected is exception_type
assert ("value: ca://PRE:Value", "mode: ca://PRE:Mode") == exception.args

# Ensure correct device
assert s.name == "sensor"


Expand Down