diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 8567aa0..ab177ce 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -30,6 +30,6 @@ jobs: - name: Base Setup uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 - name: Install and test - run: + run: | pip install -e .[test] - pytest + pytest . diff --git a/comm/base_comm.py b/comm/base_comm.py index fb92408..30b2125 100644 --- a/comm/base_comm.py +++ b/comm/base_comm.py @@ -8,12 +8,13 @@ import logging from traitlets.utils.importstring import import_item +from traitlets.config import LoggingConfigurable logger = logging.getLogger('Comm') -class BaseComm: +class BaseComm(LoggingConfigurable): """Class for communicating between a Frontend and a Kernel Must be subclassed with a publish_msg method implementation which @@ -92,7 +93,6 @@ def open(self, data=None, metadata=None, buffers=None): def close(self, data=None, metadata=None, buffers=None, deleting=False): """Close the frontend-side version of this comm""" - from comm import get_comm_manager if self._closed: # only close once return @@ -107,6 +107,7 @@ def close(self, data=None, metadata=None, buffers=None, deleting=False): ) if not deleting: # If deleting, the comm can't be registered + from comm import get_comm_manager get_comm_manager().unregister_comm(self) def send(self, data=None, metadata=None, buffers=None): @@ -160,15 +161,17 @@ def handle_msg(self, msg): shell.events.trigger("post_execute") -class CommManager: +class CommManager(LoggingConfigurable): """Default CommManager singleton implementation for Comms in the Kernel""" # Public APIs - def __init__(self): + def __init__(self, *args, **kwargs): self.comms = {} self.targets = {} + super(CommManager, self).__init__(*args, **kwargs) + def register_target(self, target_name, f): """Register a callable f for a given target name diff --git a/tests/test_comm.py b/tests/test_comm.py index 9e8f8e1..a297f80 100644 --- a/tests/test_comm.py +++ b/tests/test_comm.py @@ -1,3 +1,6 @@ +from traitlets import Any +from traitlets.config import LoggingConfigurable + from comm.base_comm import CommManager, BaseComm @@ -7,6 +10,11 @@ def publish_msg(self, msg_type, data=None, metadata=None, buffers=None, **keys): pass +class CustomCommManager(CommManager): + + parent = Any() + + def test_comm_manager(): test = CommManager() assert test.targets == {} @@ -14,4 +22,9 @@ def test_comm_manager(): def test_base_comm(): test = MyComm() - assert test.target_name == "comm" \ No newline at end of file + assert test.target_name == "comm" + + +def test_custom_comm_manager(): + test = CustomCommManager(parent=LoggingConfigurable()) + assert test.parent is not None