Skip to content

Commit

Permalink
Use pydantic for messages sent over websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
JHolba committed Aug 20, 2024
1 parent ac32a87 commit 080d7d4
Show file tree
Hide file tree
Showing 21 changed files with 829 additions and 698 deletions.
88 changes: 46 additions & 42 deletions src/_ert_forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import datetime
import logging
import queue
import threading
from pathlib import Path
from typing import Any

import orjson

from _ert.threading import ErtThread
from _ert_forward_model_runner.client import (
Expand All @@ -25,6 +24,14 @@
Start,
)
from _ert_forward_model_runner.reporting.statemachine import StateMachine
from ert.ensemble_evaluator import message
from ert.ensemble_evaluator.message import (
ForwardModelChecksum,
ForwardModelFailure,
ForwardModelRunning,
ForwardModelStart,
ForwardModelSuccess,
)

_FORWARD_MODEL_START = "com.equinor.ert.forward_model_job.start"
_FORWARD_MODEL_RUNNING = "com.equinor.ert.forward_model_job.running"
Expand All @@ -43,10 +50,10 @@ class Event(Reporter):
The Event reporter forwards events, coming from the running job, added with
"report" to the given connection information.
An Init event must provided as the first message, which starts reporting,
An Init event must be provided as the first message, which starts reporting,
and a Finish event will signal the reporter that the last event has been reported.
If event fails to be sent (eg. due to connection error) it does not proceed to the
If event fails to be sent (e.g. due to connection error) it does not proceed to the
next event but instead tries to re-send the same event.
Whenever the Finish event (when all the jobs have exited) is provided
Expand All @@ -71,9 +78,9 @@ def __init__(self, evaluator_url, token=None, cert_path=None):

self._ens_id = None
self._real_id = None
self._event_queue = queue.Queue()
self._event_queue: queue.Queue[message.Message] | None = queue.Queue()
self._event_publisher_thread = ErtThread(target=self._event_publisher)
self._sentinel = object() # notifying the queue's ended
self._sentinel = None
self._timeout_timestamp = None
self._timestamp_lock = threading.Lock()
# seconds to timeout the reporter the thread after Finish() was received
Expand Down Expand Up @@ -102,7 +109,9 @@ def _event_publisher(self):
if event is self._sentinel:
break
try:
client.send(event)
logger.debug(f"sending {type(event.event)}")
client.send(event.model_dump_json())
logger.debug(f"sent {type(event.event)}")
event = None
except ClientConnectionError as exception:
# Possible intermittent failure, we retry sending the event
Expand All @@ -116,11 +125,9 @@ def _event_publisher(self):
def report(self, msg):
self._statemachine.transition(msg)

def _dump_event(self, event: Any = None):
event["time"] = datetime.datetime.now()
event["data"] = event.get("data", None)
logger.debug(f'Schedule "{event["type"]}" for delivery')
self._event_queue.put(orjson.dumps(event))
def _dump_event(self, event: message.ALL_EVENTS):
logger.debug(f'Schedule "{type(event)}" for delivery')
self._event_queue.put(message.Message(ensemble=self._ens_id, event=event))

def _init_handler(self, msg):
self._ens_id = str(msg.ens_id)
Expand All @@ -131,32 +138,26 @@ def _job_handler(self, msg: Message):
assert msg.job
job_name = msg.job.name()
job_msg = {
_JOB_MSG_TYPE: None,
"ensemble": self._ens_id,
"real": self._real_id,
"fm_step": str(msg.job.index),
"index": str(msg.job.index),
}
if isinstance(msg, Start):
logger.debug(f"Job {job_name} was successfully started")
start_msg = job_msg.copy()
start_msg[_JOB_MSG_TYPE] = _FORWARD_MODEL_START
start_msg["data"] = {
"stdout": str(Path(msg.job.std_out).resolve()),
"stderr": str(Path(msg.job.std_err).resolve()),
}
self._dump_event(start_msg)
start_msg["std_out"] = str(Path(msg.job.std_out).resolve())
start_msg["std_err"] = str(Path(msg.job.std_err).resolve())
self._dump_event(ForwardModelStart(**start_msg))
if not msg.success():
logger.error(f"Job {job_name} FAILED to start")
fail_msg = job_msg.copy()
fail_msg[_JOB_MSG_TYPE] = _FORWARD_MODEL_FAILURE
fail_msg["data"] = {"error_msg": msg.error_message}
self._dump_event(fail_msg)
self._dump_event(
ForwardModelFailure(**job_msg, error_msg=msg.error_message)
)

elif isinstance(msg, Exited):
if msg.success():
logger.debug(f"Job {job_name} exited successfully")
job_msg[_JOB_MSG_TYPE] = _FORWARD_MODEL_SUCCESS
self._dump_event(ForwardModelSuccess(**job_msg))
else:
logger.error(
_JOB_EXIT_FAILED_STRING.format(
Expand All @@ -165,21 +166,25 @@ def _job_handler(self, msg: Message):
error_message=msg.error_message,
)
)
job_msg[_JOB_MSG_TYPE] = _FORWARD_MODEL_FAILURE
job_msg["data"] = {
"exit_code": msg.exit_code,
"error_msg": msg.error_message,
}
self._dump_event(job_msg)
self._dump_event(
ForwardModelFailure(
**job_msg, exit_code=msg.exit_code, error_msg=msg.error_message
)
)

elif isinstance(msg, Running):
logger.debug(f"{job_name} job is running")
job_msg[_JOB_MSG_TYPE] = _FORWARD_MODEL_RUNNING
job_msg["data"] = {
"max_memory_usage": msg.memory_status.max_rss,
"current_memory_usage": msg.memory_status.rss,
}
self._dump_event(job_msg)
self._dump_event(
ForwardModelRunning(
**job_msg,
max_memory_usage=msg.memory_status.max_rss,
current_memory_usage=msg.memory_status.rss,
)
)

def _finished_handler(self, msg):
self._event_queue.put(self._sentinel)
Expand All @@ -191,11 +196,10 @@ def _finished_handler(self, msg):
self._event_publisher_thread.join()

def _checksum_handler(self, msg):
job_msg = {
_JOB_MSG_TYPE: _FORWARD_MODEL_CHECKSUM,
"ensemble": self._ens_id,
"real": self._real_id,
_RUN_PATH: msg.run_path,
"data": msg.data,
}
self._dump_event(job_msg)
fm_checksum = ForwardModelChecksum(
ensemble=self._ens_id,
real=self._real_id,
run_path=msg.run_path,
data=msg.data,
)
self._dump_event(fm_checksum)
50 changes: 27 additions & 23 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import asyncio
import logging
import traceback
import uuid
from dataclasses import dataclass
from datetime import datetime
from functools import partialmethod
from typing import (
Any,
Expand All @@ -18,8 +17,6 @@
Union,
)

import orjson

from _ert_forward_model_runner.client import Client
from ert.config import ForwardModelStep, QueueConfig
from ert.run_arg import RunArg
Expand All @@ -28,6 +25,7 @@
from ._wait_for_evaluator import wait_for_evaluator
from .config import EvaluatorServerConfig
from .identifiers import EVTYPE_ENSEMBLE_FAILED, EVTYPE_ENSEMBLE_STARTED
from .message import Message
from .snapshot import (
ForwardModel,
RealizationSnapshot,
Expand Down Expand Up @@ -122,6 +120,7 @@ def __post_init__(self) -> None:
self._status_tracker = _EnsembleStateTracker(self.snapshot.status)
else:
self._status_tracker = _EnsembleStateTracker()
self.scheduler_started: asyncio.Event = asyncio.Event()

@property
def active_reals(self) -> Sequence[Realization]:
Expand Down Expand Up @@ -151,7 +150,7 @@ def _create_snapshot(self) -> Snapshot:
def get_successful_realizations(self) -> List[int]:
return self.snapshot.get_successful_realizations()

def update_snapshot(self, events: List[Dict]) -> Snapshot:
def update_snapshot(self, events: List[Message]) -> Snapshot:
snapshot_mutate_event = Snapshot()
for event in events:
snapshot_mutate_event = snapshot_mutate_event.update_from_cloudevent(
Expand All @@ -165,27 +164,23 @@ def update_snapshot(self, events: List[Dict]) -> Snapshot:
async def send_event( # noqa: PLR6301
self,
url: str,
event: Dict,
event: Message,
token: Optional[str] = None,
cert: Optional[Union[str, bytes]] = None,
retries: int = 10,
) -> None:
async with Client(url, token, cert, max_retries=retries) as client:
await client._send(orjson.dumps(event))
await client._send(event.model_dump_json())

def generate_event_creator(
self, experiment_id: Optional[str] = None
) -> Callable[[str, Optional[int]], Dict]:
def event_builder(status: str, real_id: Optional[int] = None) -> Dict:
def generate_event_creator(self) -> Callable[[str], Message]:
def event_builder(status: str) -> Message:
msg = {
"type": status,
"time": datetime.now(),
"event": {
"event_type": status,
},
"ensemble": self.id_,
"id": str(uuid.uuid1()),
}
if real_id is not None:
msg["real"] = str(real_id)
return msg
return Message.model_validate(msg)

return event_builder

Expand Down Expand Up @@ -213,8 +208,7 @@ async def evaluate(self, config: EvaluatorServerConfig) -> None:

async def _evaluate_inner( # pylint: disable=too-many-branches
self,
event_unary_send: Callable[[Dict], Awaitable[None]],
experiment_id: Optional[str] = None,
event_unary_send: Callable[[Message], Awaitable[None]],
) -> None:
"""
This (inner) coroutine does the actual work of evaluating the ensemble. It
Expand All @@ -228,7 +222,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
is a function (or bound method) that only takes a CloudEvent as a positional
argument.
"""
event_creator = self.generate_event_creator(experiment_id=experiment_id)
event_creator = self.generate_event_creator()

if not self.id_:
raise ValueError("Ensemble id not set")
Expand All @@ -252,7 +246,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
f"Experiment ran on ORCHESTRATOR: scheduler on {self._queue_config.queue_system} queue"
)

await event_unary_send(event_creator(EVTYPE_ENSEMBLE_STARTED, None))
await event_unary_send(event_creator(EVTYPE_ENSEMBLE_STARTED))

min_required_realizations = (
self.min_required_realizations
Expand All @@ -261,7 +255,9 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
)

self._scheduler.add_dispatch_information_to_jobs_file()
sched_run = asyncio.create_task(self.update_scheduler_running())
result = await self._scheduler.execute(min_required_realizations)
await sched_run

except Exception as exc:
logger.exception(
Expand All @@ -272,13 +268,18 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
),
exc_info=True,
)
await event_unary_send(event_creator(EVTYPE_ENSEMBLE_FAILED, None))
await event_unary_send(event_creator(EVTYPE_ENSEMBLE_FAILED))
return

logger.info(f"Experiment ran on QUEUESYSTEM: {self._queue_config.queue_system}")

# Dispatch final result from evaluator - FAILED, CANCEL or STOPPED
await event_unary_send(event_creator(result, None))
await event_unary_send(event_creator(result))

async def update_scheduler_running(self) -> None:
assert self._scheduler is not None
await self._scheduler.running.wait()
self.scheduler_started.set()

@property
def cancellable(self) -> bool:
Expand All @@ -292,6 +293,9 @@ def cancel(self) -> None:

class _KillAllJobs(Protocol):
def kill_all_jobs(self) -> None: ...
@property
def running(self) -> asyncio.Event:
pass


@dataclass
Expand Down
Loading

0 comments on commit 080d7d4

Please sign in to comment.