From 7d8924665398a58fe2f9725fe0b39a609eb30f9a Mon Sep 17 00:00:00 2001 From: Levi Zhou <31941107+ZhouXY-PKU@users.noreply.github.com> Date: Wed, 11 Oct 2023 05:37:07 +0800 Subject: [PATCH] Update local_context.py/ssh_context.py (#370) To support wildcards for backward_files and backward_common_files. Now only support for local_context and ssh_context. #371 --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng --- dpdispatcher/local_context.py | 54 ++++++++++++++++++-- dpdispatcher/ssh_context.py | 93 ++++++++++++++++++++++++++--------- tests/sample_class.py | 16 ++++-- tests/test_ssh_context.py | 8 ++- 4 files changed, 140 insertions(+), 31 deletions(-) diff --git a/dpdispatcher/local_context.py b/dpdispatcher/local_context.py index 6fffb473..0d466e6b 100644 --- a/dpdispatcher/local_context.py +++ b/dpdispatcher/local_context.py @@ -152,9 +152,34 @@ def download( for ii in submission.belonging_tasks: local_job = os.path.join(self.local_root, ii.task_work_path) remote_job = os.path.join(self.remote_root, ii.task_work_path) - flist = ii.backward_files + flist = [] + for kk in ii.backward_files: + abs_flist_r = glob(os.path.join(remote_job, kk)) + abs_flist_l = glob(os.path.join(local_job, kk)) + if not abs_flist_r and not abs_flist_l: + if check_exists: + if mark_failure: + tag_file_path = os.path.join( + self.local_root, + ii.task_work_path, + "tag_failure_download_%s" % kk, + ) + with open(tag_file_path, "w") as fp: + pass + else: + pass + else: + raise RuntimeError( + "cannot find download file " + os.path.join(remote_job, kk) + ) + rel_flist = [ + os.path.relpath(ii, start=remote_job) for ii in abs_flist_r + ] + flist.extend(rel_flist) if back_error: - flist += glob(os.path.join(remote_job, "error*")) + abs_flist = glob(os.path.join(remote_job, "error*")) + rel_flist = [os.path.relpath(ii, start=remote_job) for ii in abs_flist] + flist.extend(rel_flist) for jj in flist: rfile = os.path.join(remote_job, jj) lfile = os.path.join(local_job, jj) @@ -198,9 +223,30 @@ def download( pass local_job = self.local_root remote_job = self.remote_root - flist = submission.backward_common_files + flist = [] + for kk in submission.backward_common_files: + abs_flist_r = glob(os.path.join(remote_job, kk)) + abs_flist_l = glob(os.path.join(local_job, kk)) + if not abs_flist_r and not abs_flist_l: + if check_exists: + if mark_failure: + tag_file_path = os.path.join( + self.local_root, "tag_failure_download_%s" % kk + ) + with open(tag_file_path, "w") as fp: + pass + else: + pass + else: + raise RuntimeError( + "cannot find download file " + os.path.join(remote_job, kk) + ) + rel_flist = [os.path.relpath(ii, start=remote_job) for ii in abs_flist_r] + flist.extend(rel_flist) if back_error: - flist += glob(os.path.join(remote_job, "error*")) + abs_flist = glob(os.path.join(remote_job, "error*")) + rel_flist = [os.path.relpath(ii, start=remote_job) for ii in abs_flist] + flist.extend(rel_flist) for jj in flist: rfile = os.path.join(remote_job, jj) lfile = os.path.join(local_job, jj) diff --git a/dpdispatcher/ssh_context.py b/dpdispatcher/ssh_context.py index 0c831326..849b66b4 100644 --- a/dpdispatcher/ssh_context.py +++ b/dpdispatcher/ssh_context.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +import fnmatch import os import pathlib import shlex @@ -10,6 +11,7 @@ import uuid from functools import lru_cache from glob import glob +from stat import S_ISDIR, S_ISREG from typing import List import paramiko @@ -414,7 +416,7 @@ def __init__( assert os.path.isabs(remote_root), "remote_root must be a abspath" self.temp_remote_root = remote_root self.remote_profile = remote_profile - self.remote_root = None + self.remote_root = "" # self.job_uuid = None self.clean_asynchronously = clean_asynchronously @@ -634,6 +636,18 @@ def upload( tar_compress=self.remote_profile.get("tar_compress", None), ) + def list_remote_dir(self, sftp, remote_dir, ref_remote_root, result_list): + for entry in sftp.listdir_attr(remote_dir): + remote_name = pathlib.PurePath( + os.path.join(remote_dir, entry.filename) + ).as_posix() + st_mode = entry.st_mode + if S_ISDIR(st_mode): + self.list_remote_dir(sftp, remote_name, ref_remote_root, result_list) + elif S_ISREG(st_mode): + rel_remote_name = os.path.relpath(remote_name, start=ref_remote_root) + result_list.append(rel_remote_name) + def download( self, submission, @@ -646,31 +660,66 @@ def download( self.ssh_session.ensure_alive() file_list = [] # for ii in job_dirs : - for task in submission.belonging_tasks: - for jj in task.backward_files: - file_name = pathlib.PurePath( - os.path.join(task.task_work_path, jj) - ).as_posix() + for ii in submission.belonging_tasks: + remote_file_list = None + for jj in ii.backward_files: + if "*" in jj or "?" in jj: + if remote_file_list is not None: + abs_file_list = fnmatch.filter(remote_file_list, jj) + else: + remote_file_list = [] + remote_job = pathlib.PurePath( + os.path.join(self.remote_root, ii.task_work_path) + ).as_posix() + self.list_remote_dir( + self.sftp, remote_job, remote_job, remote_file_list + ) + + abs_file_list = fnmatch.filter(remote_file_list, jj) + rel_file_list = [ + pathlib.PurePath(os.path.join(ii.task_work_path, kk)).as_posix() + for kk in abs_file_list + ] + + else: + rel_file_list = [ + pathlib.PurePath(os.path.join(ii.task_work_path, jj)).as_posix() + ] if check_exists: - if self.check_file_exists(file_name): - file_list.append(file_name) - elif mark_failure: - with open( - os.path.join( - self.local_root, - task.task_work_path, - "tag_failure_download_%s" % jj, - ), - "w", - ) as fp: + for file_name in rel_file_list: + if self.check_file_exists(file_name): + file_list.append(file_name) + elif mark_failure: + with open( + os.path.join( + self.local_root, + ii.task_work_path, + "tag_failure_download_%s" % jj, + ), + "w", + ) as fp: + pass + else: pass - else: - pass else: - file_list.append(file_name) + file_list.extend(rel_file_list) if back_error: - errors = glob(os.path.join(task.task_work_path, "error*")) - file_list.extend(errors) + if remote_file_list is not None: + abs_errors = fnmatch.filter(remote_file_list, "error*") + else: + remote_file_list = [] + remote_job = pathlib.PurePath( + os.path.join(self.remote_root, ii.task_work_path) + ).as_posix() + self.list_remote_dir( + self.sftp, remote_job, remote_job, remote_file_list + ) + abs_errors = fnmatch.filter(remote_file_list, "error*") + rel_errors = [ + pathlib.PurePath(os.path.join(ii.task_work_path, kk)).as_posix() + for kk in abs_errors + ] + file_list.extend(rel_errors) file_list.extend(submission.backward_common_files) if len(file_list) > 0: self._get_files( diff --git a/tests/sample_class.py b/tests/sample_class.py index 57a9d58e..7c663094 100644 --- a/tests/sample_class.py +++ b/tests/sample_class.py @@ -83,7 +83,7 @@ def get_sample_task_dict(cls): return task_dict @classmethod - def get_sample_task_list(cls): + def get_sample_task_list(cls, backward_wildcard=False): task1 = Task( command="lmp -i input.lammps", task_work_path="bct-1/", @@ -109,6 +109,16 @@ def get_sample_task_list(cls): backward_files=["log.lammps"], ) task_list = [task1, task2, task3, task4] + if backward_wildcard: + task_wildcard = Task( + command="lmp -i input.lammps", + task_work_path="bct-backward_wildcard/", + forward_files=[], + backward_files=["test*/test*"], + outlog="wildcard.log", + errlog="wildcard.err", + ) + task_list.append(task_wildcard) return task_list @classmethod @@ -127,9 +137,9 @@ def get_sample_empty_submission(cls): return empty_submission @classmethod - def get_sample_submission(cls): + def get_sample_submission(cls, backward_wildcard=False): submission = cls.get_sample_empty_submission() - task_list = cls.get_sample_task_list() + task_list = cls.get_sample_task_list(backward_wildcard=backward_wildcard) submission.register_task_list(task_list) submission.generate_jobs() return submission diff --git a/tests/test_ssh_context.py b/tests/test_ssh_context.py index 83ca23f2..05a1e986 100644 --- a/tests/test_ssh_context.py +++ b/tests/test_ssh_context.py @@ -41,7 +41,7 @@ def setUpClass(cls): cls.machine = Machine.load_from_dict(mdata) except (SSHException, socket.timeout): raise unittest.SkipTest("SSHException ssh cannot connect") - cls.submission = SampleClass.get_sample_submission() + cls.submission = SampleClass.get_sample_submission(backward_wildcard=True) cls.submission.bind_machine(cls.machine) cls.submission_hash = cls.submission.submission_hash file_list = [ @@ -50,6 +50,8 @@ def setUpClass(cls): "bct-3/log.lammps", "bct-4/log.lammps", "dir with space/file with space", + "bct-backward_wildcard/test456", + "bct-backward_wildcard/test123/test123", ] for file in file_list: cls.machine.context.sftp.mkdir( @@ -187,7 +189,7 @@ def setUpClass(cls): cls.machine = Machine.load_from_dict(mdata) except (SSHException, socket.timeout): raise unittest.SkipTest("SSHException ssh cannot connect") - cls.submission = SampleClass.get_sample_submission() + cls.submission = SampleClass.get_sample_submission(backward_wildcard=True) cls.submission.bind_machine(cls.machine) cls.submission_hash = cls.submission.submission_hash file_list = [ @@ -196,6 +198,8 @@ def setUpClass(cls): "bct-3/log.lammps", "bct-4/log.lammps", "dir with space/file with space", + "bct-backward_wildcard/test456", + "bct-backward_wildcard/test123/test123", ] for file in file_list: cls.machine.context.sftp.mkdir(