Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Oct 3, 2024
1 parent 35fcc8a commit 7049ec0
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 84 deletions.
161 changes: 81 additions & 80 deletions src/_ert/forward_model_runner/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from datetime import datetime as dt
from pathlib import Path
from subprocess import Popen, run
from typing import Dict, Generator, Optional, Sequence, Tuple
from typing import Dict, Generator, List, Optional, Sequence, Tuple, cast

from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess

from ert.config.forward_model_step import ForwardModelStepJSON

from .io import check_executable
from .reporting.message import (
Exited,
Expand All @@ -26,7 +28,7 @@
logger = logging.getLogger(__name__)


def killed_by_oom(pids: set[int]) -> bool:
def killed_by_oom(pids: Sequence[int]) -> bool:
"""Will try to detect if a process (or any of its descendants) was killed
by the Linux OOM-killer.
Expand Down Expand Up @@ -76,14 +78,16 @@ def killed_by_oom(pids: set[int]) -> bool:
class Job:
MEMORY_POLL_PERIOD = 5 # Seconds between memory polls

def __init__(self, job_data, index, sleep_interval=1) -> None:
def __init__(
self, job_data: ForwardModelStepJSON, index: int, sleep_interval: int = 1
) -> None:
self.sleep_interval = sleep_interval
self.job_data: Dict[str, str] = job_data
self.job_data = job_data
self.index = index
self.std_err = job_data.get("stderr")
self.std_out = job_data.get("stdout")

def run(self) -> Generator[Start | Exited | Running]:
def run(self) -> Generator[Start | Exited | Running | None]:
try:
for msg in self._run():
yield msg
Expand All @@ -94,18 +98,16 @@ def create_start_message_and_check_job_files(self) -> Start:
start_message = Start(self)

errors = self._check_job_files()

errors.extend(self._assert_arg_list())

self._dump_exec_env()

if errors:
start_message = start_message.with_error("\n".join(errors))
return start_message

def _build_arg_list(self):
def _build_arg_list(self) -> List[str]:
executable = self.job_data.get("executable")

# assert executable is not None
combined_arg_list = [executable]
if arg_list := self.job_data.get("argList"):
combined_arg_list += arg_list
Expand All @@ -117,7 +119,7 @@ def _open_file_handles(
io.TextIOWrapper | None, io.TextIOWrapper | None, io.TextIOWrapper | None
]:
if self.job_data.get("stdin"):
stdin = open(self.job_data.get("stdin"), encoding="utf-8") # noqa
stdin = open(cast(Path, self.job_data.get("stdin")), encoding="utf-8") # noqa
else:
stdin = None

Expand All @@ -141,13 +143,13 @@ def _open_file_handles(

return (stdin, stdout, stderr)

def _create_environment(self) -> Dict:
environment = self.job_data.get("environment")
if environment is not None:
environment = {**os.environ, **environment}
return environment
def _create_environment(self) -> Optional[Dict[str, str]]:
combined_environment = None
if environment := self.job_data.get("environment"):
combined_environment = {**os.environ, **environment}
return combined_environment

def _run(self) -> contextlib.Generator[Start | Exited | Running]:
def _run(self) -> Generator[Start | Exited | Running | None]:
start_message = self.create_start_message_and_check_job_files()

yield start_message
Expand All @@ -160,7 +162,7 @@ def _run(self) -> contextlib.Generator[Start | Exited | Running]:
# stdin/stdout/stderr are closed at the end of this function

target_file = self.job_data.get("target_file")
target_file_mtime: int = _get_target_file_ntime(target_file)
target_file_mtime: Optional[int] = _get_target_file_ntime(target_file)

try:
proc = Popen(
Expand Down Expand Up @@ -201,20 +203,22 @@ def _run(self) -> contextlib.Generator[Start | Exited | Running]:
try:
exit_code = process.wait(timeout=self.MEMORY_POLL_PERIOD)
except TimeoutExpired:
exited_msg = self.handle_process_timeout_and_create_exited_msg(
process, proc
potential_exited_msg = (
self.handle_process_timeout_and_create_exited_msg(process, proc)
)
fm_step_pids |= {
int(child.pid) for child in process.children(recursive=True)
}
if isinstance(exited_msg, Exited):
yield exited_msg
if isinstance(potential_exited_msg, Exited):
yield potential_exited_msg

return

ensure_file_handles_closed([stdin, stdout, stderr])
exited_message = self._create_exited_message_based_on_exit_code(
max_memory_usage, target_file_mtime, exit_code, fm_step_pids
)
assert exited_message.job
yield exited_message

def _create_exited_message_based_on_exit_code(
Expand All @@ -224,20 +228,16 @@ def _create_exited_message_based_on_exit_code(
exit_code: int,
fm_step_pids: Sequence[int],
) -> Exited:
# exit_code = proc.returncode

if exit_code != 0:
exited_message = self._create_exited_msg_for_non_zero_exit_code(
max_memory_usage, exit_code, fm_step_pids
)
return exited_message

# exit_code is 0

exited_message = Exited(self, exit_code)
if self.job_data.get("error_file") and os.path.exists(
self.job_data["error_file"]
):
exited_message = Exited(self, exit_code)
return exited_message.with_error(
f'Found the error file:{self.job_data["error_file"]} - job failed.'
)
Expand Down Expand Up @@ -271,34 +271,33 @@ def _create_exited_msg_for_non_zero_exit_code(
)

def handle_process_timeout_and_create_exited_msg(
self, process: Process, proc: Popen
self, process: Process, proc: Popen[Process]
) -> Exited | None:
max_running_minutes = self.job_data.get("max_running_minutes")
run_start_time = dt.now()

run_time = dt.now() - run_start_time
if (
max_running_minutes is not None
and run_time.seconds > max_running_minutes * 60
):
# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

return Exited(self, proc.returncode).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
if max_running_minutes is None or run_time.seconds > max_running_minutes * 60:
return None

# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

return Exited(self, proc.returncode).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
)

def _handle_process_io_error_and_create_exited_message(
self, e: OSError, stderr: io.TextIOWrapper | None
Expand All @@ -314,44 +313,19 @@ def _handle_process_io_error_and_create_exited_message(
stderr.write(msg)
return Exited(self, e.errno).with_error(msg)

def _assert_arg_list(self) -> list[str]:
errors: list[str] = []
if "arg_types" in self.job_data:
arg_types = self.job_data["arg_types"]
arg_list = self.job_data.get("argList")
for index, arg_type in enumerate(arg_types):
if arg_type == "RUNTIME_FILE":
file_path = os.path.join(os.getcwd(), arg_list[index])
if not os.path.isfile(file_path):
errors.append(
f"In job {self.name()}: RUNTIME_FILE {arg_list[index]} "
"does not exist."
)
if arg_type == "RUNTIME_INT":
try:
int(arg_list[index])
except ValueError:
errors.append(
(
f"In job {self.name()}: argument with index {index} "
"is of incorrect type, should be integer."
)
)
return errors

def name(self) -> str:
return self.job_data["name"]

def _dump_exec_env(self) -> None:
exec_env = self.job_data.get("exec_env")
if exec_env:
exec_name, _ = os.path.splitext(
os.path.basename(self.job_data.get("executable"))
os.path.basename(cast(Path, self.job_data.get("executable")))
)
with open(f"{exec_name}_exec_env.json", "w", encoding="utf-8") as f_handle:
f_handle.write(json.dumps(exec_env, indent=4))

def _check_job_files(self)-> list[str]:
def _check_job_files(self) -> list[str]:
"""
Returns the empty list if no failed checks, or a list of errors in case
of failed checks.
Expand All @@ -361,21 +335,23 @@ def _check_job_files(self)-> list[str]:
errors.append(f'Could not locate stdin file: {self.job_data["stdin"]}')

if self.job_data.get("start_file") and not os.path.exists(
self.job_data["start_file"]
cast(Path, self.job_data["start_file"])
):
errors.append(f'Could not locate start_file:{self.job_data["start_file"]}')

if self.job_data.get("error_file") and os.path.exists(
self.job_data.get("error_file")
cast(Path, self.job_data.get("error_file"))
):
os.unlink(self.job_data.get("error_file"))
os.unlink(cast(Path, self.job_data.get("error_file")))

if executable_error := check_executable(self.job_data.get("executable")):
errors.append(str(executable_error))
errors.append(executable_error)

return errors

def _check_target_file_is_written(self, target_file_mtime: int, timeout: int =5) -> None | str:
def _check_target_file_is_written(
self, target_file_mtime: int, timeout: int = 5
) -> None | str:
"""
Check whether or not a target_file eventually appear. Returns None in
case of success, an error message in the case of failure.
Expand Down Expand Up @@ -408,6 +384,31 @@ def _check_target_file_is_written(self, target_file_mtime: int, timeout: int =5)
)
return f"Could not find target_file:{target_file}"

def _assert_arg_list(self):
errors = []
if "arg_types" in self.job_data:
arg_types = self.job_data["arg_types"]
arg_list = self.job_data.get("argList")
for index, arg_type in enumerate(arg_types):
if arg_type == "RUNTIME_FILE":
file_path = os.path.join(os.getcwd(), arg_list[index])
if not os.path.isfile(file_path):
errors.append(
f"In job {self.name()}: RUNTIME_FILE {arg_list[index]} "
"does not exist."
)
if arg_type == "RUNTIME_INT":
try:
int(arg_list[index])
except ValueError:
errors.append(
(
f"In job {self.name()}: argument with index {index} "
"is of incorrect type, should be integer."
)
)
return errors


def _get_target_file_ntime(file: Optional[str]) -> Optional[int]:
mtime = None
Expand Down
6 changes: 3 additions & 3 deletions src/_ert/forward_model_runner/reporting/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def with_error(self, message: str):
self.error_message = message
return self

def success(self):
def success(self) -> bool:
return self.error_message is None


Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self):


class Start(Message):
def __init__(self, job):
def __init__(self, job: "Job"):
super().__init__(job)


Expand All @@ -127,7 +127,7 @@ def __init__(self, job: "Job", memory_status: ProcessTreeStatus):


class Exited(Message):
def __init__(self, job, exit_code):
def __init__(self, job, exit_code: int):
super().__init__(job)
self.exit_code = exit_code

Expand Down
1 change: 0 additions & 1 deletion src/_ert/forward_model_runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def run(self, names_of_jobs_to_run):
for job in job_queue:
for status_update in job.run():
yield status_update

if not status_update.success():
yield Checksum(checksum_dict={}, run_path=os.getcwd())
yield Finish().with_error("Not all jobs completed successfully.")
Expand Down

0 comments on commit 7049ec0

Please sign in to comment.