Skip to content

Commit

Permalink
to be squashed
Browse files Browse the repository at this point in the history
  • Loading branch information
JHolba committed Aug 23, 2024
1 parent 0792552 commit 442eaf0
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 49 deletions.
20 changes: 13 additions & 7 deletions src/_ert/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import sys
from datetime import datetime
from typing import Annotated, Any, Dict, Literal, Union
from typing import Any, Dict, Literal, Union

if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated

from pydantic import BaseModel, Field, TypeAdapter

Expand Down Expand Up @@ -66,8 +72,7 @@ class ForwardModelChecksum(BaseEvent):
event_type: _FORWARD_MODEL_CHECKSUM = "com.equinor.ert.forward_model_job.checksum"
ensemble: Union[str, None] = None
real: str
run_path: str
data: Any
checksums: Dict[str, Dict[str, Any]]


class RealizationBaseEvent(BaseEvent):
Expand Down Expand Up @@ -137,7 +142,6 @@ class EESnapshotUpdate(BaseEvent):

class EETerminated(BaseEvent):
event_type: _EE_TERMINATED = "com.equinor.ert.ee.terminated"
data: Any
ensemble: Union[str, None] = None


Expand Down Expand Up @@ -173,19 +177,21 @@ class EEUserDone(BaseEvent):

Event = Union[FMEvent, ForwardModelChecksum, RealizationEvent, EEEvent, EnsembleEvent]

DispatchEvent = Union[FMEvent, ForwardModelChecksum, RealizationEvent]
DispatchEvent = Union[
FMEvent, ForwardModelChecksum, RealizationEvent, EnsembleSucceeded, EnsembleFailed
]

_ALL_EVENTS_ANNOTATION = Annotated[Event, Field(discriminator="event_type")]

EventAdapter: TypeAdapter[Event] = TypeAdapter(_ALL_EVENTS_ANNOTATION)


def event_from_json(raw_msg: Union[str, bytes]) -> Event:
return EventAdapter.validate_json(raw_msg)
return EventAdapter.validate_json(raw_msg, strict=True)


def event_from_dict(dict_msg: Dict[str, Any]) -> Event:
return EventAdapter.validate_python(dict_msg)
return EventAdapter.validate_python(dict_msg, strict=True)


def event_to_json(event: Event) -> str:
Expand Down
46 changes: 20 additions & 26 deletions src/_ert_forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import datetime
import logging
import queue
import threading
from datetime import datetime, timedelta
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -98,7 +98,7 @@ def _event_publisher(self):
with self._timestamp_lock:
if (
self._timeout_timestamp is not None
and datetime.datetime.now() > self._timeout_timestamp
and datetime.now() > self._timeout_timestamp
):
self._timeout_timestamp = None
break
Expand Down Expand Up @@ -142,15 +142,16 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
}
if isinstance(msg, Start):
logger.debug(f"Job {job_name} was successfully started")
start_msg = job_msg.copy()
start_msg["std_out"] = str(Path(msg.job.std_out).resolve())
start_msg["std_err"] = str(Path(msg.job.std_err).resolve())
self._dump_event(ForwardModelStart(**start_msg))
event = ForwardModelStart(
**job_msg,
std_out=str(Path(msg.job.std_out).resolve()),
std_err=str(Path(msg.job.std_err).resolve()),
)
self._dump_event(event)
if not msg.success():
logger.error(f"Job {job_name} FAILED to start")
self._dump_event(
ForwardModelFailure(**job_msg, error_msg=msg.error_message)
)
event = ForwardModelFailure(**job_msg, error_msg=msg.error_message)
self._dump_event(event)

elif isinstance(msg, Exited):
if msg.success():
Expand All @@ -164,30 +165,24 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
error_message=msg.error_message,
)
)
self._dump_event(
ForwardModelFailure(
**job_msg, exit_code=msg.exit_code, error_msg=msg.error_message
)
event = ForwardModelFailure(
**job_msg, exit_code=msg.exit_code, error_msg=msg.error_message
)
self._dump_event(event)

elif isinstance(msg, Running):
logger.debug(f"{job_name} job is running")
job_msg["data"] = {
"max_memory_usage": msg.memory_status.max_rss,
"current_memory_usage": msg.memory_status.rss,
}
self._dump_event(
ForwardModelRunning(
**job_msg,
max_memory_usage=msg.memory_status.max_rss,
current_memory_usage=msg.memory_status.rss,
)
event = ForwardModelRunning(
**job_msg,
max_memory_usage=msg.memory_status.max_rss,
current_memory_usage=msg.memory_status.rss,
)
self._dump_event(event)

def _finished_handler(self, msg: Finish):
self._event_queue.put(self._sentinel)
with self._timestamp_lock:
self._timeout_timestamp = datetime.datetime.now() + datetime.timedelta(
self._timeout_timestamp = datetime.now() + timedelta(
seconds=self._reporter_timeout
)
if self._event_publisher_thread.is_alive():
Expand All @@ -197,7 +192,6 @@ def _checksum_handler(self, msg: Checksum):
fm_checksum = ForwardModelChecksum(
ensemble=self._ens_id,
real=self._real_id,
run_path=msg.run_path,
data=msg.data,
checksums={msg.run_path: msg.data},
)
self._dump_event(fm_checksum)
9 changes: 2 additions & 7 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._events_to_send: asyncio.Queue[Event] = asyncio.Queue()
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()

self._result = None

self._ee_tasks: List[asyncio.Task[None]] = []
self._server_started: asyncio.Event = asyncio.Event()
self._server_done: asyncio.Event = asyncio.Event()
Expand Down Expand Up @@ -163,7 +161,6 @@ async def _stopped_handler(self, events: Sequence[EnsembleSucceeded]) -> None:
if self.ensemble.status == ENSEMBLE_STATE_FAILED:
return

# self._result = events[0]["data"] # normal termination
max_memory_usage = -1
for job in self.ensemble.snapshot.get_all_forward_models().values():
memory_usage = job.get(ids.MAX_MEMORY_USAGE) or "-1"
Expand Down Expand Up @@ -208,7 +205,7 @@ def store_client(
async def handle_client(self, websocket: WebSocketServerProtocol) -> None:
with self.store_client(websocket):
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event = EESnapshot(
event: Event = EESnapshot(
snapshot=current_snapshot_dict, ensemble=self.ensemble.id_
)
await websocket.send(event_to_json(event))
Expand Down Expand Up @@ -270,8 +267,6 @@ async def handle_dispatch(self, websocket: WebSocketServerProtocol) -> None:
)

async def forward_checksum(self, event: Event) -> None:
event = cast(ForwardModelChecksum, event)
event.data[event.run_path] = event.data.copy()
# clients still need to receive events via ws
await self._events_to_send.put(event)
await self._manifest_queue.put(event)
Expand Down Expand Up @@ -326,7 +321,7 @@ async def _server(self) -> None:

logger.debug("Sending termination-message to clients...")

event = EETerminated(data=self._result, ensemble=self._ensemble.id_)
event = EETerminated(ensemble=self._ensemble.id_)
await self._events_to_send.put(event)
await self._events.join()
await self._batch_processing_queue.join()
Expand Down
9 changes: 5 additions & 4 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,16 @@ def update_from_event(
end_time = None
callback_status_message = None

if e_type == RealizationRunning:
if e_type is RealizationRunning:
start_time = convert_iso8601_to_datetime(timestamp)
elif e_type in {
RealizationSuccess,
RealizationFailed,
RealizationTimeout,
}:
if type(event) is RealizationFailed:
callback_status_message = event.callback_status_message
end_time = convert_iso8601_to_datetime(timestamp)
if type(event) is RealizationFailed:
callback_status_message = event.callback_status_message
self.update_realization(
event.real,
status,
Expand Down Expand Up @@ -382,7 +382,8 @@ def update_from_event(
)

elif e_type in get_args(EnsembleEvent):
self._ensemble_state = _ENSEMBLE_TYPE_EVENT_TO_STATUS[e_type]
event = cast(EnsembleEvent, event)
self._ensemble_state = _ENSEMBLE_TYPE_EVENT_TO_STATUS[type(event)]
elif type(event) is EESnapshotUpdate:
self.merge_snapshot(Snapshot.from_nested_dict(event.snapshot))
elif type(event) is EESnapshot:
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ async def run_monitor(
)
# Allow track() to emit an EndEvent.
return False
elif type(event) == EETerminated:
elif type(event) is EETerminated:
logger.debug("got terminator event")

if not self._end_queue.empty():
Expand Down
2 changes: 1 addition & 1 deletion src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ async def _checksum_consumer(self) -> None:
while True:
event = await self._manifest_queue.get()
if type(event) is ForwardModelChecksum:
self.checksum.update(event.data)
self.checksum.update(event.checksums)
self._manifest_queue.task_done()

async def _publisher(self) -> None:
Expand Down
4 changes: 1 addition & 3 deletions tests/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,14 @@ async def load_successful(_):
ecl_config = Mock()
ecl_config.assert_restart = Mock()

ensemble = LegacyEnsemble(
return LegacyEnsemble(
realizations,
{},
queue_config,
0,
"0",
)

return ensemble

return _make_ensemble_builder


Expand Down

0 comments on commit 442eaf0

Please sign in to comment.