Skip to content

Commit

Permalink
Remove get_communicator calls
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 21, 2024
1 parent 28cdb1c commit 5d59e6a
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 62 deletions.
8 changes: 3 additions & 5 deletions src/aiida/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import abc
import typing as t


if t.TYPE_CHECKING:
from aiida.manage.configuration.profile import Profile
from plumpy.coordinator import Coordinator

__all__ = ('Broker',)

Expand All @@ -20,11 +22,7 @@ def __init__(self, profile: 'Profile') -> None:
self._profile = profile

@abc.abstractmethod
def get_communicator(self):
"""Return an instance of :class:`kiwipy.Communicator`."""

@abc.abstractmethod
def get_coordinator(self):
def get_coordinator(self) -> 'Coordinator':
"""Return an instance of coordinator."""

@abc.abstractmethod
Expand Down
17 changes: 7 additions & 10 deletions src/aiida/brokers/rabbitmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, profile: Profile) -> None:
:param profile: The profile.
"""
self._profile = profile
self._communicator: 'RmqThreadCommunicator' | None = None
self._communicator: 'RmqThreadCommunicator | None' = None
self._prefix = f'aiida-{self._profile.uuid}'

def __str__(self):
Expand All @@ -48,19 +48,16 @@ def close(self):

def iterate_tasks(self):
"""Return an iterator over the tasks in the launch queue."""
for task in self.get_communicator().task_queue(get_launch_queue_name(self._prefix)):
for task in self.get_coordinator().communicator.task_queue(get_launch_queue_name(self._prefix)):
yield task

def get_communicator(self) -> 'RmqThreadCommunicator':
def get_coordinator(self):
if self._communicator is None:
self._communicator = self._create_communicator()
# Check whether a compatible version of RabbitMQ is being used.
self.check_rabbitmq_version()

return self._communicator

def get_coordinator(self):
coordinator = RmqCoordinator(self.get_communicator())
coordinator = RmqCoordinator(self._communicator)

return coordinator

Expand All @@ -70,7 +67,7 @@ def _create_communicator(self) -> 'RmqThreadCommunicator':

from aiida.orm.utils import serialize

self._communicator = RmqThreadCommunicator.connect(
_communicator = RmqThreadCommunicator.connect(
connection_params={'url': self.get_url()},
message_exchange=get_message_exchange_name(self._prefix),
encoder=functools.partial(serialize.serialize, encoding='utf-8'),
Expand All @@ -84,7 +81,7 @@ def _create_communicator(self) -> 'RmqThreadCommunicator':
testing_mode=self._profile.is_test_profile,
)

return self._communicator
return _communicator

def check_rabbitmq_version(self):
"""Check the version of RabbitMQ that is being connected to and emit warning if it is not compatible."""
Expand Down Expand Up @@ -128,4 +125,4 @@ def get_rabbitmq_version(self):
"""
from packaging.version import parse

return parse(self.get_communicator().server_properties['version'])
return parse(self.get_coordinator().communicator.server_properties['version'])
11 changes: 6 additions & 5 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import click

from aiida.brokers.broker import Broker
from aiida.cmdline.commands.cmd_verdi import verdi
from aiida.cmdline.params import arguments, options, types
from aiida.cmdline.utils import decorators, echo
Expand Down Expand Up @@ -416,7 +417,7 @@ def process_play(processes, all_entries, timeout, wait):
@decorators.with_dbenv()
@decorators.with_broker
@decorators.only_if_daemon_running(echo.echo_warning, 'daemon is not running, so process may not be reachable')
def process_watch(broker, processes, most_recent_node):
def process_watch(broker: Broker, processes, most_recent_node):
"""Watch the state transitions of processes.
Watch the state transitions for one or multiple running processes."""
Expand All @@ -436,7 +437,7 @@ def process_watch(broker, processes, most_recent_node):

from kiwipy import BroadcastFilter

def _print(communicator, body, sender, subject, correlation_id):
def _print(coordinator, body, sender, subject, correlation_id):
"""Format the incoming broadcast data into a message and echo it to stdout."""
if body is None:
body = 'No message specified'
Expand All @@ -446,7 +447,7 @@ def _print(communicator, body, sender, subject, correlation_id):

echo.echo(f'Process<{sender}> [{subject}|{correlation_id}]: {body}')

communicator = broker.get_communicator()
coordinator = broker.get_coordinator()
echo.echo_report('watching for broadcasted messages, press CTRL+C to stop...')

if most_recent_node:
Expand All @@ -457,7 +458,7 @@ def _print(communicator, body, sender, subject, correlation_id):
echo.echo_error(f'Process<{process.pk}> is already terminated')
continue

communicator.add_broadcast_subscriber(BroadcastFilter(_print, sender=process.pk))
coordinator.add_broadcast_subscriber(BroadcastFilter(_print, sender=process.pk))

try:
# Block this thread indefinitely until interrupt
Expand All @@ -467,7 +468,7 @@ def _print(communicator, body, sender, subject, correlation_id):
echo.echo('') # add a new line after the interrupt character
echo.echo_report('received interrupt, exiting...')
try:
communicator.close()
coordinator.close()
except RuntimeError:
pass

Expand Down
6 changes: 4 additions & 2 deletions src/aiida/cmdline/commands/cmd_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aiida.cmdline.commands.cmd_devel import verdi_devel
from aiida.cmdline.params import arguments, options
from aiida.cmdline.utils import decorators, echo, echo_tabulate
from aiida.manage.manager import Manager

if t.TYPE_CHECKING:
import requests
Expand Down Expand Up @@ -131,12 +132,13 @@ def with_client(ctx, wrapped, _, args, kwargs):

@cmd_rabbitmq.command('server-properties')
@decorators.with_manager
def cmd_server_properties(manager):
def cmd_server_properties(manager: Manager):
"""List the server properties."""
import yaml

data = {}
for key, value in manager.get_communicator().server_properties.items():
# FIXME: server_properties as an common API for coordinator?
for key, value in manager.get_coordinator().communicator.server_properties.items():
data[key] = value.decode('utf-8') if isinstance(value, bytes) else value
click.echo(yaml.dump(data, indent=4))

Expand Down
2 changes: 1 addition & 1 deletion src/aiida/cmdline/commands/cmd_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def verdi_status(print_traceback, no_rmq):

if broker:
try:
broker.get_communicator()
broker.get_coordinator()
except Exception as exc:
message = f'Unable to connect to broker: {broker}'
print_status(ServiceStatus.ERROR, 'broker', message, exception=exc, print_traceback=print_traceback)
Expand Down
23 changes: 12 additions & 11 deletions src/aiida/engine/processes/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Optional, Union

import kiwipy
from plumpy.coordinator import Coordinator

from aiida.orm import Node, load_node

Expand All @@ -28,36 +29,36 @@ def __init__(
pk: int,
loop: Optional[asyncio.AbstractEventLoop] = None,
poll_interval: Union[None, int, float] = None,
communicator: Optional[kiwipy.Communicator] = None,
coordinator: Optional[Coordinator] = None,
):
"""Construct a future for a process node being finished.
If a None poll_interval is supplied polling will not be used.
If a communicator is supplied it will be used to listen for broadcast messages.
If a coordinator is supplied it will be used to listen for broadcast messages.
:param pk: process pk
:param loop: An event loop
:param poll_interval: optional polling interval, if None, polling is not activated.
:param communicator: optional communicator, if None, will not subscribe to broadcasts.
:param coordinator: optional coordinator, if None, will not subscribe to broadcasts.
"""
from .process import ProcessState

# create future in specified event loop
loop = loop if loop is not None else asyncio.get_event_loop()
super().__init__(loop=loop)

assert not (poll_interval is None and communicator is None), 'Must poll or have a communicator to use'
assert not (poll_interval is None and coordinator is None), 'Must poll or have a coordinator to use'

node = load_node(pk=pk)

if node.is_terminated:
self.set_result(node)
else:
self._communicator = communicator
self._coordinator = coordinator
self.add_done_callback(lambda _: self.cleanup())

# Try setting up a filtered broadcast subscriber
if self._communicator is not None:
if self._coordinator is not None:

def _subscriber(*args, **kwargs):
if not self.done():
Expand All @@ -66,17 +67,17 @@ def _subscriber(*args, **kwargs):
broadcast_filter = kiwipy.BroadcastFilter(_subscriber, sender=pk)
for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]:
broadcast_filter.add_subject_filter(f'state_changed.*.{state.value}')
self._broadcast_identifier = self._communicator.add_broadcast_subscriber(broadcast_filter)
self._broadcast_identifier = self._coordinator.add_broadcast_subscriber(broadcast_filter)

# Start polling
if poll_interval is not None:
loop.create_task(self._poll_process(node, poll_interval))

def cleanup(self) -> None:
"""Clean up the future by removing broadcast subscribers from the communicator if it still exists."""
if self._communicator is not None:
self._communicator.remove_broadcast_subscriber(self._broadcast_identifier)
self._communicator = None
"""Clean up the future by removing broadcast subscribers from the coordinator if it still exists."""
if self._coordinator is not None:
self._coordinator.remove_broadcast_subscriber(self._broadcast_identifier)
self._coordinator = None
self._broadcast_identifier = None

async def _poll_process(self, node: Node, poll_interval: Union[int, float]) -> None:
Expand Down
18 changes: 0 additions & 18 deletions src/aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,24 +326,6 @@ def get_persister(self) -> 'AiiDAPersister':

return self._persister

def get_communicator(self) -> 'RmqThreadCommunicator':
"""Return the communicator
:return: a global communicator instance
"""
from aiida.common import ConfigurationError

broker = self.get_broker()

if broker is None:
assert self._profile is not None
raise ConfigurationError(
f'profile `{self._profile.name}` does not provide a communicator because it does not define a broker'
)

return broker.get_communicator()

def get_coordinator(self) -> 'Coordinator':
"""Return the coordinator
Expand Down
10 changes: 5 additions & 5 deletions tests/brokers/test_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def raise_connection_error():
broker = manager.get_broker()
assert 'RabbitMQ v' in str(broker)

monkeypatch.setattr(broker, 'get_communicator', raise_connection_error)
monkeypatch.setattr(broker, 'get_coordinator', raise_connection_error)
assert 'RabbitMQ @' in str(broker)


Expand Down Expand Up @@ -92,14 +92,14 @@ def test_communicator(url):
RmqThreadCommunicator.connect(connection_params={'url': url})


def test_add_rpc_subscriber(communicator):
def test_add_rpc_subscriber(coordinator):
"""Test ``add_rpc_subscriber``."""
communicator.add_rpc_subscriber(None)
coordinator.add_rpc_subscriber(None)


def test_add_broadcast_subscriber(communicator):
def test_add_broadcast_subscriber(coordinator):
"""Test ``add_broadcast_subscriber``."""
communicator.add_broadcast_subscriber(None)
coordinator.add_broadcast_subscriber(None)


@pytest.mark.usefixtures('aiida_profile_clean')
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aiida.common.folders import Folder
from aiida.common.links import LinkType
from aiida.manage.configuration import Profile, get_config, load_profile
from aiida.manage.manager import Manager

if t.TYPE_CHECKING:
from aiida.manage.configuration.config import Config
Expand Down Expand Up @@ -540,9 +541,9 @@ def backend(manager):


@pytest.fixture
def communicator(manager):
"""Get the ``Communicator`` instance of the currently loaded profile to communicate with RabbitMQ."""
return manager.get_communicator()
def coordinator(manager: Manager):
"""Get the ``Coordinator`` instance of the currently loaded profile to communicate with RabbitMQ."""
return manager.get_coordinator()


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/engine/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_calculation_future_broadcasts(self):

# No polling
future = processes.futures.ProcessFuture(
pk=process.pid, loop=runner.loop, communicator=manager.get_communicator()
pk=process.pid, loop=runner.loop, coordinator=manager.get_coordinator()
)

run(process)
Expand Down
6 changes: 5 additions & 1 deletion tests/manage/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ def test_disconnect():
demonstrate the problematic behavior. Getting the communicator and then disconnecting it (through calling
:meth:`aiida.manage.manager.Manager.reset_profile`) works fine. However, if a process is a run before closing it,
for example running a calcfunction, the closing of the communicator will raise a ``TimeoutError``.
The problem was solved by:
- https://github.com/aiidateam/aiida-core/pull/6672
- https://github.com/mosquito/aiormq/pull/208
"""
from aiida.manage import get_manager

manager = get_manager()
manager.get_communicator()
_ = manager.get_coordinator()
manager.reset_profile() # This returns just fine

result, node = add_calcfunction.run_get_node(1)
Expand Down

0 comments on commit 5d59e6a

Please sign in to comment.