Skip to content

Commit

Permalink
Merge changes in
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Dec 12, 2022
2 parents d2218c2 + c4af791 commit ba2b38c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 19 deletions.
16 changes: 7 additions & 9 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7634,15 +7634,13 @@ def log_event(self, topic: str | Collection[str], msg: Any) -> None:
logger.info("Plugin failed with exception", exc_info=True)

def _report_event(self, name, event):
for client in self.event_subscriber[name]:
self.report(
{
"op": "event",
"topic": name,
"event": event,
},
client=client,
)
msg = {
"op": "event",
"topic": name,
"event": event,
}
client_msgs = {client: [msg] for client in self.event_subscriber[name]}
self.send_all(client_msgs, worker_msgs={})

def subscribe_topic(self, topic, client):
self.event_subscriber[topic].add(client)
Expand Down
12 changes: 4 additions & 8 deletions distributed/shuffle/_scheduler_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class ShuffleSchedulerExtension(SchedulerPlugin):
heartbeats: defaultdict[ShuffleId, dict]
tombstones: set[ShuffleId]
erred_shuffles: dict[ShuffleId, Exception]
barriers: dict[ShuffleId, str]

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
Expand All @@ -59,7 +58,6 @@ def __init__(self, scheduler: Scheduler):
self.states = {}
self.tombstones = set()
self.erred_shuffles = {}
self.barriers = {}
self.scheduler.add_plugin(self)

def shuffle_ids(self) -> set[ShuffleId]:
Expand Down Expand Up @@ -95,7 +93,6 @@ def get(
output_workers = set()

name = barrier_key(id)
self.barriers[id] = name
mapping = {}

for ts in self.scheduler.tasks[name].dependents:
Expand Down Expand Up @@ -150,7 +147,7 @@ async def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
contact_workers = state.participating_workers.copy()
contact_workers.discard(worker)
affected_shuffles.add(shuffle_id)
name = self.barriers[shuffle_id]
name = barrier_key(shuffle_id)
barrier_task = self.scheduler.tasks.get(name)
if barrier_task:
barriers.append(barrier_task)
Expand Down Expand Up @@ -195,11 +192,11 @@ def transition(
) -> None:
if finish != "forgotten":
return
if key not in self.barriers.values():

if not key.startswith("shuffle-barrier-"):
return

shuffle_id = id_from_key(key)
if shuffle_id not in self.states:
return
participating_workers = self.states[shuffle_id].participating_workers
worker_msgs = {
worker: [
Expand Down Expand Up @@ -227,7 +224,6 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None:
self.tombstones.add(id)
del self.states[id]
self.erred_shuffles.pop(id, None)
del self.barriers[id]
with contextlib.suppress(KeyError):
del self.heartbeats[id]

Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,5 @@ def barrier_key(shuffle_id: ShuffleId) -> str:


def id_from_key(key: str) -> ShuffleId:
assert _BARRIER_PREFIX in key
assert key.startswith(_BARRIER_PREFIX)
return ShuffleId(key.replace(_BARRIER_PREFIX, ""))
1 change: 0 additions & 1 deletion distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ async def clean_scheduler(
while extension.states and not deadline.expired:
await asyncio.sleep(interval)
assert not extension.states
assert not extension.barriers
assert not extension.heartbeats


Expand Down
30 changes: 30 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6763,6 +6763,36 @@ def log_scheduler(dask_scheduler):
assert events[1][1] == ("alice", "bob")


@gen_cluster(client=True, nthreads=[])
async def test_log_event_multiple_clients(c, s):
async with Client(s.address, asynchronous=True) as c2, Client(
s.address, asynchronous=True
) as c3:
received_events = []

def get_event_handler(handler_id):
def handler(event):
received_events.append((handler_id, event))

return handler

c.subscribe_topic("test-topic", get_event_handler(1))
c2.subscribe_topic("test-topic", get_event_handler(2))

while len(s.event_subscriber["test-topic"]) != 2:
await asyncio.sleep(0.01)

with captured_logger(logging.getLogger("distributed.client")) as logger:
await c.log_event("test-topic", {})

while len(received_events) < 2:
await asyncio.sleep(0.01)

assert len(received_events) == 2
assert {handler_id for handler_id, _ in received_events} == {1, 2}
assert "ValueError" not in logger.getvalue()


@gen_cluster(client=True)
async def test_annotations_task_state(c, s, a, b):
da = pytest.importorskip("dask.array")
Expand Down

0 comments on commit ba2b38c

Please sign in to comment.