Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Sep 23, 2024
1 parent 8b3e759 commit 6587920
Showing 1 changed file with 147 additions and 83 deletions.
230 changes: 147 additions & 83 deletions src/_ert/forward_model_runner/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import io
import json
import logging
import os
Expand All @@ -10,7 +11,7 @@
from datetime import datetime as dt
from pathlib import Path
from subprocess import Popen, run
from typing import Optional, Tuple
from typing import Dict, Optional, Sequence, Tuple

from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess

Expand Down Expand Up @@ -89,7 +90,7 @@ def run(self):
except Exception as e:
yield Exited(self, exit_code=1).with_error(str(e))

def _run(self):
def create_start_message_and_check_job_files(self) -> Start:
start_message = Start(self)

errors = self._check_job_files()
Expand All @@ -99,19 +100,23 @@ def _run(self):
self._dump_exec_env()

if errors:
yield start_message.with_error("\n".join(errors))
return

yield start_message
start_message = start_message.with_error("\n".join(errors))
return start_message

def _build_arg_list(self):
executable = self.job_data.get("executable")
assert_file_executable(executable)

arg_list = [executable]
if self.job_data.get("argList"):
arg_list += self.job_data["argList"]
combined_arg_list = [executable]
if arg_list := self.job_data.get("argList"):
combined_arg_list += arg_list
return combined_arg_list

# stdin/stdout/stderr are closed at the end of this function
def _open_file_handles(
self,
) -> Tuple[
io.TextIOWrapper | None, io.TextIOWrapper | None, io.TextIOWrapper | None
]:
if self.job_data.get("stdin"):
stdin = open(self.job_data.get("stdin"), encoding="utf-8") # noqa
else:
Expand All @@ -135,56 +140,50 @@ def _run(self):
else:
stdout = None

target_file = self.job_data.get("target_file")
target_file_mtime: int = 0
if target_file and os.path.exists(target_file):
stat = os.stat(target_file)
target_file_mtime = stat.st_mtime_ns
return (stdin, stdout, stderr)

max_running_minutes = self.job_data.get("max_running_minutes")
run_start_time = dt.now()
def _create_environment(self) -> Dict:
environment = self.job_data.get("environment")
if environment is not None:
environment = {**os.environ, **environment}
return environment

def _run(self):
start_message = self.create_start_message_and_check_job_files()

yield start_message
if not start_message.success:
return

def ensure_file_handles_closed():
if stdin is not None:
stdin.close()
if stdout is not None:
stdout.close()
if stderr is not None:
stderr.close()
arg_list = self._build_arg_list()

(stdin, stdout, stderr) = self._open_file_handles()
# stdin/stdout/stderr are closed at the end of this function

target_file = self.job_data.get("target_file")
target_file_mtime: int = _get_target_file_ntime(target_file)

try:
proc = Popen(
arg_list,
stdin=stdin,
stdout=stdout,
stderr=stderr,
env=environment,
env=self._create_environment(),
)
process = Process(proc.pid)
except OSError as e:
msg = f"{e.strerror} {e.filename}"
if e.strerror == "Exec format error" and e.errno == 8:
msg = (
f"Missing execution format information in file: {e.filename!r}."
f"Most likely you are missing and should add "
f"'#!/usr/bin/env python' to the top of the file: "
)
if stderr:
stderr.write(msg)
ensure_file_handles_closed()
yield Exited(self, e.errno).with_error(msg)
exited_message = self._handle_process_io_error_and_create_exited_message(
e, stderr
)
ensure_file_handles_closed([stdin, stdout, stderr])
yield exited_message
return

exit_code = None

# All child pids for the forward model step. Need to track these in order to be able
# to detect OOM kills in case of failure.
fm_step_pids = {process.pid}

max_memory_usage = 0
fm_step_pids = {int(process.pid)}
while exit_code is None:
(memory_rss, cpu_seconds, oom_score) = _get_processtree_data(process)
max_memory_usage = max(memory_rss, max_memory_usage)
Expand All @@ -203,67 +202,118 @@ def ensure_file_handles_closed():
try:
exit_code = process.wait(timeout=self.MEMORY_POLL_PERIOD)
except TimeoutExpired:
exited_msg = self.handle_process_timeout_and_create_exited_msg(
process, proc
)
fm_step_pids |= {
int(child.pid) for child in process.children(recursive=True)
}
run_time = dt.now() - run_start_time
if (
max_running_minutes is not None
and run_time.seconds > max_running_minutes * 60
):
# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

yield Exited(self, exit_code).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
)
if isinstance(exited_msg, Exited):
yield exited_msg
return

exited_message = Exited(self, exit_code)
ensure_file_handles_closed([stdin, stdout, stderr])
exited_message = self._create_exited_message_based_on_exit_code(
max_memory_usage, target_file_mtime, exit_code, fm_step_pids
)
yield exited_message

def _create_exited_message_based_on_exit_code(
self,
max_memory_usage: int,
target_file_mtime: Optional[int],
exit_code: int,
fm_step_pids: Sequence[int],
) -> Exited:
# exit_code = proc.returncode

if exit_code != 0:
if killed_by_oom(fm_step_pids):
yield exited_message.with_error(
f"Forward model step {self.job_data.get('name')} "
"was killed due to out-of-memory. "
"Max memory usage recorded by Ert for the "
f"realization was {max_memory_usage//1024//1024} MB"
)
else:
yield exited_message.with_error(
f"Process exited with status code {exit_code}"
)
return
exited_message = self._create_exited_msg_for_non_zero_exit_code(
max_memory_usage, exit_code, fm_step_pids
)
return exited_message

# exit_code is 0

if self.job_data.get("error_file") and os.path.exists(
self.job_data["error_file"]
):
yield exited_message.with_error(
exited_message = Exited(self, exit_code)
return exited_message.with_error(
f'Found the error file:{self.job_data["error_file"]} - job failed.'
)
return

if target_file:
if target_file_mtime:
target_file_error = self._check_target_file_is_written(target_file_mtime)
if target_file_error:
yield exited_message.with_error(target_file_error)
return
ensure_file_handles_closed()
yield exited_message
return exited_message.with_error(target_file_error)

return Exited(self, exit_code)

def _create_exited_msg_for_non_zero_exit_code(
self,
max_memory_usage: int,
exit_code: int,
fm_step_pids: Sequence[int],
) -> Exited:
# All child pids for the forward model step. Need to track these in order to be able
# to detect OOM kills in case of failure.
exited_message = Exited(self, exit_code)

if killed_by_oom(fm_step_pids):
return exited_message.with_error(
f"Forward model step {self.job_data.get('name')} "
"was killed due to out-of-memory. "
"Max memory usage recorded by Ert for the "
f"realization was {max_memory_usage//1024//1024} MB"
)
return exited_message.with_error(
f"Process exited with status code {exited_message.exit_code}"
)

def handle_process_timeout_and_create_exited_msg(
self, process: Process, proc: Popen
) -> Exited | None:
max_running_minutes = self.job_data.get("max_running_minutes")
run_start_time = dt.now()

run_time = dt.now() - run_start_time
if (
max_running_minutes is not None
and run_time.seconds > max_running_minutes * 60
):
# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

return Exited(self, proc.returncode).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
)

def _handle_process_io_error_and_create_exited_message(
self, e: OSError, stderr: io.TextIOWrapper | None
) -> Exited:
msg = f"{e.strerror} {e.filename}"
if e.strerror == "Exec format error" and e.errno == 8:
msg = (
f"Missing execution format information in file: {e.filename!r}."
f"Most likely you are missing and should add "
f"'#!/usr/bin/env python' to the top of the file: "
)
if stderr:
stderr.write(msg)
return Exited(self, e.errno).with_error(msg)

def _assert_arg_list(self):
errors = []
Expand Down Expand Up @@ -357,6 +407,20 @@ def _check_target_file_is_written(self, target_file_mtime: int, timeout=5):
return f"Could not find target_file:{target_file}"


def _get_target_file_ntime(file) -> Optional[int]:
mtime = None
if file and os.path.exists(file):
stat = os.stat(file)
mtime = stat.st_mtime_ns
return mtime


def ensure_file_handles_closed(file_handles: Sequence[io.TextIOWrapper | None]):
for file_handle in file_handles:
if file_handle is not None:
file_handle.close()


def _get_processtree_data(
process: Process,
) -> Tuple[int, float, Optional[int]]:
Expand Down

0 comments on commit 6587920

Please sign in to comment.