From b59d99919cdafd437613f553e2bce0f12ad768a9 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 9 Jul 2023 23:00:45 +0200 Subject: [PATCH 01/29] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Allow=20for=20file?= =?UTF-8?q?=20uploads/downloads=20to=20be=20async?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/engine/daemon/execmanager.py | 14 +++++++------- aiida/engine/processes/calcjobs/calcjob.py | 14 +++++++------- aiida/engine/processes/calcjobs/tasks.py | 6 +++--- aiida/transports/transport.py | 21 +++++++++++++++++++++ environment.yml | 2 +- pyproject.toml | 2 +- requirements/requirements-py-3.10.txt | 2 +- requirements/requirements-py-3.11.txt | 2 +- requirements/requirements-py-3.9.txt | 2 +- 9 files changed, 43 insertions(+), 22 deletions(-) diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 74468eb34c..aab9c48512 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -62,7 +62,7 @@ def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: return data_node -def upload_calculation( +async def upload_calculation( node: CalcJobNode, transport: Transport, calc_info: CalcInfo, @@ -242,7 +242,7 @@ def upload_calculation( if not dry_run: for filename in folder.get_content_list(): logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...') - transport.put(folder.get_abs_path(filename), filename) + await transport.put_async(folder.get_abs_path(filename), filename) for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_copy_list: if remote_computer_uuid == computer.uuid: @@ -449,7 +449,7 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash') -def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str) -> None: +async def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str) -> None: """Retrieve all the files of a completed job calculation using the given transport. If the job defined anything in the `retrieve_temporary_list`, those entries will be stored in the @@ -488,14 +488,14 @@ def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retriev retrieve_temporary_list = calculation.get_retrieve_temporary_list() with SandboxFolder(filepath_sandbox) as folder: - retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list) + await retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list) # Here I retrieved everything; now I store them inside the calculation retrieved_files.base.repository.put_object_from_tree(folder.abspath) # Retrieve the temporary files in the retrieved_temporary_folder if any files were # specified in the 'retrieve_temporary_list' key if retrieve_temporary_list: - retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list) + await retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list) # Log the files that were retrieved in the temporary folder for filename in os.listdir(retrieved_temporary_folder): @@ -553,7 +553,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: ) -def retrieve_files_from_list( +async def retrieve_files_from_list( calculation: CalcJobNode, transport: Transport, folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], list]] ) -> None: @@ -613,4 +613,4 @@ def retrieve_files_from_list( for rem, loc in zip(remote_names, local_names): transport.logger.debug(f"[retrieval of calc {calculation.pk}] Trying to retrieve remote item '{rem}'") - transport.get(rem, os.path.join(folder, loc), ignore_nonexisting=True) + await transport.get_async(rem, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 475fa94e4d..da5c54a14d 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -529,7 +529,7 @@ def on_terminated(self) -> None: super().on_terminated() @override - def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: + async def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: """Run the calculation job. This means invoking the `presubmit` and storing the temporary folder in the node's repository. Then we move the @@ -540,11 +540,11 @@ def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wa """ if self.inputs.metadata.dry_run: - self._perform_dry_run() + await self._perform_dry_run() return plumpy.process_states.Stop(None, True) if 'remote_folder' in self.inputs: - exit_code = self._perform_import() + exit_code = await self._perform_import() return exit_code # The following conditional is required for the caching to properly work. Even if the source node has a process @@ -598,7 +598,7 @@ def _setup_inputs(self) -> None: if not self.node.computer: self.node.computer = self.inputs.code.computer - def _perform_dry_run(self): + async def _perform_dry_run(self): """Perform a dry run. Instead of performing the normal sequence of steps, just the `presubmit` is called, which will call the method @@ -615,13 +615,13 @@ def _perform_dry_run(self): with SubmitTestFolder() as folder: calc_info = self.presubmit(folder) transport.chdir(folder.abspath) - upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) + await upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) self.node.dry_run_info = { # type: ignore 'folder': folder.abspath, 'script_filename': self.node.get_option('submit_script_filename') } - def _perform_import(self): + async def _perform_import(self): """Perform the import of an already completed calculation. The inputs contained a `RemoteData` under the key `remote_folder` signalling that this is not supposed to be run @@ -641,7 +641,7 @@ def _perform_import(self): with SandboxFolder(filepath_sandbox) as retrieved_temporary_folder: self.presubmit(folder) self.node.set_remote_workdir(self.inputs.remote_folder.get_remote_path()) - retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) + await retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) self.node.set_state(CalcJobState.PARSING) self.node.base.attributes.set(orm.CalcJobNode.IMMIGRATED_KEY, True) return self.parse(retrieved_temporary_folder.abspath) diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 5d7e14eaf2..9256f6b598 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -92,7 +92,7 @@ async def do_upload(): except Exception as exception: # pylint: disable=broad-except raise PreSubmitException('exception occurred in presubmit call') from exception else: - execmanager.upload_calculation(node, transport, calc_info, folder) + await execmanager.upload_calculation(node, transport, calc_info, folder) skip_submit = calc_info.skip_submit or False return skip_submit @@ -310,7 +310,7 @@ async def do_retrieve(): if node.get_job_id() is None: logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`') - return execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + return await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) try: detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id()) @@ -320,7 +320,7 @@ async def do_retrieve(): else: node.set_detailed_job_info(detailed_job_info) - return execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + return await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) try: logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>') diff --git a/aiida/transports/transport.py b/aiida/transports/transport.py index 2e4484cdba..51bc701679 100644 --- a/aiida/transports/transport.py +++ b/aiida/transports/transport.py @@ -449,6 +449,16 @@ def get(self, remotepath, localpath, *args, **kwargs): :param localpath: (str) local_folder_path """ + async def get_async(self, remotepath, localpath, *args, **kwargs): + """ + Retrieve a file or folder from remote source to local destination + dst must be an absolute path (src not necessarily) + + :param remotepath: (str) remote_folder_path + :param localpath: (str) local_folder_path + """ + return self.get(remotepath, localpath, *args, **kwargs) + @abc.abstractmethod def getfile(self, remotepath, localpath, *args, **kwargs): """ @@ -622,6 +632,17 @@ def put(self, localpath, remotepath, *args, **kwargs): :param str remotepath: path to remote destination """ + async def put_async(self, localpath, remotepath, *args, **kwargs): + """ + Put a file or a directory from local src to remote dst. + src must be an absolute path (dst not necessarily)) + Redirects to putfile and puttree. + + :param str localpath: absolute path to local source + :param str remotepath: path to remote destination + """ + return self.put(localpath, remotepath, *args, **kwargs) + @abc.abstractmethod def putfile(self, localpath, remotepath, *args, **kwargs): """ diff --git a/environment.yml b/environment.yml index 135248c015..4718d3a241 100644 --- a/environment.yml +++ b/environment.yml @@ -23,7 +23,7 @@ dependencies: - importlib-metadata~=4.13 - numpy~=1.21 - paramiko>=2.7.2,~=2.7 -- plumpy~=0.21.6 +- plumpy@ git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy - pgsu~=0.2.1 - psutil~=5.6 - psycopg2-binary~=2.8 diff --git a/pyproject.toml b/pyproject.toml index aebcba57a3..27f1ceb1f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "importlib-metadata~=4.13", "numpy~=1.21", "paramiko~=2.7,>=2.7.2", - "plumpy~=0.21.6", + "plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy", "pgsu~=0.2.1", "psutil~=5.6", "psycopg2-binary~=2.8", diff --git a/requirements/requirements-py-3.10.txt b/requirements/requirements-py-3.10.txt index 99a8938f59..667154f6cb 100644 --- a/requirements/requirements-py-3.10.txt +++ b/requirements/requirements-py-3.10.txt @@ -120,7 +120,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy==0.21.8 +plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.11.txt b/requirements/requirements-py-3.11.txt index 92afc7c8b0..e6efe3aae2 100644 --- a/requirements/requirements-py-3.11.txt +++ b/requirements/requirements-py-3.11.txt @@ -119,7 +119,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy==0.21.8 +plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index cb62fa681d..05293c9a15 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -122,7 +122,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy==0.21.8 +plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 From 0c4284165e553bc41825757b1c55915b19fcd42e Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 9 Jul 2023 23:29:04 +0200 Subject: [PATCH 02/29] Update test_execmanager.py --- tests/engine/daemon/test_execmanager.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/engine/daemon/test_execmanager.py b/tests/engine/daemon/test_execmanager.py index 823c42fc4f..88ff68ba7c 100644 --- a/tests/engine/daemon/test_execmanager.py +++ b/tests/engine/daemon/test_execmanager.py @@ -148,7 +148,8 @@ def test_hierarchy_utility(file_hierarchy, tmp_path): (['file_a.txt', 'file_u.txt', 'path/file_u.txt', ('path/sub/file_u.txt', '.', 3)], {'file_a.txt': 'file_a'}), )) # yapf: enable -def test_retrieve_files_from_list( +@pytest.mark.asyncio +async def test_retrieve_files_from_list( tmp_path_factory, generate_calculation_node, file_hierarchy, retrieve_list, expected_hierarchy ): """Test the `retrieve_files_from_list` function.""" @@ -160,7 +161,7 @@ def test_retrieve_files_from_list( with LocalTransport() as transport: node = generate_calculation_node() transport.chdir(source) - execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) + await execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) assert serialize_file_hierarchy(target) == expected_hierarchy @@ -178,7 +179,8 @@ def test_retrieve_files_from_list( (['sub', 'target'], {'target': {'b': 'file_b'}}), )) # yapf: enable -def test_upload_local_copy_list( +@pytest.mark.asyncio +async def test_upload_local_copy_list( fixture_sandbox, node_and_calc_info, file_hierarchy_simple, tmp_path, local_copy_list, expected_hierarchy ): """Test the ``local_copy_list`` functionality in ``upload_calculation``.""" @@ -191,7 +193,7 @@ def test_upload_local_copy_list( calc_info.local_copy_list = [[folder.uuid] + local_copy_list] with LocalTransport() as transport: - execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) # Check that none of the files were written to the repository of the calculation node, since they were communicated # through the ``local_copy_list``. @@ -202,7 +204,8 @@ def test_upload_local_copy_list( assert written_hierarchy == expected_hierarchy -def test_upload_local_copy_list_files_folders(fixture_sandbox, node_and_calc_info, file_hierarchy, tmp_path): +@pytest.mark.asyncio +async def test_upload_local_copy_list_files_folders(fixture_sandbox, node_and_calc_info, file_hierarchy, tmp_path): """Test the ``local_copy_list`` functionality in ``upload_calculation``. Specifically, verify that files in the ``local_copy_list`` do not end up in the repository of the node. @@ -226,7 +229,7 @@ def test_upload_local_copy_list_files_folders(fixture_sandbox, node_and_calc_inf ] with LocalTransport() as transport: - execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) # Check that none of the files were written to the repository of the calculation node, since they were communicated # through the ``local_copy_list``. From 14cbd29390684ad8cb7b6d5769acacca1fcbd175 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Mon, 10 Jul 2023 00:13:50 +0200 Subject: [PATCH 03/29] update run methods --- aiida/engine/processes/functions.py | 2 +- aiida/engine/processes/workchains/workchain.py | 2 +- tests/engine/processes/test_caching.py | 2 +- tests/engine/test_process.py | 6 +++--- tests/engine/test_runners.py | 2 +- tests/engine/test_work_chain.py | 2 +- tests/utils/processes.py | 14 +++++++------- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 8baf92c903..9e0c8eeff5 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -552,7 +552,7 @@ def _setup_db_record(self) -> None: self.node.store_source_info(self._func) @override - def run(self) -> 'ExitCode' | None: + async def run(self) -> 'ExitCode' | None: """Run the process.""" from .exit_code import ExitCode diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py index e6ca21a4b4..de0eabf498 100644 --- a/aiida/engine/processes/workchains/workchain.py +++ b/aiida/engine/processes/workchains/workchain.py @@ -295,7 +295,7 @@ def _update_process_status(self) -> None: @override @Protect.final - def run(self) -> t.Any: + async def run(self) -> t.Any: self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type] return self._do_step() diff --git a/tests/engine/processes/test_caching.py b/tests/engine/processes/test_caching.py index 58244ec3a9..cafcb0a393 100644 --- a/tests/engine/processes/test_caching.py +++ b/tests/engine/processes/test_caching.py @@ -16,7 +16,7 @@ def define(cls, spec): spec.input('a') spec.output_namespace('nested', dynamic=True) - def run(self): + async def run(self): self.out('nested', {'a': self.inputs.a + 2}) diff --git a/tests/engine/test_process.py b/tests/engine/test_process.py index 95d40827a7..34cdb416c9 100644 --- a/tests/engine/test_process.py +++ b/tests/engine/test_process.py @@ -76,7 +76,7 @@ class ProcessStackTest(Process): _node_class = orm.WorkflowNode @override - def run(self): + async def run(self): pass @override @@ -298,7 +298,7 @@ def define(cls, spec): spec.input_namespace('namespace', valid_type=orm.Int, dynamic=True) spec.output_namespace('namespace', valid_type=orm.Int, dynamic=True) - def run(self): + async def run(self): self.out('namespace', self.inputs.namespace) results, node = run_get_node(TestProcess1, namespace={'alpha': orm.Int(1), 'beta': orm.Int(2)}) @@ -322,7 +322,7 @@ def define(cls, spec): spec.output_namespace('integer.namespace', valid_type=orm.Int, dynamic=True) spec.output('required_string', valid_type=orm.Str, required=True) - def run(self): + async def run(self): if self.inputs.add_outputs: self.out('required_string', orm.Str('testing').store()) self.out('integer.namespace.two', orm.Int(2).store()) diff --git a/tests/engine/test_runners.py b/tests/engine/test_runners.py index 1a718c4992..3059cff515 100644 --- a/tests/engine/test_runners.py +++ b/tests/engine/test_runners.py @@ -39,7 +39,7 @@ def define(cls, spec): super().define(spec) spec.input('a') - def run(self): + async def run(self): pass diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 1578a3d2db..bfa594cad4 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -1775,5 +1775,5 @@ def define(cls, spec): super().define(spec) spec.outline(cls.run) - def run(self): + async def run(self): pass diff --git a/tests/utils/processes.py b/tests/utils/processes.py index 88c48a25a5..17a6188746 100644 --- a/tests/utils/processes.py +++ b/tests/utils/processes.py @@ -26,7 +26,7 @@ def define(cls, spec): spec.inputs.valid_type = Data spec.outputs.valid_type = Data - def run(self): + async def run(self): pass @@ -42,7 +42,7 @@ def define(cls, spec): spec.input('b', required=True) spec.output('result', required=True) - def run(self): + async def run(self): summed = self.inputs.a + self.inputs.b self.out(summed.store()) @@ -57,7 +57,7 @@ def define(cls, spec): super().define(spec) spec.outputs.valid_type = Data - def run(self): + async def run(self): self.out('bad_output', 5) @@ -66,7 +66,7 @@ class ExceptionProcess(Process): _node_class = WorkflowNode - def run(self): + async def run(self): raise RuntimeError('CRASH') @@ -75,7 +75,7 @@ class WaitProcess(Process): _node_class = WorkflowNode - def run(self): + async def run(self): return plumpy.Wait(self.next_step) def next_step(self): @@ -95,7 +95,7 @@ def define(cls, spec): 123, 'GENERIC_EXIT_CODE', message='This process should not be used as cache.', invalidates_cache=True ) - def run(self): + async def run(self): if self.inputs.return_exit_code: return self.exit_codes.GENERIC_EXIT_CODE # pylint: disable=no-member @@ -110,7 +110,7 @@ def define(cls, spec): super().define(spec) spec.input('not_valid_cache', valid_type=Bool, default=lambda: Bool(False)) - def run(self): + async def run(self): pass @classmethod From 68110987bb2259b1204c14473f5127b797cad864 Mon Sep 17 00:00:00 2001 From: Ali Khosravi Date: Mon, 18 Nov 2024 16:09:02 +0100 Subject: [PATCH 04/29] async transport, the first implementation --- environment.yml | 6 +- pyproject.toml | 193 ++- requirements/requirements-py-3.10.txt | 2 +- requirements/requirements-py-3.11.txt | 2 +- requirements/requirements-py-3.9.txt | 2 +- src/aiida/engine/daemon/execmanager.py | 661 ++++++++++ .../engine/processes/calcjobs/calcjob.py | 1115 +++++++++++++++++ src/aiida/engine/processes/calcjobs/tasks.py | 683 ++++++++++ src/aiida/orm/computers.py | 4 +- src/aiida/schedulers/plugins/direct.py | 2 +- src/aiida/transports/__init__.py | 2 + src/aiida/transports/plugins/local.py | 5 + src/aiida/transports/plugins/ssh.py | 186 ++- src/aiida/transports/plugins/ssh_async.py | 915 ++++++++++++++ src/aiida/transports/transport.py | 81 +- src/aiida/transports/util.py | 67 + tests/engine/daemon/test_execmanager.py | 134 +- tests/engine/processes/test_caching.py | 2 +- tests/engine/test_process.py | 6 +- tests/engine/test_runners.py | 2 +- tests/engine/test_work_chain.py | 2 +- tests/transports/test_all_plugins.py | 774 +++++------- tests/utils/processes.py | 14 +- utils/dependency_management.py | 0 24 files changed, 4261 insertions(+), 599 deletions(-) create mode 100644 src/aiida/transports/plugins/ssh_async.py mode change 100755 => 100644 utils/dependency_management.py diff --git a/environment.yml b/environment.yml index ee1b308235..86eee6f90b 100644 --- a/environment.yml +++ b/environment.yml @@ -21,9 +21,9 @@ dependencies: - kiwipy[rmq]~=0.8.4 - importlib-metadata~=6.0 - numpy~=1.21 -- paramiko>=2.7.2,~=2.7 -- plumpy@ git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy -- pgsu~=0.2.1 +- paramiko~=3.0 +- plumpy~=0.22.3 +- pgsu~=0.3.0 - psutil~=5.6 - psycopg[binary]~=3.0 - pydantic~=2.4 diff --git a/pyproject.toml b/pyproject.toml index a827f52c46..ef7176fc4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,41 +18,168 @@ classifiers = [ 'Topic :: Scientific/Engineering' ] dependencies = [ - "alembic~=1.2", - "archive-path~=0.4.2", - "aio-pika~=6.6", - "circus~=0.18.0", - "click-spinner~=0.1.8", - "click~=8.1", - "disk-objectstore~=0.6.0", - "docstring-parser", - "get-annotations~=0.1;python_version<'3.10'", - "graphviz~=0.19", - "ipython>=7", - "jinja2~=3.0", - "jsonschema~=3.0", - "kiwipy[rmq]~=0.7.7", - "importlib-metadata~=4.13", - "numpy~=1.21", - "paramiko~=2.7,>=2.7.2", - "plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy", - "pgsu~=0.2.1", - "psutil~=5.6", - "psycopg2-binary~=2.8", - "pytz~=2021.1", - "pyyaml~=6.0", - "requests~=2.0", - "sqlalchemy~=1.4.22", - "tabulate~=0.8.5", - "tqdm~=4.45", - "upf_to_json~=0.9.2", - "wrapt~=1.11" + 'alembic~=1.2', + 'archive-path~=0.4.2', + 'circus~=0.18.0', + 'click-spinner~=0.1.8', + 'click~=8.1', + 'disk-objectstore~=1.2', + 'docstring-parser', + 'get-annotations~=0.1;python_version<"3.10"', + 'graphviz~=0.19', + 'ipython>=7', + 'jedi<0.19', + 'jinja2~=3.0', + 'kiwipy[rmq]~=0.8.4', + 'importlib-metadata~=6.0', + 'numpy~=1.21', + 'paramiko~=3.0', + 'plumpy~=0.22.3', + 'pgsu~=0.3.0', + 'psutil~=5.6', + 'psycopg[binary]~=3.0', + 'pydantic~=2.4', + 'pytz~=2021.1', + 'pyyaml~=6.0', + 'requests~=2.0', + 'sqlalchemy~=2.0', + 'tabulate>=0.8.0,<0.10.0', + 'tqdm~=4.45', + 'upf_to_json~=0.9.2', + 'wrapt~=1.11' ] +description = 'AiiDA is a workflow manager for computational science with a strong focus on provenance, performance and extensibility.' +dynamic = ['version'] # read from aiida/__init__.py +keywords = ['aiida', 'workflows'] +license = {file = 'LICENSE.txt'} +name = 'aiida-core' +readme = 'README.md' +requires-python = '>=3.9' -[project.urls] -Home = "http://www.aiida.net/" -Documentation = "https://aiida.readthedocs.io" -Source = "https://github.com/aiidateam/aiida-core" +[project.entry-points.'aiida.brokers'] +'core.rabbitmq' = 'aiida.brokers.rabbitmq.broker:RabbitmqBroker' + +[project.entry-points.'aiida.calculations'] +'core.arithmetic.add' = 'aiida.calculations.arithmetic.add:ArithmeticAddCalculation' +'core.templatereplacer' = 'aiida.calculations.templatereplacer:TemplatereplacerCalculation' +'core.transfer' = 'aiida.calculations.transfer:TransferCalculation' + +[project.entry-points.'aiida.calculations.importers'] +'core.arithmetic.add' = 'aiida.calculations.importers.arithmetic.add:ArithmeticAddCalculationImporter' + +[project.entry-points.'aiida.calculations.monitors'] +'core.always_kill' = 'aiida.calculations.monitors.base:always_kill' + +[project.entry-points.'aiida.cmdline.computer.configure'] +'core.local' = 'aiida.transports.plugins.local:CONFIGURE_LOCAL_CMD' +'core.ssh' = 'aiida.transports.plugins.ssh:CONFIGURE_SSH_CMD' + +[project.entry-points.'aiida.cmdline.data'] +'core.array' = 'aiida.cmdline.commands.cmd_data.cmd_array:array' +'core.bands' = 'aiida.cmdline.commands.cmd_data.cmd_bands:bands' +'core.cif' = 'aiida.cmdline.commands.cmd_data.cmd_cif:cif' +'core.dict' = 'aiida.cmdline.commands.cmd_data.cmd_dict:dictionary' +'core.remote' = 'aiida.cmdline.commands.cmd_data.cmd_remote:remote' +'core.singlefile' = 'aiida.cmdline.commands.cmd_data.cmd_singlefile:singlefile' +'core.structure' = 'aiida.cmdline.commands.cmd_data.cmd_structure:structure' +'core.trajectory' = 'aiida.cmdline.commands.cmd_data.cmd_trajectory:trajectory' +'core.upf' = 'aiida.cmdline.commands.cmd_data.cmd_upf:upf' + +[project.entry-points.'aiida.cmdline.data.structure.import'] + +[project.entry-points.'aiida.data'] +'core.array' = 'aiida.orm.nodes.data.array.array:ArrayData' +'core.array.bands' = 'aiida.orm.nodes.data.array.bands:BandsData' +'core.array.kpoints' = 'aiida.orm.nodes.data.array.kpoints:KpointsData' +'core.array.projection' = 'aiida.orm.nodes.data.array.projection:ProjectionData' +'core.array.trajectory' = 'aiida.orm.nodes.data.array.trajectory:TrajectoryData' +'core.array.xy' = 'aiida.orm.nodes.data.array.xy:XyData' +'core.base' = 'aiida.orm.nodes.data:BaseType' +'core.bool' = 'aiida.orm.nodes.data.bool:Bool' +'core.cif' = 'aiida.orm.nodes.data.cif:CifData' +'core.code' = 'aiida.orm.nodes.data.code.legacy:Code' +'core.code.containerized' = 'aiida.orm.nodes.data.code.containerized:ContainerizedCode' +'core.code.installed' = 'aiida.orm.nodes.data.code.installed:InstalledCode' +'core.code.portable' = 'aiida.orm.nodes.data.code.portable:PortableCode' +'core.dict' = 'aiida.orm.nodes.data.dict:Dict' +'core.enum' = 'aiida.orm.nodes.data.enum:EnumData' +'core.float' = 'aiida.orm.nodes.data.float:Float' +'core.folder' = 'aiida.orm.nodes.data.folder:FolderData' +'core.int' = 'aiida.orm.nodes.data.int:Int' +'core.jsonable' = 'aiida.orm.nodes.data.jsonable:JsonableData' +'core.list' = 'aiida.orm.nodes.data.list:List' +'core.numeric' = 'aiida.orm.nodes.data.numeric:NumericType' +'core.orbital' = 'aiida.orm.nodes.data.orbital:OrbitalData' +'core.remote' = 'aiida.orm.nodes.data.remote.base:RemoteData' +'core.remote.stash' = 'aiida.orm.nodes.data.remote.stash.base:RemoteStashData' +'core.remote.stash.folder' = 'aiida.orm.nodes.data.remote.stash.folder:RemoteStashFolderData' +'core.singlefile' = 'aiida.orm.nodes.data.singlefile:SinglefileData' +'core.str' = 'aiida.orm.nodes.data.str:Str' +'core.structure' = 'aiida.orm.nodes.data.structure:StructureData' +'core.upf' = 'aiida.orm.nodes.data.upf:UpfData' + +[project.entry-points.'aiida.groups'] +'core' = 'aiida.orm.groups:Group' +'core.auto' = 'aiida.orm.groups:AutoGroup' +'core.import' = 'aiida.orm.groups:ImportGroup' +'core.upf' = 'aiida.orm.groups:UpfFamily' + +[project.entry-points.'aiida.node'] +'data' = 'aiida.orm.nodes.data.data:Data' +'process' = 'aiida.orm.nodes.process.process:ProcessNode' +'process.calculation' = 'aiida.orm.nodes.process.calculation.calculation:CalculationNode' +'process.calculation.calcfunction' = 'aiida.orm.nodes.process.calculation.calcfunction:CalcFunctionNode' +'process.calculation.calcjob' = 'aiida.orm.nodes.process.calculation.calcjob:CalcJobNode' +'process.workflow' = 'aiida.orm.nodes.process.workflow.workflow:WorkflowNode' +'process.workflow.workchain' = 'aiida.orm.nodes.process.workflow.workchain:WorkChainNode' +'process.workflow.workfunction' = 'aiida.orm.nodes.process.workflow.workfunction:WorkFunctionNode' + +[project.entry-points.'aiida.parsers'] +'core.arithmetic.add' = 'aiida.parsers.plugins.arithmetic.add:ArithmeticAddParser' +'core.templatereplacer' = 'aiida.parsers.plugins.templatereplacer.parser:TemplatereplacerParser' + +[project.entry-points.'aiida.schedulers'] +'core.direct' = 'aiida.schedulers.plugins.direct:DirectScheduler' +'core.lsf' = 'aiida.schedulers.plugins.lsf:LsfScheduler' +'core.pbspro' = 'aiida.schedulers.plugins.pbspro:PbsproScheduler' +'core.sge' = 'aiida.schedulers.plugins.sge:SgeScheduler' +'core.slurm' = 'aiida.schedulers.plugins.slurm:SlurmScheduler' +'core.torque' = 'aiida.schedulers.plugins.torque:TorqueScheduler' + +[project.entry-points.'aiida.storage'] +'core.psql_dos' = 'aiida.storage.psql_dos.backend:PsqlDosBackend' +'core.sqlite_dos' = 'aiida.storage.sqlite_dos.backend:SqliteDosStorage' +'core.sqlite_temp' = 'aiida.storage.sqlite_temp.backend:SqliteTempBackend' +'core.sqlite_zip' = 'aiida.storage.sqlite_zip.backend:SqliteZipBackend' + +[project.entry-points.'aiida.tools.calculations'] + +[project.entry-points.'aiida.tools.data.orbitals'] +'core.orbital' = 'aiida.tools.data.orbital.orbital:Orbital' +'core.realhydrogen' = 'aiida.tools.data.orbital.realhydrogen:RealhydrogenOrbital' + +[project.entry-points.'aiida.tools.dbexporters'] + +[project.entry-points.'aiida.tools.dbimporters'] +'core.cod' = 'aiida.tools.dbimporters.plugins.cod:CodDbImporter' +'core.icsd' = 'aiida.tools.dbimporters.plugins.icsd:IcsdDbImporter' +'core.materialsproject' = 'aiida.tools.dbimporters.plugins.materialsproject:MaterialsProjectImporter' +'core.mpds' = 'aiida.tools.dbimporters.plugins.mpds:MpdsDbImporter' +'core.mpod' = 'aiida.tools.dbimporters.plugins.mpod:MpodDbImporter' +'core.nninc' = 'aiida.tools.dbimporters.plugins.nninc:NnincDbImporter' +'core.oqmd' = 'aiida.tools.dbimporters.plugins.oqmd:OqmdDbImporter' +'core.pcod' = 'aiida.tools.dbimporters.plugins.pcod:PcodDbImporter' +'core.tcod' = 'aiida.tools.dbimporters.plugins.tcod:TcodDbImporter' + +[project.entry-points.'aiida.transports'] +'core.local' = 'aiida.transports.plugins.local:LocalTransport' +'core.ssh' = 'aiida.transports.plugins.ssh:SshTransport' +'core.ssh_async' = 'aiida.transports.plugins.ssh_async:AsyncSshTransport' +'core.ssh_auto' = 'aiida.transports.plugins.ssh_auto:SshAutoTransport' + +[project.entry-points.'aiida.workflows'] +'core.arithmetic.add_multiply' = 'aiida.workflows.arithmetic.add_multiply:add_multiply' +'core.arithmetic.multiply_add' = 'aiida.workflows.arithmetic.multiply_add:MultiplyAddWorkChain' [project.optional-dependencies] atomic_tools = [ diff --git a/requirements/requirements-py-3.10.txt b/requirements/requirements-py-3.10.txt index 680bd9ba3c..b2408e8087 100644 --- a/requirements/requirements-py-3.10.txt +++ b/requirements/requirements-py-3.10.txt @@ -120,7 +120,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy +plumpy==0.22.3 prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.11.txt b/requirements/requirements-py-3.11.txt index 2880ef436e..24acc25a6b 100644 --- a/requirements/requirements-py-3.11.txt +++ b/requirements/requirements-py-3.11.txt @@ -119,7 +119,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy +plumpy==0.22.3 prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index b3d1f37a9c..3087e62844 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -122,7 +122,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy +plumpy==0.22.3 prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/src/aiida/engine/daemon/execmanager.py b/src/aiida/engine/daemon/execmanager.py index e69de29bb2..fb24f8955b 100644 --- a/src/aiida/engine/daemon/execmanager.py +++ b/src/aiida/engine/daemon/execmanager.py @@ -0,0 +1,661 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""This file contains the main routines to submit, check and retrieve calculation +results. These are general and contain only the main logic; where appropriate, +the routines make reference to the suitable plugins for all +plugin-specific operations. +""" + +from __future__ import annotations + +import os +import shutil +from collections.abc import Mapping +from logging import LoggerAdapter +from pathlib import Path +from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import Mapping as MappingType + +from aiida.common import AIIDA_LOGGER, exceptions +from aiida.common.datastructures import CalcInfo, FileCopyOperation +from aiida.common.folders import Folder, SandboxFolder +from aiida.common.links import LinkType +from aiida.engine.processes.exit_code import ExitCode +from aiida.manage.configuration import get_config_option +from aiida.orm import CalcJobNode, Code, FolderData, Node, PortableCode, RemoteData, load_node +from aiida.orm.utils.log import get_dblogger_extra +from aiida.repository.common import FileType +from aiida.schedulers.datastructures import JobState + +if TYPE_CHECKING: + from aiida.transports import Transport + +REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found' + +EXEC_LOGGER = AIIDA_LOGGER.getChild('execmanager') + + +def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: + """Find and return the node with the given UUID from a nested mapping of input nodes. + + :param inputs: (nested) mapping of nodes + :param uuid: UUID of the node to find + :return: instance of `Node` or `None` if not found + """ + data_node = None + + for input_node in inputs.values(): + if isinstance(input_node, Mapping): + data_node = _find_data_node(input_node, uuid) + elif isinstance(input_node, Node) and input_node.uuid == uuid: + data_node = input_node + if data_node is not None: + break + + return data_node + + +async def upload_calculation( + node: CalcJobNode, + transport: Transport, + calc_info: CalcInfo, + folder: Folder, + inputs: Optional[MappingType[str, Any]] = None, + dry_run: bool = False, +) -> RemoteData | None: + """Upload a `CalcJob` instance + + :param node: the `CalcJobNode`. + :param transport: an already opened transport to use to submit the calculation. + :param calc_info: the calculation info datastructure returned by `CalcJob.presubmit` + :param folder: temporary local file system folder containing the inputs written by `CalcJob.prepare_for_submission` + :returns: The ``RemoteData`` representing the working directory on the remote if, or ``None`` if ``dry_run=True``. + """ + # If the calculation already has a `remote_folder`, simply return. The upload was apparently already completed + # before, which can happen if the daemon is restarted and it shuts down after uploading but before getting the + # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the upload. + link_label = 'remote_folder' + if node.base.links.get_outgoing(RemoteData, link_label_filter=link_label).first(): + EXEC_LOGGER.warning(f'CalcJobNode<{node.pk}> already has a `{link_label}` output: skipping upload') + return calc_info + + computer = node.computer + + codes_info = calc_info.codes_info + input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info] + + logger_extra = get_dblogger_extra(node) + transport.set_logger_extra(logger_extra) + logger = LoggerAdapter(logger=EXEC_LOGGER, extra=logger_extra) + + if not dry_run and not node.is_stored: + raise ValueError( + f'Cannot submit calculation {node.pk} because it is not stored! If you just want to test the submission, ' + 'set `metadata.dry_run` to True in the inputs.' + ) + + # If we are performing a dry-run, the working directory should actually be a local folder that should already exist + if dry_run: + workdir = Path(folder.abspath) + else: + remote_user = transport.whoami() + remote_working_directory = computer.get_workdir().format(username=remote_user) + if not remote_working_directory.strip(): + raise exceptions.ConfigurationError( + f'[submission of calculation {node.pk}] No remote_working_directory ' + f"configured for computer '{computer.label}'" + ) + + # If it already exists, no exception is raised + if not transport.path_exists(remote_working_directory): + logger.debug( + f'[submission of calculation {node.pk}] Path ' + f'{remote_working_directory} does not exist, trying to create it' + ) + try: + transport.makedirs(remote_working_directory) + except EnvironmentError as exc: + raise exceptions.ConfigurationError( + f'[submission of calculation {node.pk}] ' + f'Unable to create the remote directory {remote_working_directory} on ' + f"computer '{computer.label}': {exc}" + ) + # Store remotely with sharding (here is where we choose + # the folder structure of remote jobs; then I store this + # in the calculation properties using _set_remote_dir + # and I do not have to know the logic, but I just need to + # read the absolute path from the calculation properties. + workdir = Path(remote_working_directory).joinpath(calc_info.uuid[:2], calc_info.uuid[2:4]) + transport.makedirs(str(workdir), ignore_existing=True) + + try: + # The final directory may already exist, most likely because this function was already executed once, but + # failed and as a result was rescheduled by the engine. In this case it would be fine to delete the folder + # and create it from scratch, except that we cannot be sure that this the actual case. Therefore, to err on + # the safe side, we move the folder to the lost+found directory before recreating the folder from scratch + transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + except OSError: + # Move the existing directory to lost+found, log a warning and create a clean directory anyway + path_existing = os.path.join(str(workdir), calc_info.uuid[4:]) + path_lost_found = os.path.join(remote_working_directory, REMOTE_WORK_DIRECTORY_LOST_FOUND) + path_target = os.path.join(path_lost_found, calc_info.uuid) + logger.warning( + f'tried to create path {path_existing} but it already exists, moving the entire folder to {path_target}' + ) + + # Make sure the lost+found directory exists, then copy the existing folder there and delete the original + transport.mkdir(path_lost_found, ignore_existing=True) + transport.copytree(path_existing, path_target) + transport.rmtree(path_existing) + + # Now we can create a clean folder for this calculation + transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + finally: + workdir = workdir.joinpath(calc_info.uuid[4:]) + + node.set_remote_workdir(str(workdir)) + + # I first create the code files, so that the code can put + # default files to be overwritten by the plugin itself. + # Still, beware! The code file itself could be overwritten... + # But I checked for this earlier. + for code in input_codes: + if isinstance(code, PortableCode): + # Note: this will possibly overwrite files + for root, dirnames, filenames in code.base.repository.walk(): + # mkdir of root + transport.makedirs(str(workdir.joinpath(root)), ignore_existing=True) + + # remotely mkdir first + for dirname in dirnames: + transport.makedirs(str(workdir.joinpath(root, dirname)), ignore_existing=True) + + # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in + # combination with the new `Transport.put_object_from_filelike` + # Since the content of the node could potentially be binary, we read the raw bytes and pass them on + for filename in filenames: + with NamedTemporaryFile(mode='wb+') as handle: + content = code.base.repository.get_object_content(Path(root) / filename, mode='rb') + handle.write(content) + handle.flush() + await transport.put_async(handle.name, str(workdir.joinpath(root, filename))) + if code.filepath_executable.is_absolute(): + transport.chmod(str(code.filepath_executable), 0o755) # rwxr-xr-x + else: + transport.chmod(str(workdir.joinpath(code.filepath_executable)), 0o755) # rwxr-xr-x + + # local_copy_list is a list of tuples, each with (uuid, dest_path, rel_path) + # NOTE: validation of these lists are done inside calculation.presubmit() + local_copy_list = calc_info.local_copy_list or [] + remote_copy_list = calc_info.remote_copy_list or [] + remote_symlink_list = calc_info.remote_symlink_list or [] + provenance_exclude_list = calc_info.provenance_exclude_list or [] + + file_copy_operation_order = calc_info.file_copy_operation_order or [ + FileCopyOperation.SANDBOX, + FileCopyOperation.LOCAL, + FileCopyOperation.REMOTE, + ] + + for file_copy_operation in file_copy_operation_order: + if file_copy_operation is FileCopyOperation.LOCAL: + await _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir) + elif file_copy_operation is FileCopyOperation.REMOTE: + if not dry_run: + _copy_remote_files( + logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir=workdir + ) + elif file_copy_operation is FileCopyOperation.SANDBOX: + if not dry_run: + await _copy_sandbox_files(logger, node, transport, folder, workdir=workdir) + else: + raise RuntimeError(f'file copy operation {file_copy_operation} is not yet implemented.') + + # In a dry_run, the working directory is the raw input folder, which will already contain these resources + if dry_run: + if remote_copy_list: + filepath = os.path.join(str(workdir), '_aiida_remote_copy_list.txt') + with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] + for _, remote_abs_path, dest_rel_path in remote_copy_list: + handle.write( + f'would have copied {remote_abs_path} to {dest_rel_path} in working ' + f'directory on remote {computer.label}' + ) + + if remote_symlink_list: + filepath = os.path.join(str(workdir), '_aiida_remote_symlink_list.txt') + with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] + for _, remote_abs_path, dest_rel_path in remote_symlink_list: + handle.write( + f'would have created symlinks from {remote_abs_path} to {dest_rel_path} in working' + f'directory on remote {computer.label}' + ) + + # Loop recursively over content of the sandbox folder copying all that are not in `provenance_exclude_list`. Note + # that directories are not created explicitly. The `node.put_object_from_filelike` call will create intermediate + # directories for nested files automatically when needed. This means though that empty folders in the sandbox or + # folders that would be empty when considering the `provenance_exclude_list` will *not* be copied to the repo. The + # advantage of this explicit copying instead of deleting the files from `provenance_exclude_list` from the sandbox + # first before moving the entire remaining content to the node's repository, is that in this way we are guaranteed + # not to accidentally move files to the repository that should not go there at all cost. Note that all entries in + # the provenance exclude list are normalized first, just as the paths that are in the sandbox folder, otherwise the + # direct equality test may fail, e.g.: './path/file.txt' != 'path/file.txt' even though they reference the same file + provenance_exclude_list = [os.path.normpath(entry) for entry in provenance_exclude_list] + + for root, _, filenames in os.walk(folder.abspath): + for filename in filenames: + filepath = os.path.join(root, filename) + relpath = os.path.normpath(os.path.relpath(filepath, folder.abspath)) + dirname = os.path.dirname(relpath) + + # Construct a list of all (partial) filepaths + # For example, if `relpath == 'some/sub/directory/file.txt'` then the list of relative directory paths is + # ['some', 'some/sub', 'some/sub/directory'] + # This is necessary, because if any of these paths is in the `provenance_exclude_list` the file should not + # be copied over. + components = dirname.split(os.sep) + dirnames = [os.path.join(*components[:i]) for i in range(1, len(components) + 1)] + if relpath not in provenance_exclude_list and all( + dirname not in provenance_exclude_list for dirname in dirnames + ): + with open(filepath, 'rb') as handle: # type: ignore[assignment] + node.base.repository._repository.put_object_from_filelike(handle, relpath) + + # Since the node is already stored, we cannot use the normal repository interface since it will raise a + # `ModificationNotAllowed` error. To bypass it, we go straight to the underlying repository instance to store the + # files, however, this means we have to manually update the node's repository metadata. + node.base.repository._update_repository_metadata() + + if not dry_run: + return RemoteData(computer=computer, remote_path=str(workdir)) + + return None + + +def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path): + """Perform the copy instructions of the ``remote_copy_list`` and ``remote_symlink_list``.""" + for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: + if remote_computer_uuid == computer.uuid: + logger.debug( + f'[submission of calculation {node.pk}] copying {dest_rel_path} ' + f'remotely, directly on the machine {computer.label}' + ) + try: + transport.copy(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + except FileNotFoundError: + logger.warning( + f'[submission of calculation {node.pk}] Unable to copy remote ' + f'resource from {remote_abs_path} to {dest_rel_path}! NOT Stopping but just ignoring!.' + ) + except OSError: + logger.warning( + f'[submission of calculation {node.pk}] Unable to copy remote ' + f'resource from {remote_abs_path} to {dest_rel_path}! Stopping.' + ) + raise + else: + raise NotImplementedError( + f'[submission of calculation {node.pk}] Remote copy between two different machines is ' + 'not implemented yet' + ) + + for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_symlink_list: + if remote_computer_uuid == computer.uuid: + logger.debug( + f'[submission of calculation {node.pk}] copying {dest_rel_path} remotely, ' + f'directly on the machine {computer.label}' + ) + remote_dirname = Path(dest_rel_path).parent + try: + transport.makedirs(str(workdir.joinpath(remote_dirname)), ignore_existing=True) + transport.symlink(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + except OSError: + logger.warning( + f'[submission of calculation {node.pk}] Unable to create remote symlink ' + f'from {remote_abs_path} to {dest_rel_path}! Stopping.' + ) + raise + else: + raise OSError( + f'It is not possible to create a symlink between two different machines for calculation {node.pk}' + ) + + +async def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: Path): + """Perform the copy instructions of the ``local_copy_list``.""" + for uuid, filename, target in local_copy_list: + logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}') + + try: + data_node = load_node(uuid=uuid) + except exceptions.NotExistent: + data_node = _find_data_node(inputs, uuid) if inputs else None + + if data_node is None: + logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') + continue + + # The transport class can only copy files directly from the file system, so the files in the source node's repo + # have to first be copied to a temporary directory on disk. + with TemporaryDirectory() as tmpdir: + dirpath = Path(tmpdir) + + # If no explicit source filename is defined, we assume the top-level directory + filename_source = filename or '.' + filename_target = target or '.' + + file_type_source = data_node.base.repository.get_object(filename_source).file_type + + # The logic below takes care of an edge case where the source is a file but the target is a directory. In + # this case, the v2.5.1 implementation would raise an `IsADirectoryError` exception, because it would try + # to open the directory in the sandbox folder as a file when writing the contents. + if file_type_source == FileType.FILE and target and transport.isdir(str(workdir.joinpath(target))): + raise IsADirectoryError + + # In case the source filename is specified and it is a directory that already exists in the remote, we + # want to avoid nested directories in the target path to replicate the behavior of v2.5.1. This is done by + # setting the target filename to '.', which means the contents of the node will be copied in the top level + # of the temporary directory, whose contents are then copied into the target directory. + if filename and transport.isdir(str(workdir.joinpath(filename))): + filename_target = '.' + + filepath_target = (dirpath / filename_target).resolve().absolute() + filepath_target.parent.mkdir(parents=True, exist_ok=True) + + if file_type_source == FileType.DIRECTORY: + # If the source object is a directory, we copy its entire contents + data_node.base.repository.copy_tree(filepath_target, filename_source) + await transport.put_async( + f'{dirpath}/*', + str(workdir.joinpath(target)) if target else str(workdir.joinpath('.')), + overwrite=True, + ) + else: + # Otherwise, simply copy the file + with filepath_target.open('wb') as handle: + with data_node.base.repository.open(filename_source, 'rb') as source: + shutil.copyfileobj(source, handle) + transport.makedirs(str(workdir.joinpath(Path(target).parent)), ignore_existing=True) + await transport.put_async(str(filepath_target), str(workdir.joinpath(target))) + + +async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path): + """Copy the contents of the sandbox folder to the working directory.""" + for filename in folder.get_content_list(): + logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...') + await transport.put_async(folder.get_abs_path(filename), str(workdir.joinpath(filename))) + + +def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | ExitCode: + """Submit a previously uploaded `CalcJob` to the scheduler. + + :param calculation: the instance of CalcJobNode to submit. + :param transport: an already opened transport to use to submit the calculation. + :return: the job id as returned by the scheduler `submit_job` call + """ + job_id = calculation.get_job_id() + + # If the `job_id` attribute is already set, that means this function was already executed once and the scheduler + # submit command was successful as the job id it returned was set on the node. This scenario can happen when the + # daemon runner gets shutdown right after accomplishing the submission task, but before it gets the chance to + # finalize the state transition of the `CalcJob` to the `UPDATE` transport task. Since the job is already submitted + # we do not want to submit it a second time, so we simply return the existing job id here. + if job_id is not None: + return job_id + + scheduler = calculation.computer.get_scheduler() + scheduler.set_transport(transport) + + submit_script_filename = calculation.get_option('submit_script_filename') + workdir = calculation.get_remote_workdir() + result = scheduler.submit_job(workdir, submit_script_filename) + + if isinstance(result, str): + calculation.set_job_id(result) + + return result + + +def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: + """Stash files from the working directory of a completed calculation to a permanent remote folder. + + After a calculation has been completed, optionally stash files from the work directory to a storage location on the + same remote machine. This is useful if one wants to keep certain files from a completed calculation to be removed + from the scratch directory, because they are necessary for restarts, but that are too heavy to retrieve. + Instructions of which files to copy where are retrieved from the `stash.source_list` option. + + :param calculation: the calculation job node. + :param transport: an already opened transport. + """ + from aiida.common.datastructures import StashMode + from aiida.orm import RemoteStashFolderData + + logger_extra = get_dblogger_extra(calculation) + + stash_options = calculation.get_option('stash') + stash_mode = stash_options.get('mode', StashMode.COPY.value) + source_list = stash_options.get('source_list', []) + + if not source_list: + return + + if stash_mode != StashMode.COPY.value: + EXEC_LOGGER.warning(f'stashing mode {stash_mode} is not implemented yet.') + return + + cls = RemoteStashFolderData + + EXEC_LOGGER.debug(f'stashing files for calculation<{calculation.pk}>: {source_list}', extra=logger_extra) + + uuid = calculation.uuid + source_basepath = Path(calculation.get_remote_workdir()) + target_basepath = Path(stash_options['target_base']) / uuid[:2] / uuid[2:4] / uuid[4:] + + for source_filename in source_list: + if transport.has_magic(source_filename): + copy_instructions = [] + for globbed_filename in transport.glob(str(source_basepath / source_filename)): + target_filepath = target_basepath / Path(globbed_filename).relative_to(source_basepath) + copy_instructions.append((globbed_filename, target_filepath)) + else: + copy_instructions = [(source_basepath / source_filename, target_basepath / source_filename)] + + for source_filepath, target_filepath in copy_instructions: + # If the source file is in a (nested) directory, create those directories first in the target directory + target_dirname = target_filepath.parent + transport.makedirs(str(target_dirname), ignore_existing=True) + + try: + transport.copy(str(source_filepath), str(target_filepath)) + except (OSError, ValueError) as exception: + EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') + else: + EXEC_LOGGER.debug(f'stashed {source_filepath} to {target_filepath}') + + remote_stash = cls( + computer=calculation.computer, + target_basepath=str(target_basepath), + stash_mode=StashMode(stash_mode), + source_list=source_list, + ).store() + remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash') + + +async def retrieve_calculation( + calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str +) -> FolderData | None: + """Retrieve all the files of a completed job calculation using the given transport. + + If the job defined anything in the `retrieve_temporary_list`, those entries will be stored in the + `retrieved_temporary_folder`. The caller is responsible for creating and destroying this folder. + + :param calculation: the instance of CalcJobNode to update. + :param transport: an already opened transport to use for the retrieval. + :param retrieved_temporary_folder: the absolute path to a directory in which to store the files + listed, if any, in the `retrieved_temporary_folder` of the jobs CalcInfo. + :returns: The ``FolderData`` into which the files have been retrieved, or ``None`` if the calculation already has + a retrieved output node attached. + """ + logger_extra = get_dblogger_extra(calculation) + workdir = calculation.get_remote_workdir() + filepath_sandbox = get_config_option('storage.sandbox') or None + + EXEC_LOGGER.debug(f'Retrieving calc {calculation.pk}', extra=logger_extra) + EXEC_LOGGER.debug(f'[retrieval of calc {calculation.pk}] chdir {workdir}', extra=logger_extra) + + # If the calculation already has a `retrieved` folder, simply return. The retrieval was apparently already completed + # before, which can happen if the daemon is restarted and it shuts down after retrieving but before getting the + # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the retrieval. + link_label = calculation.link_label_retrieved + if calculation.base.links.get_outgoing(FolderData, link_label_filter=link_label).first(): + EXEC_LOGGER.warning( + f'CalcJobNode<{calculation.pk}> already has a `{link_label}` output folder: skipping retrieval' + ) + return + + # Create the FolderData node into which to store the files that are to be retrieved + retrieved_files = FolderData() + + with transport: + # First, retrieve the files of folderdata + retrieve_list = calculation.get_retrieve_list() + retrieve_temporary_list = calculation.get_retrieve_temporary_list() + + with SandboxFolder(filepath_sandbox) as folder: + await retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list) + # Here I retrieved everything; now I store them inside the calculation + retrieved_files.base.repository.put_object_from_tree(folder.abspath) + + # Retrieve the temporary files in the retrieved_temporary_folder if any files were + # specified in the 'retrieve_temporary_list' key + if retrieve_temporary_list: + await retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list) + + # Log the files that were retrieved in the temporary folder + for filename in os.listdir(retrieved_temporary_folder): + EXEC_LOGGER.debug( + f"[retrieval of calc {calculation.pk}] Retrieved temporary file or folder '{filename}'", + extra=logger_extra, + ) + + # Store everything + EXEC_LOGGER.debug( + f'[retrieval of calc {calculation.pk}] Storing retrieved_files={retrieved_files.pk}', extra=logger_extra + ) + retrieved_files.store() + + return retrieved_files + + +def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: + """Kill the calculation through the scheduler + + :param calculation: the instance of CalcJobNode to kill. + :param transport: an already opened transport to use to address the scheduler + """ + job_id = calculation.get_job_id() + + if job_id is None: + # the calculation has not yet been submitted to the scheduler + return + + # Get the scheduler plugin class and initialize it with the correct transport + scheduler = calculation.computer.get_scheduler() + scheduler.set_transport(transport) + + # Call the proper kill method for the job ID of this calculation + result = scheduler.kill_job(job_id) + + if result is not True: + # Failed to kill because the job might have already been completed + running_jobs = scheduler.get_jobs(jobs=[job_id], as_dict=True) + job = running_jobs.get(job_id, None) + + # If the job is returned it is still running and the kill really failed, so we raise + if job is not None and job.job_state != JobState.DONE: + raise exceptions.RemoteOperationError(f'scheduler.kill_job({job_id}) was unsuccessful') + else: + EXEC_LOGGER.warning( + 'scheduler.kill_job() failed but job<{%s}> no longer seems to be running regardless', job_id + ) + + +async def retrieve_files_from_list( + calculation: CalcJobNode, + transport: Transport, + folder: str, + retrieve_list: List[Union[str, Tuple[str, str, int], list]], +) -> None: + """Retrieve all the files in the retrieve_list from the remote into the + local folder instance through the transport. The entries in the retrieve_list + can be of two types: + + * a string + * a list + + If it is a string, it represents the remote absolute or relative filepath of the file. + If the item is a list, the elements will correspond to the following: + + * remotepath (relative path) + * localpath + * depth + + If the remotepath contains file patterns with wildcards, the localpath will be + treated as the work directory of the folder and the depth integer determines + upto what level of the original remotepath nesting the files will be copied. + + :param transport: the Transport instance. + :param folder: an absolute path to a folder that contains the files to copy. + :param retrieve_list: the list of files to retrieve. + """ + workdir = Path(calculation.get_remote_workdir()) + for item in retrieve_list: + if isinstance(item, (list, tuple)): + tmp_rname, tmp_lname, depth = item + # if there are more than one file I do something differently + if transport.has_magic(tmp_rname): + remote_names = transport.glob(str(workdir.joinpath(tmp_rname))) + local_names = [] + for rem in remote_names: + # get the relative path so to make local_names relative + rel_rem = os.path.relpath(rem, str(workdir)) + if depth is None: + local_names.append(os.path.join(tmp_lname, rel_rem)) + else: + to_append = rel_rem.split(os.path.sep)[-depth:] if depth > 0 else [] + local_names.append(os.path.sep.join([tmp_lname] + to_append)) + else: + remote_names = [tmp_rname] + to_append = tmp_rname.split(os.path.sep)[-depth:] if depth > 0 else [] + local_names = [os.path.sep.join([tmp_lname] + to_append)] + if depth is None or depth > 1: # create directories in the folder, if needed + for this_local_file in local_names: + new_folder = os.path.join(folder, os.path.split(this_local_file)[0]) + if not os.path.exists(new_folder): + os.makedirs(new_folder) + else: + abs_item = item if item.startswith('/') else str(workdir.joinpath(item)) + + if transport.has_magic(abs_item): + remote_names = transport.glob(abs_item) + local_names = [os.path.split(rem)[1] for rem in remote_names] + else: + remote_names = [abs_item] + local_names = [os.path.split(abs_item)[1]] + + for rem, loc in zip(remote_names, local_names): + transport.logger.debug(f"[retrieval of calc {calculation.pk}] Trying to retrieve remote item '{rem}'") + + if rem.startswith('/'): + to_get = rem + else: + to_get = str(workdir.joinpath(rem)) + + await transport.get_async(to_get, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/src/aiida/engine/processes/calcjobs/calcjob.py b/src/aiida/engine/processes/calcjobs/calcjob.py index e69de29bb2..264be0c6b7 100644 --- a/src/aiida/engine/processes/calcjobs/calcjob.py +++ b/src/aiida/engine/processes/calcjobs/calcjob.py @@ -0,0 +1,1115 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Implementation of the CalcJob process.""" + +from __future__ import annotations + +import dataclasses +import io +import json +import os +import shutil +from typing import Any, Dict, Hashable, Optional, Type, Union + +import plumpy.ports +import plumpy.process_states + +from aiida import orm +from aiida.common import AttributeDict, exceptions +from aiida.common.datastructures import CalcInfo, FileCopyOperation +from aiida.common.folders import Folder +from aiida.common.lang import classproperty, override +from aiida.common.links import LinkType + +from ..exit_code import ExitCode +from ..ports import PortNamespace +from ..process import Process, ProcessState +from ..process_spec import CalcJobProcessSpec +from .importer import CalcJobImporter +from .monitors import CalcJobMonitor +from .tasks import UPLOAD_COMMAND, Waiting + +__all__ = ('CalcJob',) + + +def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: + """Validate the entire set of inputs passed to the `CalcJob` constructor. + + Reasons that will cause this validation to raise an `InputValidationError`: + + * No `Computer` has been specified, neither directly in `metadata.computer` nor indirectly through the `Code` input + * The specified computer is not stored + * The `Computer` specified in `metadata.computer` is not the same as that of the specified `Code` + * No `Code` has been specified and no `remote_folder` input has been specified, i.e. this is no import run + + :return: string with error message in case the inputs are invalid + """ + try: + ctx.get_port('code') + ctx.get_port('metadata.computer') + except ValueError: + # If the namespace no longer contains the `code` or `metadata.computer` ports we skip validation + return None + + remote_folder = inputs.get('remote_folder', None) + + if remote_folder is not None: + # The `remote_folder` input has been specified and so this concerns an import run, which means that neither + # a `Code` nor a `Computer` are required. However, they are allowed to be specified but will not be explicitly + # checked for consistency. + return None + + code = inputs.get('code', None) + computer_from_code = code.computer + computer_from_metadata = inputs.get('metadata', {}).get('computer', None) + + if not computer_from_code and not computer_from_metadata: + return 'no computer has been specified in `metadata.computer` nor via `code`.' + + if computer_from_code and not computer_from_code.is_stored: + return f'the Computer<{computer_from_code}> is not stored' + + if computer_from_metadata and not computer_from_metadata.is_stored: + return f'the Computer<{computer_from_metadata}> is not stored' + + if computer_from_code and computer_from_metadata and computer_from_code.uuid != computer_from_metadata.uuid: + return ( + 'Computer<{}> explicitly defined in `metadata.computer` is different from Computer<{}> which is the ' + 'computer of Code<{}> defined as the `code` input.'.format(computer_from_metadata, computer_from_code, code) + ) + + try: + resources_port = ctx.get_port('metadata.options.resources') + except ValueError: + return None + + # If the resources port exists but is not required, we don't need to validate it against the computer's scheduler + if not resources_port.required: + return None + + computer = computer_from_code or computer_from_metadata + scheduler = computer.get_scheduler() + try: + resources = inputs['metadata']['options']['resources'] + except KeyError: + return 'input `metadata.options.resources` is required but is not specified' + + scheduler.preprocess_resources(resources, computer.get_default_mpiprocs_per_machine()) + + try: + scheduler.validate_resources(**resources) + except ValueError as exception: + return f'input `metadata.options.resources` is not valid for the `{scheduler}` scheduler: {exception}' + + return None + + +def validate_stash_options(stash_options: Any, _: Any) -> Optional[str]: + """Validate the ``stash`` options.""" + from aiida.common.datastructures import StashMode + + target_base = stash_options.get('target_base', None) + source_list = stash_options.get('source_list', None) + stash_mode = stash_options.get('mode', StashMode.COPY.value) + + if not isinstance(target_base, str) or not os.path.isabs(target_base): + return f'`metadata.options.stash.target_base` should be an absolute filepath, got: {target_base}' + + if not isinstance(source_list, (list, tuple)) or any( + not isinstance(src, str) or os.path.isabs(src) for src in source_list + ): + port = 'metadata.options.stash.source_list' + return f'`{port}` should be a list or tuple of relative filepaths, got: {source_list}' + + try: + StashMode(stash_mode) + except ValueError: + port = 'metadata.options.stash.mode' + return f'`{port}` should be a member of aiida.common.datastructures.StashMode, got: {stash_mode}' + + return None + + +def validate_monitors(monitors: Any, _: PortNamespace) -> Optional[str]: + """Validate the ``monitors`` input namespace.""" + for key, monitor_node in monitors.items(): + try: + CalcJobMonitor(**monitor_node.get_dict()) + except (exceptions.EntryPointError, TypeError, ValueError) as exception: + return f'`monitors.{key}` is invalid: {exception}' + return None + + +def validate_parser(parser_name: Any, _: PortNamespace) -> Optional[str]: + """Validate the parser. + + :return: string with error message in case the inputs are invalid + """ + from aiida.plugins import ParserFactory + + try: + ParserFactory(parser_name) + except exceptions.EntryPointError as exception: + return f'invalid parser specified: {exception}' + + return None + + +def validate_additional_retrieve_list(additional_retrieve_list: Any, _: Any) -> Optional[str]: + """Validate the additional retrieve list. + + :return: string with error message in case the input is invalid. + """ + if any(not isinstance(value, str) or os.path.isabs(value) for value in additional_retrieve_list): + return f'`additional_retrieve_list` should only contain relative filepaths but got: {additional_retrieve_list}' + + return None + + +class CalcJob(Process): + """Implementation of the CalcJob process.""" + + _node_class = orm.CalcJobNode + _spec_class = CalcJobProcessSpec + link_label_retrieved: str = 'retrieved' + KEY_CACHE_VERSION: str = 'cache_version' + CACHE_VERSION: int | None = None + + def __init__(self, *args, **kwargs) -> None: + """Construct a CalcJob instance. + + Construct the instance only if it is a sub class of `CalcJob`, otherwise raise `InvalidOperation`. + + See documentation of :class:`aiida.engine.Process`. + """ + if self.__class__ == CalcJob: + raise exceptions.InvalidOperation('cannot construct or launch a base `CalcJob` class.') + + super().__init__(*args, **kwargs) + + @classmethod + def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] + """Define the process specification, including its inputs, outputs and known exit codes. + + Ports are added to the `metadata` input namespace (inherited from the base Process), + and a `code` input Port, a `remote_folder` output Port and retrieved folder output Port + are added. + + :param spec: the calculation job process spec to define. + """ + super().define(spec) + spec.inputs.validator = validate_calc_job # type: ignore[assignment] # takes only PortNamespace not Port + spec.input( + 'code', + valid_type=orm.AbstractCode, + required=False, + help='The `Code` to use for this job. This input is required, unless the `remote_folder` input is ' + 'specified, which means an existing job is being imported and no code will actually be run.', + ) + spec.input_namespace( + 'monitors', + valid_type=orm.Dict, + required=False, + validator=validate_monitors, + help='Add monitoring functions that can inspect output files while the job is running and decide to ' + 'prematurely terminate the job.', + ) + spec.input( + 'remote_folder', + valid_type=orm.RemoteData, + required=False, + help='Remote directory containing the results of an already completed calculation job without AiiDA. The ' + 'inputs should be passed to the `CalcJob` as normal but instead of launching the actual job, the ' + 'engine will recreate the input files and then proceed straight to the retrieve step where the files ' + 'of this `RemoteData` will be retrieved as if it had been actually launched through AiiDA. If a ' + 'parser is defined in the inputs, the results are parsed and attached as output nodes as usual.', + ) + spec.input( + 'metadata.dry_run', + valid_type=bool, + default=False, + help='When set to `True` will prepare the calculation job for submission but not actually launch it.', + ) + spec.input( + 'metadata.computer', + valid_type=orm.Computer, + required=False, + help='When using a "local" code, set the computer on which the calculation should be run.', + ) + spec.input_namespace(f'{spec.metadata_key}.{spec.options_key}', required=False) + spec.input( + 'metadata.options.input_filename', + valid_type=str, + required=False, + help='Filename to which the input for the code that is to be run is written.', + ) + spec.input( + 'metadata.options.output_filename', + valid_type=str, + required=False, + help='Filename to which the content of stdout of the code that is to be run is written.', + ) + spec.input( + 'metadata.options.submit_script_filename', + valid_type=str, + default='_aiidasubmit.sh', + help='Filename to which the job submission script is written.', + ) + spec.input( + 'metadata.options.scheduler_stdout', + valid_type=str, + default='_scheduler-stdout.txt', + help='Filename to which the content of stdout of the scheduler is written.', + ) + spec.input( + 'metadata.options.scheduler_stderr', + valid_type=str, + default='_scheduler-stderr.txt', + help='Filename to which the content of stderr of the scheduler is written.', + ) + spec.input( + 'metadata.options.resources', + valid_type=dict, + required=True, + help='Set the dictionary of resources to be used by the scheduler plugin, like the number of nodes, ' + 'cpus etc. This dictionary is scheduler-plugin dependent. Look at the documentation of the ' + 'scheduler for more details.', + ) + spec.input( + 'metadata.options.max_wallclock_seconds', + valid_type=int, + required=False, + help='Set the wallclock in seconds asked to the scheduler', + ) + spec.input( + 'metadata.options.custom_scheduler_commands', + valid_type=str, + default='', + help='Set a (possibly multiline) string with the commands that the user wants to manually set for the ' + 'scheduler. The difference of this option with respect to the `prepend_text` is the position in ' + 'the scheduler submission file where such text is inserted: with this option, the string is ' + 'inserted before any non-scheduler command', + ) + spec.input( + 'metadata.options.queue_name', + valid_type=str, + required=False, + help='Set the name of the queue on the remote computer', + ) + spec.input( + 'metadata.options.rerunnable', + valid_type=bool, + required=False, + help='Determines if the calculation can be requeued / rerun.', + ) + spec.input( + 'metadata.options.account', + valid_type=str, + required=False, + help='Set the account to use in for the queue on the remote computer', + ) + spec.input( + 'metadata.options.qos', + valid_type=str, + required=False, + help='Set the quality of service to use in for the queue on the remote computer', + ) + spec.input( + 'metadata.options.withmpi', + valid_type=bool, + required=False, + help='Set the calculation to use mpi', + ) + spec.input( + 'metadata.options.mpirun_extra_params', + valid_type=(list, tuple), + default=lambda: [], + help='Set the extra params to pass to the mpirun (or equivalent) command after the one provided in ' + 'computer.mpirun_command. Example: mpirun -np 8 extra_params[0] extra_params[1] ... exec.x', + ) + spec.input( + 'metadata.options.import_sys_environment', + valid_type=bool, + default=True, + help='If set to true, the submission script will load the system environment variables', + ) + spec.input( + 'metadata.options.environment_variables', + valid_type=dict, + default=lambda: {}, + help='Set a dictionary of custom environment variables for this calculation', + ) + spec.input( + 'metadata.options.environment_variables_double_quotes', + valid_type=bool, + default=False, + help='If set to True, use double quotes instead of single quotes to escape the environment variables ' + 'specified in ``environment_variables``.', + ) + spec.input( + 'metadata.options.priority', valid_type=str, required=False, help='Set the priority of the job to be queued' + ) + spec.input( + 'metadata.options.max_memory_kb', + valid_type=int, + required=False, + help='Set the maximum memory (in KiloBytes) to be asked to the scheduler', + ) + spec.input( + 'metadata.options.prepend_text', + valid_type=str, + default='', + help='Set the calculation-specific prepend text, which is going to be prepended in the scheduler-job ' + 'script, just before the code execution', + ) + spec.input( + 'metadata.options.append_text', + valid_type=str, + default='', + help='Set the calculation-specific append text, which is going to be appended in the scheduler-job ' + 'script, just after the code execution', + ) + spec.input( + 'metadata.options.parser_name', + valid_type=str, + required=False, + validator=validate_parser, + help='Set a string for the output parser. Can be None if no output plugin is available or needed', + ) + spec.input( + 'metadata.options.additional_retrieve_list', + required=False, + valid_type=(list, tuple), + validator=validate_additional_retrieve_list, + help='List of relative file paths that should be retrieved in addition to what the plugin specifies.', + ) + spec.input_namespace( + 'metadata.options.stash', + required=False, + populate_defaults=False, + validator=validate_stash_options, + help='Optional directives to stash files after the calculation job has completed.', + ) + spec.input( + 'metadata.options.stash.target_base', + valid_type=str, + required=False, + help='The base location to where the files should be stashd. For example, for the `copy` stash mode, this ' + 'should be an absolute filepath on the remote computer.', + ) + spec.input( + 'metadata.options.stash.source_list', + valid_type=(tuple, list), + required=False, + help='Sequence of relative filepaths representing files in the remote directory that should be stashed.', + ) + spec.input( + 'metadata.options.stash.stash_mode', + valid_type=str, + required=False, + help='Mode with which to perform the stashing, should be value of `aiida.common.datastructures.StashMode`.', + ) + + spec.output( + 'remote_folder', + valid_type=orm.RemoteData, + help='Input files necessary to run the process will be stored in this folder node.', + ) + spec.output( + 'remote_stash', + valid_type=orm.RemoteStashData, + required=False, + help='Contents of the `stash.source_list` option are stored in this remote folder after job completion.', + ) + spec.output( + cls.link_label_retrieved, + valid_type=orm.FolderData, + pass_to_parser=True, + help='Files that are retrieved by the daemon will be stored in this node. By default the stdout and stderr ' + 'of the scheduler will be added, but one can add more by specifying them in `CalcInfo.retrieve_list`.', + ) + + spec.exit_code( + 100, + 'ERROR_NO_RETRIEVED_FOLDER', + invalidates_cache=True, + message='The process did not have the required `retrieved` output.', + ) + spec.exit_code( + 110, 'ERROR_SCHEDULER_OUT_OF_MEMORY', invalidates_cache=True, message='The job ran out of memory.' + ) + spec.exit_code( + 120, 'ERROR_SCHEDULER_OUT_OF_WALLTIME', invalidates_cache=True, message='The job ran out of walltime.' + ) + spec.exit_code( + 131, 'ERROR_SCHEDULER_INVALID_ACCOUNT', invalidates_cache=True, message='The specified account is invalid.' + ) + spec.exit_code( + 140, 'ERROR_SCHEDULER_NODE_FAILURE', invalidates_cache=True, message='The node running the job failed.' + ) + spec.exit_code(150, 'STOPPED_BY_MONITOR', invalidates_cache=True, message='{message}') + + @classproperty + def spec_options(cls): # noqa: N805 + """Return the metadata options port namespace of the process specification of this process. + + :return: options dictionary + :rtype: dict + """ + return cls.spec_metadata['options'] + + @classmethod + def get_importer(cls, entry_point_name: str | None = None) -> CalcJobImporter: + """Load the `CalcJobImporter` associated with this `CalcJob` if it exists. + + By default an importer with the same entry point as the ``CalcJob`` will be loaded, however, this can be + overridden using the ``entry_point_name`` argument. + + :param entry_point_name: optional entry point name of a ``CalcJobImporter`` to override the default. + :return: the loaded ``CalcJobImporter``. + :raises: if no importer class could be loaded. + """ + from aiida.plugins import CalcJobImporterFactory + from aiida.plugins.entry_point import get_entry_point_from_class + + if entry_point_name is None: + _, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__) + if entry_point is not None: + entry_point_name = entry_point.name + + assert entry_point_name is not None + + return CalcJobImporterFactory(entry_point_name)() + + @property + def options(self) -> AttributeDict: + """Return the options of the metadata that were specified when this process instance was launched. + + :return: options dictionary + + """ + try: + return self.metadata.options + except AttributeError: + return AttributeDict() + + @classmethod + def get_state_classes(cls) -> Dict[Hashable, Type[plumpy.process_states.State]]: + """A mapping of the State constants to the corresponding state class. + + Overrides the waiting state with the Calcjob specific version. + """ + # Overwrite the waiting state + states_map = super().get_state_classes() + states_map[ProcessState.WAITING] = Waiting + return states_map + + @property + def node(self) -> orm.CalcJobNode: + return super().node # type: ignore[return-value] + + @override + def on_terminated(self) -> None: + """Cleanup the node by deleting the calulation job state. + + .. note:: This has to be done before calling the super because that will seal the node after we cannot change it + """ + self.node.delete_state() + super().on_terminated() + + @override + async def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: + """Run the calculation job. + + This means invoking the `presubmit` and storing the temporary folder in the node's repository. Then we move the + process in the `Wait` state, waiting for the `UPLOAD` transport task to be started. + + :returns: the `Stop` command if a dry run, int if the process has an exit status, + `Wait` command if the calcjob is to be uploaded + + """ + if self.inputs.metadata.dry_run: + await self._perform_dry_run() + return plumpy.process_states.Stop(None, True) + + if 'remote_folder' in self.inputs: + exit_code = await self._perform_import() + return exit_code + + # The following conditional is required for the caching to properly work. Even if the source node has a process + # state of `Finished` the cached process will still enter the running state. The process state will have then + # been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other + # than `None`, it should mean this node was taken from the cache, so the process should not be rerun. + if self.node.exit_status is not None: + # Normally the outputs will be attached to the process by a ``Parser``, if defined in the inputs. But in + # this case, the parser will not be called. The outputs will already have been added to the process node + # though, so all that needs to be done here is just also assign them to the process instance. This such that + # when the process returns its results, it returns the actual outputs and not an empty dictionary. + self._outputs = self.node.base.links.get_outgoing(link_type=LinkType.CREATE).nested() + return self.node.exit_status + + # Launch the upload operation + return plumpy.process_states.Wait(msg='Waiting to upload', data=UPLOAD_COMMAND) + + def prepare_for_submission(self, folder: Folder) -> CalcInfo: + """Prepare the calculation for submission. + + Convert the input nodes into the corresponding input files in the format that the code will expect. In addition, + define and return a `CalcInfo` instance, which is a simple data structure that contains information for the + engine, for example, on what files to copy to the remote machine, what files to retrieve once it has completed, + specific scheduler settings and more. + + :param folder: a temporary folder on the local file system. + :returns: the `CalcInfo` instance + """ + raise NotImplementedError() + + def _setup_version_info(self) -> dict[str, Any]: + """Store relevant plugin version information.""" + from aiida.plugins.entry_point import format_entry_point_string + from aiida.plugins.factories import ParserFactory + + version_info = super()._setup_version_info() + + for key, monitor in self.inputs.get('monitors', {}).items(): + entry_point = monitor.base.attributes.get('entry_point') + entry_point_string = format_entry_point_string('aiida.calculations.monitors', entry_point) + monitor_version_info = self.runner.plugin_version_provider.get_version_info(entry_point_string) + version_info['version'].setdefault('monitors', {})[key] = monitor_version_info['version']['plugin'] + + cache_version_info = {} + + if self.CACHE_VERSION is not None: + cache_version_info['calc_job'] = self.CACHE_VERSION + + parser_entry_point = self.inputs.metadata.options.get('parser_name') + + if parser_entry_point is not None: + try: + parser = ParserFactory(self.inputs.metadata.options.parser_name) + except exceptions.EntryPointError: + self.logger.warning(f'Could not load the `parser_name` entry point `{parser_entry_point}') + else: + if parser.CACHE_VERSION is not None: + cache_version_info['parser'] = parser.CACHE_VERSION + + if cache_version_info: + self.node.base.attributes.set(self.KEY_CACHE_VERSION, cache_version_info) + + return version_info + + def _setup_metadata(self, metadata: dict) -> None: + """Store the metadata on the ProcessNode.""" + computer = metadata.pop('computer', None) + if computer is not None: + self.node.computer = computer + + options = metadata.pop('options', {}) + for option_name, option_value in options.items(): + self.node.set_option(option_name, option_value) + + super()._setup_metadata(metadata) + + def _setup_inputs(self) -> None: + """Create the links between the input nodes and the ProcessNode that represents this process.""" + super()._setup_inputs() + + # If a computer has not yet been set, which should have been done in ``_setup_metadata`` if it was specified + # in the ``metadata`` inputs, set the computer associated with the ``code`` input. Note that not all ``code``s + # will have an associated computer, but in that case the ``computer`` property should return ``None`` and + # nothing would change anyway. + if not self.node.computer: + self.node.computer = self.inputs.code.computer + + async def _perform_dry_run(self): + """Perform a dry run. + + Instead of performing the normal sequence of steps, just the `presubmit` is called, which will call the method + `prepare_for_submission` of the plugin to generate the input files based on the inputs. Then the upload action + is called, but using a normal local transport that will copy the files to a local sandbox folder. The generated + input script and the absolute path to the sandbox folder are stored in the `dry_run_info` attribute of the node + of this process. + """ + from aiida.common.folders import SubmitTestFolder + from aiida.engine.daemon.execmanager import upload_calculation + from aiida.transports.plugins.local import LocalTransport + + with LocalTransport() as transport: + with SubmitTestFolder() as folder: + calc_info = self.presubmit(folder) + await upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) + self.node.dry_run_info = { # type: ignore[attr-defined] + 'folder': folder.abspath, + 'script_filename': self.node.get_option('submit_script_filename'), + } + + async def _perform_import(self): + """Perform the import of an already completed calculation. + + The inputs contained a `RemoteData` under the key `remote_folder` signalling that this is not supposed to be run + as a normal calculation job, but rather the results are already computed outside of AiiDA and merely need to be + imported. + """ + from aiida.common.datastructures import CalcJobState + from aiida.common.folders import SandboxFolder + from aiida.engine.daemon.execmanager import retrieve_calculation + from aiida.manage import get_config_option + from aiida.transports.plugins.local import LocalTransport + + filepath_sandbox = get_config_option('storage.sandbox') or None + + with LocalTransport() as transport: + with SandboxFolder(filepath_sandbox) as folder: + with SandboxFolder(filepath_sandbox) as retrieved_temporary_folder: + self.presubmit(folder) + self.node.set_remote_workdir(self.inputs.remote_folder.get_remote_path()) + retrieved = await retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) + if retrieved is not None: + self.out(self.node.link_label_retrieved, retrieved) + self.update_outputs() + self.node.set_state(CalcJobState.PARSING) + self.node.base.attributes.set(orm.CalcJobNode.IMMIGRATED_KEY, True) + return self.parse(retrieved_temporary_folder.abspath) + + def parse( + self, retrieved_temporary_folder: Optional[str] = None, existing_exit_code: ExitCode | None = None + ) -> ExitCode: + """Parse a retrieved job calculation. + + This is called once it's finished waiting for the calculation to be finished and the data has been retrieved. + + :param retrieved_temporary_folder: The path to the temporary folder + + """ + try: + retrieved = self.node.outputs.retrieved + except exceptions.NotExistent: + return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER + + # Call the scheduler output parser + exit_code_scheduler = self.parse_scheduler_output(retrieved) + + if exit_code_scheduler is not None and exit_code_scheduler.status > 0: + # If an exit code is returned by the scheduler output parser, we log it and set it on the node. This will + # allow the actual `Parser` implementation, if defined in the inputs, to inspect it and decide to keep it, + # or override it with a more specific exit code, if applicable. + msg = f'scheduler parser returned exit code<{exit_code_scheduler.status}>: {exit_code_scheduler.message}' + self.logger.warning(msg) + self.node.set_exit_status(exit_code_scheduler.status) + self.node.set_exit_message(exit_code_scheduler.message) + + # Call the retrieved output parser + try: + exit_code_retrieved = self.parse_retrieved_output(retrieved_temporary_folder) + finally: + if retrieved_temporary_folder is not None: + shutil.rmtree(retrieved_temporary_folder, ignore_errors=True) + + if exit_code_retrieved is not None and exit_code_retrieved.status > 0: + msg = f'output parser returned exit code<{exit_code_retrieved.status}>: {exit_code_retrieved.message}' + self.logger.warning(msg) + + # The final exit code is that of the scheduler, unless the output parser returned one + exit_code: Optional[ExitCode] + if exit_code_retrieved is not None: + exit_code = exit_code_retrieved + else: + exit_code = exit_code_scheduler + + if existing_exit_code is not None: + return existing_exit_code + + return exit_code or ExitCode(0) + + @staticmethod + def terminate(exit_code: ExitCode) -> ExitCode: + """Terminate the process immediately and return the given exit code. + + This method is called by :meth:`aiida.engine.processes.calcjobs.tasks.Waiting.execute` if a monitor triggered + the job to be terminated and specified the parsing to be skipped. It will construct the running state and tell + this method to be run, which returns the given exit code which will cause the process to be terminated. + + :param exit_code: The exit code to return. + :returns: The provided exit code. + """ + return exit_code + + def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: + """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" + computer = self.node.computer + + if computer is None: + self.logger.info( + 'no computer is defined for this calculation job which suggest that it is an imported job and so ' + 'scheduler output probably is not available or not in a format that can be reliably parsed, skipping..' + ) + return None + + scheduler = computer.get_scheduler() + filename_stderr = self.node.get_option('scheduler_stderr') + filename_stdout = self.node.get_option('scheduler_stdout') + + detailed_job_info = self.node.get_detailed_job_info() + + if detailed_job_info is None: + self.logger.info('could not parse scheduler output: the `detailed_job_info` attribute is missing') + elif detailed_job_info.get('retval', 0) != 0: + self.logger.info('could not parse scheduler output: return value of `detailed_job_info` is non-zero') + detailed_job_info = None + + if filename_stderr is None: + self.logger.warning('could not determine `stderr` filename because `scheduler_stderr` option was not set.') + else: + try: + scheduler_stderr = retrieved.base.repository.get_object_content(filename_stderr, mode='r') + except FileNotFoundError: + scheduler_stderr = None + self.logger.warning(f'could not parse scheduler output: the `{filename_stderr}` file is missing') + + if filename_stdout is None: + self.logger.warning('could not determine `stdout` filename because `scheduler_stdout` option was not set.') + else: + try: + scheduler_stdout = retrieved.base.repository.get_object_content(filename_stdout, mode='r') + except FileNotFoundError: + scheduler_stdout = None + self.logger.warning(f'could not parse scheduler output: the `{filename_stdout}` file is missing') + + try: + exit_code = scheduler.parse_output( + detailed_job_info, + scheduler_stdout or '', + scheduler_stderr or '', + ) + except exceptions.FeatureNotAvailable: + self.logger.info(f'`{scheduler.__class__.__name__}` does not implement scheduler output parsing') + return None + except Exception as exception: + self.logger.error(f'the `parse_output` method of the scheduler excepted: {exception}') + return None + + if exit_code is not None and not isinstance(exit_code, ExitCode): + args = (scheduler.__class__.__name__, type(exit_code)) # type: ignore[unreachable] + raise ValueError('`{}.parse_output` returned neither an `ExitCode` nor None, but: {}'.format(*args)) + + return exit_code + + def parse_retrieved_output(self, retrieved_temporary_folder: Optional[str] = None) -> Optional[ExitCode]: + """Parse the retrieved data by calling the parser plugin if it was defined in the inputs.""" + parser_class = self.node.get_parser_class() + + if parser_class is None: + return None + + parser = parser_class(self.node) + parse_kwargs = parser.get_outputs_for_parsing() + + if retrieved_temporary_folder: + parse_kwargs['retrieved_temporary_folder'] = retrieved_temporary_folder + + exit_code = parser.parse(**parse_kwargs) + + for link_label, node in parser.outputs.items(): + try: + self.out(link_label, node) + except ValueError as exception: + self.logger.error(f'invalid value {node} specified with label {link_label}: {exception}') + exit_code = self.exit_codes.ERROR_INVALID_OUTPUT + break + + if exit_code is not None and not isinstance(exit_code, ExitCode): + args = (parser_class.__name__, type(exit_code)) # type: ignore[unreachable] + raise ValueError('`{}.parse` returned neither an `ExitCode` nor None, but: {}'.format(*args)) + + return exit_code + + def presubmit(self, folder: Folder) -> CalcInfo: + """Prepares the calculation folder with all inputs, ready to be copied to the cluster. + + :param folder: a SandboxFolder that can be used to write calculation input files and the scheduling script. + + :return calcinfo: the CalcInfo object containing the information needed by the daemon to handle operations. + + """ + from aiida.common.datastructures import CodeInfo, CodeRunMode + from aiida.common.exceptions import InputValidationError, InvalidOperation, PluginInternalError, ValidationError + from aiida.common.utils import validate_list_of_string_tuples + from aiida.orm import AbstractCode, Computer, load_code + from aiida.schedulers.datastructures import JobTemplate, JobTemplateCodeInfo + + inputs = self.node.base.links.get_incoming(link_type=LinkType.INPUT_CALC) + + if not self.inputs.metadata.dry_run and not self.node.is_stored: + raise InvalidOperation('calculation node is not stored.') + + computer = self.node.computer + assert computer is not None + codes = [_ for _ in inputs.all_nodes() if isinstance(_, AbstractCode)] + + for code in codes: + if not code.can_run_on_computer(computer): + raise InputValidationError( + 'The selected code {} for calculation {} cannot run on computer {}'.format( + code.pk, self.node.pk, computer.label + ) + ) + + code.validate_working_directory(folder) + + calc_info = self.prepare_for_submission(folder) + calc_info.uuid = str(self.node.uuid) + + # I create the job template to pass to the scheduler + job_tmpl = JobTemplate() + job_tmpl.submit_as_hold = False + job_tmpl.rerunnable = self.options.get('rerunnable', False) + # 'email', 'email_on_started', 'email_on_terminated', + job_tmpl.job_name = f'aiida-{self.node.pk}' + job_tmpl.sched_output_path = self.options.scheduler_stdout + if computer is not None: + job_tmpl.shebang = computer.get_shebang() + if self.options.scheduler_stderr == self.options.scheduler_stdout: + job_tmpl.sched_join_files = True + else: + job_tmpl.sched_error_path = self.options.scheduler_stderr + job_tmpl.sched_join_files = False + + # Set retrieve path, add also scheduler STDOUT and STDERR + retrieve_list = calc_info.retrieve_list or [] + if job_tmpl.sched_output_path is not None and job_tmpl.sched_output_path not in retrieve_list: + retrieve_list.append(job_tmpl.sched_output_path) + if not job_tmpl.sched_join_files: + if job_tmpl.sched_error_path is not None and job_tmpl.sched_error_path not in retrieve_list: + retrieve_list.append(job_tmpl.sched_error_path) + retrieve_list.extend(self.node.get_option('additional_retrieve_list') or []) + self.node.set_retrieve_list(retrieve_list) + + # Handle the retrieve_temporary_list + retrieve_temporary_list = calc_info.retrieve_temporary_list or [] + self.node.set_retrieve_temporary_list(retrieve_temporary_list) + + # If the inputs contain a ``remote_folder`` input node, we are in an import scenario and can skip the rest + if 'remote_folder' in inputs.all_link_labels(): + return calc_info + + # The remaining code is only necessary for actual runs, for example, creating the submission script + scheduler = computer.get_scheduler() + + # the if is done so that if the method returns None, this is + # not added. This has two advantages: + # - it does not add too many \n\n if most of the prepend_text are empty + # - most importantly, skips the cases in which one of the methods + # would return None, in which case the join method would raise + # an exception + prepend_texts = ( + [computer.get_prepend_text()] + + [code.prepend_text for code in codes] + + [calc_info.prepend_text, self.node.get_option('prepend_text')] + ) + job_tmpl.prepend_text = '\n\n'.join(prepend_text for prepend_text in prepend_texts if prepend_text) + + append_texts = ( + [self.node.get_option('append_text'), calc_info.append_text] + + [code.append_text for code in codes] + + [computer.get_append_text()] + ) + job_tmpl.append_text = '\n\n'.join(append_text for append_text in append_texts if append_text) + + # Set resources, also with get_default_mpiprocs_per_machine + resources = self.node.get_option('resources') + scheduler.preprocess_resources(resources or {}, computer.get_default_mpiprocs_per_machine()) + job_tmpl.job_resource = scheduler.create_job_resource(**resources) # type: ignore[arg-type] + + subst_dict = {'tot_num_mpiprocs': job_tmpl.job_resource.get_tot_num_mpiprocs()} + + for key, value in job_tmpl.job_resource.items(): + subst_dict[key] = value + mpi_args = [arg.format(**subst_dict) for arg in computer.get_mpirun_command()] + extra_mpirun_params = self.node.get_option('mpirun_extra_params') # same for all codes in the same calc + + # set the codes_info + if not isinstance(calc_info.codes_info, (list, tuple)): + raise PluginInternalError('codes_info passed to CalcInfo must be a list of CalcInfo objects') + + tmpl_codes_info = [] + for code_info in calc_info.codes_info: + if not isinstance(code_info, CodeInfo): + raise PluginInternalError('Invalid codes_info, must be a list of CodeInfo objects') + + if code_info.code_uuid is None: + raise PluginInternalError('CalcInfo should have the information of the code to be launched') + + code = load_code(code_info.code_uuid) + + # Here are the three values that will determine whether the code is to be run with MPI _if_ they are not + # ``None``. If any of them are explicitly defined but are not equivalent, an exception is raised. We use the + # ``self._raw_inputs`` to determine the actual value passed for ``metadata.options.withmpi`` and + # distinghuish it from the default. + raw_inputs = self._raw_inputs or {} # type: ignore[var-annotated] + with_mpi_option = raw_inputs.get('metadata', {}).get('options', {}).get('withmpi', None) + with_mpi_plugin = code_info.withmpi + with_mpi_code = code.with_mpi + + with_mpi_values = [with_mpi_option, with_mpi_plugin, with_mpi_code] + with_mpi_values_defined = [value for value in with_mpi_values if value is not None] + with_mpi_values_set = set(with_mpi_values_defined) + + # If more than one value is defined, they have to be identical, or we raise that a conflict is encountered + if len(with_mpi_values_set) > 1: + error = f'Inconsistent requirements as to whether code `{code}` should be run with or without MPI.' + if with_mpi_option is not None: + error += f'\nThe `metadata.options.withmpi` input was set to `{with_mpi_option}`.' + if with_mpi_plugin is not None: + error += f'\nThe plugin require `{with_mpi_plugin}`.' + if with_mpi_code is not None: + error += f'\nThe code `{code}` required `{with_mpi_code}`.' + raise RuntimeError(error) + + # At this point we know that the three explicit values agree if they are defined, so we simply set the value + if with_mpi_values_set: + with_mpi = with_mpi_values_set.pop() + else: + # Fall back to the default, which is the default of the option in the process input specification with + # ``False`` as final fallback if the default is not even specified + try: + with_mpi = self.spec().inputs['metadata']['options']['withmpi'].default # type: ignore[index] + except RuntimeError: + # ``plumpy.InputPort.default`` raises a ``RuntimeError`` if no default has been set. This is bad + # design and should be changed, but we have to deal with it like this for now. + with_mpi = False + + if with_mpi: + prepend_cmdline_params = code.get_prepend_cmdline_params(mpi_args, extra_mpirun_params) + else: + prepend_cmdline_params = code.get_prepend_cmdline_params() + + cmdline_params = code.get_executable_cmdline_params(code_info.cmdline_params) + + tmpl_code_info = JobTemplateCodeInfo() + tmpl_code_info.prepend_cmdline_params = prepend_cmdline_params + tmpl_code_info.cmdline_params = cmdline_params + tmpl_code_info.use_double_quotes = [computer.get_use_double_quotes(), code.use_double_quotes] + tmpl_code_info.wrap_cmdline_params = code.wrap_cmdline_params + tmpl_code_info.stdin_name = code_info.stdin_name + tmpl_code_info.stdout_name = code_info.stdout_name + tmpl_code_info.stderr_name = code_info.stderr_name + tmpl_code_info.join_files = code_info.join_files or False + + tmpl_codes_info.append(tmpl_code_info) + job_tmpl.codes_info = tmpl_codes_info + + # set the codes execution mode, default set to `SERIAL` + codes_run_mode = CodeRunMode.SERIAL + if calc_info.codes_run_mode: + codes_run_mode = calc_info.codes_run_mode + + job_tmpl.codes_run_mode = codes_run_mode + + if calc_info.file_copy_operation_order is not None: + if not isinstance(calc_info.file_copy_operation_order, list) or any( # type: ignore[redundant-expr] + not isinstance(e, FileCopyOperation) for e in calc_info.file_copy_operation_order + ): + raise PluginInternalError( + 'calc_info.file_copy_operation_order is not a list of `FileCopyOperation` enums.' + ) + else: + # Set the default + calc_info.file_copy_operation_order = [ + FileCopyOperation.SANDBOX, + FileCopyOperation.LOCAL, + FileCopyOperation.REMOTE, + ] + + ######################################################################## + + custom_sched_commands = self.node.get_option('custom_scheduler_commands') + if custom_sched_commands: + job_tmpl.custom_scheduler_commands = custom_sched_commands + + job_tmpl.import_sys_environment = self.node.get_option('import_sys_environment') + + job_tmpl.job_environment = self.node.get_option('environment_variables') + job_tmpl.environment_variables_double_quotes = self.node.get_option('environment_variables_double_quotes') + + queue_name = self.node.get_option('queue_name') + account = self.node.get_option('account') + qos = self.node.get_option('qos') + if queue_name is not None: + job_tmpl.queue_name = queue_name + if account is not None: + job_tmpl.account = account + if qos is not None: + job_tmpl.qos = qos + priority = self.node.get_option('priority') + if priority is not None: + job_tmpl.priority = priority + + job_tmpl.max_memory_kb = self.node.get_option('max_memory_kb') or computer.get_default_memory_per_machine() + + max_wallclock_seconds = self.node.get_option('max_wallclock_seconds') + if max_wallclock_seconds is not None: + job_tmpl.max_wallclock_seconds = max_wallclock_seconds + + submit_script_filename = self.node.get_option('submit_script_filename') + script_content = scheduler.get_submit_script(job_tmpl) + folder.create_file_from_filelike(io.StringIO(script_content), submit_script_filename, 'w', encoding='utf8') + + def encoder(obj): + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + raise TypeError(f' {obj!r} is not JSON serializable') + + subfolder = folder.get_subfolder('.aiida', create=True) + subfolder.create_file_from_filelike( + io.StringIO(json.dumps(job_tmpl, default=encoder)), 'job_tmpl.json', 'w', encoding='utf8' + ) + subfolder.create_file_from_filelike(io.StringIO(json.dumps(calc_info)), 'calcinfo.json', 'w', encoding='utf8') + + if calc_info.local_copy_list is None: + calc_info.local_copy_list = [] + + if calc_info.remote_copy_list is None: + calc_info.remote_copy_list = [] + + # Some validation + this_pk = self.node.pk if self.node.pk is not None else '[UNSTORED]' + local_copy_list = calc_info.local_copy_list + try: + validate_list_of_string_tuples(local_copy_list, tuple_length=3) + except ValidationError as exception: + raise PluginInternalError( + f'[presubmission of calc {this_pk}] local_copy_list format problem: {exception}' + ) from exception + + remote_copy_list = calc_info.remote_copy_list + try: + validate_list_of_string_tuples(remote_copy_list, tuple_length=3) + except ValidationError as exception: + raise PluginInternalError( + f'[presubmission of calc {this_pk}] remote_copy_list format problem: {exception}' + ) from exception + + for remote_computer_uuid, _, dest_rel_path in remote_copy_list: + try: + Computer.collection.get(uuid=remote_computer_uuid) + except exceptions.NotExistent as exception: + raise PluginInternalError( + '[presubmission of calc {}] ' + 'The remote copy requires a computer with UUID={}' + 'but no such computer was found in the ' + 'database'.format(this_pk, remote_computer_uuid) + ) from exception + if os.path.isabs(dest_rel_path): + raise PluginInternalError( + '[presubmission of calc {}] ' 'The destination path of the remote copy ' 'is absolute! ({})'.format( + this_pk, dest_rel_path + ) + ) + + return calc_info diff --git a/src/aiida/engine/processes/calcjobs/tasks.py b/src/aiida/engine/processes/calcjobs/tasks.py index e69de29bb2..45d5d98fa4 100644 --- a/src/aiida/engine/processes/calcjobs/tasks.py +++ b/src/aiida/engine/processes/calcjobs/tasks.py @@ -0,0 +1,683 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Transport tasks for calculation jobs.""" + +from __future__ import annotations + +import asyncio +import functools +import logging +import tempfile +from typing import TYPE_CHECKING, Any, Callable, Optional + +import plumpy +import plumpy.futures +import plumpy.persistence +import plumpy.process_states + +from aiida.common.datastructures import CalcJobState +from aiida.common.exceptions import FeatureNotAvailable, TransportTaskException +from aiida.common.folders import SandboxFolder +from aiida.engine.daemon import execmanager +from aiida.engine.processes.exit_code import ExitCode +from aiida.engine.transports import TransportQueue +from aiida.engine.utils import InterruptableFuture, exponential_backoff_retry, interruptable_task +from aiida.manage.configuration import get_config_option +from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode +from aiida.schedulers.datastructures import JobState + +from ..process import ProcessState +from .monitors import CalcJobMonitorAction, CalcJobMonitorResult, CalcJobMonitors + +if TYPE_CHECKING: + from .calcjob import CalcJob + +UPLOAD_COMMAND = 'upload' +SUBMIT_COMMAND = 'submit' +UPDATE_COMMAND = 'update' +RETRIEVE_COMMAND = 'retrieve' +STASH_COMMAND = 'stash' +KILL_COMMAND = 'kill' + +RETRY_INTERVAL_OPTION = 'transport.task_retry_initial_interval' +MAX_ATTEMPTS_OPTION = 'transport.task_maximum_attempts' + +logger = logging.getLogger(__name__) + + +class PreSubmitException(Exception): # noqa: N818 + """Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`.""" + + +async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will attempt to upload the files of a job calculation to the remote. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param process: the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + node = process.node + + if node.get_state() == CalcJobState.SUBMITTING: + logger.warning(f'CalcJob<{node.pk}> already marked as SUBMITTING, skipping task_update_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + filepath_sandbox = get_config_option('storage.sandbox') or None + + authinfo = node.get_authinfo() + + async def do_upload(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + with SandboxFolder(filepath_sandbox) as folder: + # Any exception thrown in `presubmit` call is not transient so we circumvent the exponential backoff + try: + calc_info = process.presubmit(folder) + except Exception as exception: + raise PreSubmitException('exception occurred in presubmit call') from exception + else: + remote_folder = await execmanager.upload_calculation(node, transport, calc_info, folder) + if remote_folder is not None: + process.out('remote_folder', remote_folder) + skip_submit = calc_info.skip_submit or False + + return skip_submit + + try: + logger.info(f'scheduled request to upload CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption) + skip_submit = await exponential_backoff_retry( + do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except PreSubmitException: + raise + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'uploading CalcJob<{node.pk}> failed') + raise TransportTaskException(f'upload_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'uploading CalcJob<{node.pk}> successful') + node.set_state(CalcJobState.SUBMITTING) + return skip_submit + + +async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will attempt to submit a job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + if node.get_state() == CalcJobState.WITHSCHEDULER: + assert node.get_job_id() is not None, 'job is WITHSCHEDULER, however, it does not have a job id' + logger.warning(f'CalcJob<{node.pk}> already marked as WITHSCHEDULER, skipping task_submit_job') + return node.get_job_id() + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_submit(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + return execmanager.submit_calculation(node, transport) + + try: + logger.info(f'scheduled request to submit CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + result = await exponential_backoff_retry( + do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'submitting CalcJob<{node.pk}> failed') + raise TransportTaskException(f'submit_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'submitting CalcJob<{node.pk}> successful') + node.set_state(CalcJobState.WITHSCHEDULER) + return result + + +async def task_update_job(node: CalcJobNode, job_manager, cancellable: InterruptableFuture): + """Transport task that will attempt to update the scheduler status of the job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param job_manager: The job manager + :param cancellable: A cancel flag + :return: True if the tasks was successfully completed, False otherwise + """ + state = node.get_state() + + if state in [CalcJobState.RETRIEVING, CalcJobState.STASHING]: + logger.warning(f'CalcJob<{node.pk}> already marked as `{state}`, skipping task_update_job') + return True + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + job_id = node.get_job_id() + + async def do_update(): + # Get the update request + with job_manager.request_job_info_update(authinfo, job_id) as update_request: + job_info = await cancellable.with_interrupt(update_request) + + if job_info is None: + # If the job is computed or not found assume it's done + node.set_scheduler_state(JobState.DONE) + job_done = True + else: + node.set_last_job_info(job_info) + node.set_scheduler_state(job_info.job_state) + job_done = job_info.job_state == JobState.DONE + + return job_done + + try: + logger.info(f'scheduled request to update CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + job_done = await exponential_backoff_retry( + do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'updating CalcJob<{node.pk}> failed') + raise TransportTaskException(f'update_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'updating CalcJob<{node.pk}> successful') + if job_done: + node.set_state(CalcJobState.STASHING) + + return job_done + + +async def task_monitor_job( + node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture, monitors: CalcJobMonitors +): + """Transport task that will monitor the job calculation if any monitors have been defined. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: A cancel flag + :param monitors: An instance of ``CalcJobMonitors`` holding the collection of monitors to process. + :return: True if the tasks was successfully completed, False otherwise + """ + state = node.get_state() + + if state in [CalcJobState.RETRIEVING, CalcJobState.STASHING]: + logger.warning(f'CalcJob<{node.pk}> already marked as `{state}`, skipping task_monitor_job') + return None + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + authinfo = node.get_authinfo() + + async def do_monitor(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + return monitors.process(node, transport) + + try: + logger.info(f'scheduled request to monitor CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + monitor_result = await exponential_backoff_retry( + do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'monitoring CalcJob<{node.pk}> failed') + raise TransportTaskException(f'monitor_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'monitoring CalcJob<{node.pk}> successful') + return monitor_result + + +async def task_retrieve_job( + process: 'CalcJob', + transport_queue: TransportQueue, + retrieved_temporary_folder: str, + cancellable: InterruptableFuture, +): + """Transport task that will attempt to retrieve all files of a completed job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param process: the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param retrieved_temporary_folder: the absolute path to a directory to store files + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + node = process.node + + if node.get_state() == CalcJobState.PARSING: + logger.warning(f'CalcJob<{node.pk}> already marked as PARSING, skipping task_retrieve_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_retrieve(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + # Perform the job accounting and set it on the node if successful. If the scheduler does not implement this + # still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the + # accounting was called but could not be set. + scheduler = node.computer.get_scheduler() # type: ignore[union-attr] + scheduler.set_transport(transport) + + if node.get_job_id() is None: + logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`') + retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + else: + try: + detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id()) + except FeatureNotAvailable: + logger.info(f'detailed job info not available for scheduler of CalcJob<{node.pk}>') + node.set_detailed_job_info(None) + else: + node.set_detailed_job_info(detailed_job_info) + + retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + + if retrieved is not None: + process.out(node.link_label_retrieved, retrieved) + + return retrieved + + try: + logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + result = await exponential_backoff_retry( + do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'retrieving CalcJob<{node.pk}> failed') + raise TransportTaskException(f'retrieve_calculation failed {max_attempts} times consecutively') from exception + else: + node.set_state(CalcJobState.PARSING) + logger.info(f'retrieving CalcJob<{node.pk}> successful') + return result + + +async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will optionally stash files of a completed job calculation on the remote. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + :raises: Return if the tasks was successfully completed + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + if node.get_state() == CalcJobState.RETRIEVING: + logger.warning(f'calculation<{node.pk}> already marked as RETRIEVING, skipping task_stash_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_stash(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + logger.info(f'stashing calculation<{node.pk}>') + return execmanager.stash_calculation(node, transport) + + try: + await exponential_backoff_retry( + do_stash, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption, + ) + except plumpy.process_states.Interruption: + raise + except Exception as exception: + logger.warning(f'stashing calculation<{node.pk}> failed') + raise TransportTaskException(f'stash_calculation failed {max_attempts} times consecutively') from exception + else: + node.set_state(CalcJobState.RETRIEVING) + logger.info(f'stashing calculation<{node.pk}> successful') + return + + +async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will attempt to kill a job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + if node.get_state() in [CalcJobState.UPLOADING, CalcJobState.SUBMITTING]: + logger.warning(f'CalcJob<{node.pk}> killed, it was in the {node.get_state()} state') + return True + + authinfo = node.get_authinfo() + + async def do_kill(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + return execmanager.kill_calculation(node, transport) + + try: + logger.info(f'scheduled request to kill CalcJob<{node.pk}>') + result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) + except plumpy.process_states.Interruption: + raise + except Exception as exception: + logger.warning(f'killing CalcJob<{node.pk}> failed') + raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'killing CalcJob<{node.pk}> successful') + node.set_scheduler_state(JobState.DONE) + return result + + +@plumpy.persistence.auto_persist('msg', 'data', '_command', '_monitor_result') +class Waiting(plumpy.process_states.Waiting): + """The waiting state for the `CalcJob` process.""" + + def __init__( + self, + process: 'CalcJob', + done_callback: Optional[Callable[..., Any]], + msg: Optional[str] = None, + data: Optional[Any] = None, + ): + """:param process: The process this state belongs to""" + super().__init__(process, done_callback, msg, data) + self._task: InterruptableFuture | None = None + self._killing: plumpy.futures.Future | None = None + self._command: Callable[..., Any] | None = None + self._monitor_result: CalcJobMonitorResult | None = None + self._monitors: CalcJobMonitors | None = None + + if isinstance(self.data, dict): + self._command = self.data['command'] + self._monitor_result = self.data.get('monitor_result', None) + else: + self._command = self.data + + @property + def monitors(self) -> CalcJobMonitors | None: + """Return the collection of monitors if specified in the inputs. + + :return: Instance of ``CalcJobMonitors`` containing monitors if specified in the process' input. + """ + if not hasattr(self, '_monitors'): + self._monitors = None + + if self._monitors is None and 'monitors' in self.process.node.inputs: + self._monitors = CalcJobMonitors(self.process.node.inputs.monitors) + + return self._monitors + + @property + def process(self) -> 'CalcJob': + """:return: The process""" + return self.state_machine # type: ignore[return-value] + + def load_instance_state(self, saved_state, load_context): + super().load_instance_state(saved_state, load_context) + self._task = None + self._killing = None + + async def execute(self) -> plumpy.process_states.State: # type: ignore[override] + """Override the execute coroutine of the base `Waiting` state.""" + node = self.process.node + transport_queue = self.process.runner.transport + result: plumpy.process_states.State = self + + process_status = f'Waiting for transport task: {self._command}' + node.set_process_status(process_status) + + try: + if self._command == UPLOAD_COMMAND: + skip_submit = await self._launch_task(task_upload_job, self.process, transport_queue) + if skip_submit: + result = self.retrieve(monitor_result=self._monitor_result) + else: + result = self.submit() + + elif self._command == SUBMIT_COMMAND: + result = await self._launch_task(task_submit_job, node, transport_queue) + + if isinstance(result, ExitCode): + # The scheduler plugin returned an exit code from ``Scheduler.submit_job`` indicating the + # job submission failed due to a non-transient problem and the job should be terminated. + return self.create_state(ProcessState.RUNNING, self.process.terminate, result) + + result = self.update() + + elif self._command == UPDATE_COMMAND: + job_done = False + + while not job_done: + scheduler_state = node.get_scheduler_state() + scheduler_state_string = scheduler_state.name if scheduler_state else 'UNKNOWN' + process_status = f'Monitoring scheduler: job state {scheduler_state_string}' + node.set_process_status(process_status) + job_done = await self._launch_task(task_update_job, node, self.process.runner.job_manager) + monitor_result = await self._monitor_job(node, transport_queue, self.monitors) + + if monitor_result and monitor_result.action is CalcJobMonitorAction.KILL: + await self._kill_job(node, transport_queue) + job_done = True + + if monitor_result and not monitor_result.retrieve: + exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=monitor_result.message) + return self.create_state(ProcessState.RUNNING, self.process.terminate, exit_code) # type: ignore[return-value] + + result = self.stash(monitor_result=monitor_result) + + elif self._command == STASH_COMMAND: + if node.get_option('stash') is not None: + await self._launch_task(task_stash_job, node, transport_queue) + result = self.retrieve(monitor_result=self._monitor_result) + + elif self._command == RETRIEVE_COMMAND: + temp_folder = tempfile.mkdtemp() + await self._launch_task(task_retrieve_job, self.process, transport_queue, temp_folder) + + if not self._monitor_result: + result = self.parse(temp_folder) + + elif self._monitor_result.parse is False: + exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=self._monitor_result.message) + result = self.create_state( # type: ignore[assignment] + ProcessState.RUNNING, self.process.terminate, exit_code + ) + + elif self._monitor_result.override_exit_code: + exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=self._monitor_result.message) + result = self.parse(temp_folder, exit_code) + else: + result = self.parse(temp_folder) + + else: + raise RuntimeError('Unknown waiting command') + + except TransportTaskException as exception: + raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}') + except plumpy.process_states.KillInterruption as exception: + await self._kill_job(node, transport_queue) + node.set_process_status(str(exception)) + return self.retrieve(monitor_result=self._monitor_result) + except (plumpy.futures.CancelledError, asyncio.CancelledError): + node.set_process_status(f'Transport task {self._command} was cancelled') + raise + except plumpy.process_states.Interruption: + node.set_process_status(f'Transport task {self._command} was interrupted') + raise + else: + node.set_process_status(None) + return result + finally: + # If we were trying to kill but we didn't deal with it, make sure it's set here + if self._killing and not self._killing.done(): + self._killing.set_result(False) + + async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorResult | None: + """Process job monitors if any were specified as inputs.""" + if monitors is None: + return None + + if self._monitor_result and self._monitor_result.action == CalcJobMonitorAction.DISABLE_ALL: + return None + + monitor_result = await self._launch_task(task_monitor_job, node, transport_queue, monitors=monitors) + + if monitor_result and monitor_result.outputs: + for label, output in monitor_result.outputs.items(): + self.process.out(label, output) + self.process.update_outputs() + + if monitor_result and monitor_result.action == CalcJobMonitorAction.DISABLE_SELF: + monitors.monitors[monitor_result.key].disabled = True + + if monitor_result is not None: + self._monitor_result = monitor_result + + return monitor_result + + async def _kill_job(self, node, transport_queue) -> None: + """Kill the job.""" + await self._launch_task(task_kill_job, node, transport_queue) + if self._killing is not None: + self._killing.set_result(True) + else: + logger.info(f'killed CalcJob<{node.pk}> but async future was None') + + async def _launch_task(self, coro, *args, **kwargs): + """Launch a coroutine as a task, making sure to make it interruptable.""" + task_fn = functools.partial(coro, *args, **kwargs) + try: + self._task = interruptable_task(task_fn) + result = await self._task + return result + finally: + self._task = None + + def upload(self) -> 'Waiting': + """Return the `Waiting` state that will `upload` the `CalcJob`.""" + msg = 'Waiting for calculation folder upload' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': UPLOAD_COMMAND} + ) + + def submit(self) -> 'Waiting': + """Return the `Waiting` state that will `submit` the `CalcJob`.""" + msg = 'Waiting for scheduler submission' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': SUBMIT_COMMAND} + ) + + def update(self) -> 'Waiting': + """Return the `Waiting` state that will `update` the `CalcJob`.""" + msg = 'Waiting for scheduler update' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': UPDATE_COMMAND} + ) + + def stash(self, monitor_result: CalcJobMonitorResult | None = None) -> 'Waiting': + """Return the `Waiting` state that will `stash` the `CalcJob`.""" + msg = 'Waiting to stash' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': STASH_COMMAND, 'monitor_result': monitor_result} + ) + + def retrieve(self, monitor_result: CalcJobMonitorResult | None = None) -> 'Waiting': + """Return the `Waiting` state that will `retrieve` the `CalcJob`.""" + msg = 'Waiting to retrieve' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': RETRIEVE_COMMAND, 'monitor_result': monitor_result} + ) + + def parse( + self, retrieved_temporary_folder: str, exit_code: ExitCode | None = None + ) -> plumpy.process_states.Running: + """Return the `Running` state that will parse the `CalcJob`. + + :param retrieved_temporary_folder: temporary folder used in retrieving that can be used during parsing. + """ + return self.create_state( # type: ignore[return-value] + ProcessState.RUNNING, self.process.parse, retrieved_temporary_folder, exit_code + ) + + def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override] + """Interrupt the `Waiting` state by calling interrupt on the transport task `InterruptableFuture`.""" + if self._task is not None: + self._task.interrupt(reason) + + if isinstance(reason, plumpy.process_states.KillInterruption): + if self._killing is None: + self._killing = plumpy.futures.Future() + return self._killing + + return None diff --git a/src/aiida/orm/computers.py b/src/aiida/orm/computers.py index bae925b25c..1c695910af 100644 --- a/src/aiida/orm/computers.py +++ b/src/aiida/orm/computers.py @@ -626,12 +626,12 @@ def get_transport(self, user: Optional['User'] = None) -> 'Transport': """Return a Transport class, configured with all correct parameters. The Transport is closed (meaning that if you want to run any operation with it, you have to open it first (i.e., e.g. for a SSH transport, you have - to open a connection). To do this you can call ``transports.open()``, or simply + to open a connection). To do this you can call ``transport.open()``, or simply run within a ``with`` statement:: transport = Computer.get_transport() with transport: - print(transports.whoami()) + print(transport.whoami()) :param user: if None, try to obtain a transport for the default user. Otherwise, pass a valid User. diff --git a/src/aiida/schedulers/plugins/direct.py b/src/aiida/schedulers/plugins/direct.py index 694ff93863..0bed55bda4 100644 --- a/src/aiida/schedulers/plugins/direct.py +++ b/src/aiida/schedulers/plugins/direct.py @@ -192,7 +192,7 @@ def _get_submit_command(self, submit_script): directory. IMPORTANT: submit_script should be already escaped. """ - submit_command = f'bash {submit_script} > /dev/null 2>&1 & echo $!' + submit_command = f'(bash {submit_script} > /dev/null 2>&1 & echo $!) &' self.logger.info(f'submitting with: {submit_command}') diff --git a/src/aiida/transports/__init__.py b/src/aiida/transports/__init__.py index eecd07c04f..c09153228e 100644 --- a/src/aiida/transports/__init__.py +++ b/src/aiida/transports/__init__.py @@ -14,12 +14,14 @@ from .plugins import * from .transport import * +from .util import StrPath __all__ = ( 'SshTransport', 'Transport', 'convert_to_bool', 'parse_sshconfig', + 'StrPath', ) # fmt: on diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index 755476a066..1fa30f4650 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -16,6 +16,7 @@ import shutil import subprocess +from aiida.common.warnings import warn_deprecation from aiida.transports import cli as transport_cli from aiida.transports.transport import Transport, TransportInternalError @@ -101,6 +102,10 @@ def chdir(self, path): :param path: path to cd into :raise OSError: if the directory does not have read attributes. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) new_path = os.path.join(self.curdir, path) if not os.path.isdir(new_path): raise OSError(f"'{new_path}' is not a valid directory") diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index d6159fe46f..3279c4430f 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -13,14 +13,16 @@ import os import re from stat import S_ISDIR, S_ISREG +from typing import Optional import click from aiida.cmdline.params import options from aiida.cmdline.params.types.path import AbsolutePathOrEmptyParamType from aiida.common.escaping import escape_for_bash +from aiida.common.warnings import warn_deprecation -from ..transport import Transport, TransportInternalError +from ..transport import Transport, TransportInternalError, _TransportPath, fix_path __all__ = ('parse_sshconfig', 'convert_to_bool', 'SshTransport') @@ -580,7 +582,7 @@ def __str__(self): return f"{'OPEN' if self._is_open else 'CLOSED'} [{conn_info}]" - def chdir(self, path): + def chdir(self, path: _TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -590,8 +592,13 @@ def chdir(self, path): Differently from paramiko, if you pass None to chdir, nothing happens and the cwd is unchanged. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) from paramiko.sftp import SFTPError + path = fix_path(path) old_path = self.sftp.getcwd() if path is not None: try: @@ -618,11 +625,13 @@ def chdir(self, path): self.chdir(old_path) raise OSError(str(exc)) - def normalize(self, path='.'): + def normalize(self, path: _TransportPath = '.'): """Returns the normalized path (removing double slashes, etc...)""" + path = fix_path(path) + return self.sftp.normalize(path) - def stat(self, path): + def stat(self, path: _TransportPath): """Retrieve information about a file on the remote system. The return value is an object whose attributes correspond to the attributes of Python's ``stat`` structure as returned by ``os.stat``, except that it @@ -635,9 +644,11 @@ def stat(self, path): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ + path = fix_path(path) + return self.sftp.stat(path) - def lstat(self, path): + def lstat(self, path: _TransportPath): """Retrieve information about a file on the remote system, without following symbolic links (shortcuts). This otherwise behaves exactly the same as `stat`. @@ -647,6 +658,8 @@ def lstat(self, path): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ + path = fix_path(path) + return self.sftp.lstat(path) def getcwd(self): @@ -659,9 +672,13 @@ def getcwd(self): this method will return None. But in __enter__ this is set explicitly, so this should never happen within this class. """ + warn_deprecation( + '`chdir()` is deprecated and will be removed in the next major version.', + version=3, + ) return self.sftp.getcwd() - def makedirs(self, path, ignore_existing=False): + def makedirs(self, path: _TransportPath, ignore_existing: bool = False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -676,6 +693,8 @@ def makedirs(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = fix_path(path) + # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -697,7 +716,7 @@ def makedirs(self, path, ignore_existing=False): if not self.isdir(this_dir): self.mkdir(this_dir) - def mkdir(self, path, ignore_existing=False): + def mkdir(self, path: _TransportPath, ignore_existing: bool = False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -706,6 +725,8 @@ def mkdir(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = fix_path(path) + if ignore_existing and self.isdir(path): return @@ -725,7 +746,7 @@ def mkdir(self, path, ignore_existing=False): 'or the directory already exists? ({})'.format(path, self.getcwd(), exc) ) - def rmtree(self, path): + def rmtree(self, path: _TransportPath): """Remove a file or a directory at path, recursively Flags used: -r: recursive copy; -f: force, makes the command non interactive; @@ -733,6 +754,7 @@ def rmtree(self, path): :raise OSError: if the rm execution failed. """ + path = fix_path(path) # Assuming linux rm command! rm_exe = 'rm' @@ -752,8 +774,9 @@ def rmtree(self, path): self.logger.error(f"Problem executing rm. Exit code: {retval}, stdout: '{stdout}', stderr: '{stderr}'") raise OSError(f'Error while executing rm. Exit code: {retval}') - def rmdir(self, path): + def rmdir(self, path: _TransportPath): """Remove the folder named 'path' if empty.""" + path = fix_path(path) self.sftp.rmdir(path) def chown(self, path, uid, gid): @@ -763,14 +786,17 @@ def chown(self, path, uid, gid): """ raise NotImplementedError - def isdir(self, path): + def isdir(self, path: _TransportPath): """Return True if the given path is a directory, False otherwise. Return False also if the path does not exist. """ # Return False on empty string (paramiko would map this to the local # folder instead) + path = fix_path(path) + if not path: return False + path = fix_path(path) try: return S_ISDIR(self.stat(path).st_mode) except OSError as exc: @@ -779,21 +805,24 @@ def isdir(self, path): return False raise # Typically if I don't have permissions (errno=13) - def chmod(self, path, mode): + def chmod(self, path: _TransportPath, mode): """Change permissions to path :param path: path to file :param mode: new permission bits (integer) """ + path = fix_path(path) + if not path: raise OSError('Input path is an empty argument.') return self.sftp.chmod(path, mode) @staticmethod - def _os_path_split_asunder(path): - """Used by makedirs. Takes path (a str) + def _os_path_split_asunder(path: _TransportPath): + """Used by makedirs. Takes path and returns a list deconcatenating the path """ + path = fix_path(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -807,7 +836,15 @@ def _os_path_split_asunder(path): parts.reverse() return parts - def put(self, localpath, remotepath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): + def put( + self, + localpath: _TransportPath, + remotepath: _TransportPath, + callback=None, + dereference: Optional[bool] = True, + overwrite: Optional[bool] = True, + ignore_nonexisting: Optional[bool] = False, + ): """Put a file or a folder from local to remote. Redirects to putfile or puttree. @@ -821,6 +858,9 @@ def put(self, localpath, remotepath, callback=None, dereference=True, overwrite= :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist """ + localpath = fix_path(localpath) + remotepath = fix_path(remotepath) + if not dereference: raise NotImplementedError @@ -871,7 +911,14 @@ def put(self, localpath, remotepath, callback=None, dereference=True, overwrite= elif not ignore_nonexisting: raise OSError(f'The local path {localpath} does not exist') - def putfile(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): + def putfile( + self, + localpath: _TransportPath, + remotepath: _TransportPath, + callback=None, + dereference: Optional[bool] = True, + overwrite: Optional[bool] = True, + ): """Put a file from local to remote. :param localpath: an (absolute) local path @@ -883,6 +930,9 @@ def putfile(self, localpath, remotepath, callback=None, dereference=True, overwr :raise OSError: if the localpath does not exist, or unintentionally overwriting """ + localpath = fix_path(localpath) + remotepath = fix_path(remotepath) + if not dereference: raise NotImplementedError @@ -894,7 +944,14 @@ def putfile(self, localpath, remotepath, callback=None, dereference=True, overwr return self.sftp.put(localpath, remotepath, callback=callback) - def puttree(self, localpath, remotepath, callback=None, dereference=True, overwrite=True): + def puttree( + self, + localpath: _TransportPath, + remotepath: _TransportPath, + callback=None, + dereference: Optional[bool] = True, + overwrite: Optional[bool] = True, + ): """Put a folder recursively from local to remote. By default, overwrite. @@ -913,6 +970,9 @@ def puttree(self, localpath, remotepath, callback=None, dereference=True, overwr .. note:: setting dereference equal to True could cause infinite loops. see os.walk() documentation """ + localpath = fix_path(localpath) + remotepath = fix_path(remotepath) + if not dereference: raise NotImplementedError @@ -958,7 +1018,15 @@ def puttree(self, localpath, remotepath, callback=None, dereference=True, overwr this_remote_file = os.path.join(remotepath, this_basename, this_file) self.putfile(this_local_file, this_remote_file) - def get(self, remotepath, localpath, callback=None, dereference=True, overwrite=True, ignore_nonexisting=False): + def get( + self, + remotepath: _TransportPath, + localpath: _TransportPath, + callback=None, + dereference: Optional[bool] = True, + overwrite: Optional[bool] = True, + ignore_nonexisting: Optional[bool] = False, + ): """Get a file or folder from remote to local. Redirects to getfile or gettree. @@ -973,6 +1041,9 @@ def get(self, remotepath, localpath, callback=None, dereference=True, overwrite= :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found """ + remotepath = fix_path(remotepath) + localpath = fix_path(localpath) + if not dereference: raise NotImplementedError @@ -1020,7 +1091,14 @@ def get(self, remotepath, localpath, callback=None, dereference=True, overwrite= else: raise OSError(f'The remote path {remotepath} does not exist') - def getfile(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): + def getfile( + self, + remotepath: _TransportPath, + localpath: _TransportPath, + callback=None, + dereference: Optional[bool] = True, + overwrite: Optional[bool] = True, + ): """Get a file from remote to local. :param remotepath: a remote path @@ -1031,6 +1109,9 @@ def getfile(self, remotepath, localpath, callback=None, dereference=True, overwr :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting """ + remotepath = fix_path(remotepath) + localpath = fix_path(localpath) + if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -1050,7 +1131,14 @@ def getfile(self, remotepath, localpath, callback=None, dereference=True, overwr pass raise - def gettree(self, remotepath, localpath, callback=None, dereference=True, overwrite=True): + def gettree( + self, + remotepath: _TransportPath, + localpath: _TransportPath, + callback=None, + dereference: Optional[bool] = True, + overwrite: Optional[bool] = None, + ): """Get a folder recursively from remote to local. :param remotepath: a remote path @@ -1065,6 +1153,8 @@ def gettree(self, remotepath, localpath, callback=None, dereference=True, overwr :raise OSError: if the remotepath is not found :raise OSError: if unintentionally overwriting """ + remotepath = fix_path(remotepath) + localpath = fix_path(localpath) if not dereference: raise NotImplementedError @@ -1101,10 +1191,11 @@ def gettree(self, remotepath, localpath, callback=None, dereference=True, overwr else: self.getfile(os.path.join(remotepath, item), os.path.join(dest, item)) - def get_attribute(self, path): + def get_attribute(self, path: _TransportPath): """Returns the object Fileattribute, specified in aiida.transports Receives in input the path of a given file. """ + path = fix_path(path) from aiida.transports.util import FileAttribute paramiko_attr = self.lstat(path) @@ -1115,13 +1206,25 @@ def get_attribute(self, path): aiida_attr[key] = getattr(paramiko_attr, key) return aiida_attr - def copyfile(self, remotesource, remotedestination, dereference=False): + def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference: bool = False): + remotesource = fix_path(remotesource) + remotedestination = fix_path(remotedestination) + return self.copy(remotesource, remotedestination, dereference) - def copytree(self, remotesource, remotedestination, dereference=False): + def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference: bool = False): + remotesource = fix_path(remotesource) + remotedestination = fix_path(remotedestination) + return self.copy(remotesource, remotedestination, dereference, recursive=True) - def copy(self, remotesource, remotedestination, dereference=False, recursive=True): + def copy( + self, + remotesource: _TransportPath, + remotedestination: _TransportPath, + dereference: bool = False, + recursive: bool = True, + ): """Copy a file or a directory from remote source to remote destination. Flags used: ``-r``: recursive copy; ``-f``: force, makes the command non interactive; ``-L`` follows symbolic links @@ -1138,6 +1241,9 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru .. note:: setting dereference equal to True could cause infinite loops. """ + remotesource = fix_path(remotesource) + remotedestination = fix_path(remotedestination) + # In the majority of cases, we should deal with linux cp commands cp_flags = '-f' if recursive: @@ -1179,9 +1285,11 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru else: self._exec_cp(cp_exe, cp_flags, remotesource, remotedestination) - def _exec_cp(self, cp_exe, cp_flags, src, dst): + def _exec_cp(self, cp_exe: str, cp_flags: str, src: _TransportPath, dst: _TransportPath): """Execute the ``cp`` command on the remote machine.""" # to simplify writing the above copy function + src = fix_path(src) + dst = fix_path(dst) command = f'{cp_exe} {cp_flags} {escape_for_bash(src)} {escape_for_bash(dst)}' retval, stdout, stderr = self.exec_command_wait_bytes(command) @@ -1205,7 +1313,7 @@ def _exec_cp(self, cp_exe, cp_flags, src, dst): ) @staticmethod - def _local_listdir(path, pattern=None): + def _local_listdir(path: str, pattern=None): """Acts on the local folder, for the rest, same as listdir""" if not pattern: return os.listdir(path) @@ -1226,6 +1334,8 @@ def listdir(self, path='.', pattern=None): :param pattern: returns the list of files matching pattern. Unix only. (Use to emulate ``ls *`` for example) """ + path = fix_path(path) + if path.startswith('/'): abs_dir = path else: @@ -1241,6 +1351,7 @@ def listdir(self, path='.', pattern=None): def remove(self, path): """Remove a single file at 'path'""" + path = fix_path(path) return self.sftp.remove(path) def rename(self, oldpath, newpath): @@ -1250,15 +1361,22 @@ def rename(self, oldpath, newpath): :param str newpath: new name for the file or folder :raises OSError: if oldpath/newpath is not found - :raises ValueError: if sroldpathc/newpath is not a valid string + :raises ValueError: if sroldpathc/newpath is not a valid path """ if not oldpath: - raise ValueError(f'Source {oldpath} is not a valid string') + raise ValueError(f'Source {oldpath} is not a valid path') if not newpath: - raise ValueError(f'Destination {newpath} is not a valid string') + raise ValueError(f'Destination {newpath} is not a valid path') + + oldpath = fix_path(oldpath) + newpath = fix_path(newpath) + if not self.isfile(oldpath): if not self.isdir(oldpath): raise OSError(f'Source {oldpath} does not exist') + # TODO: this seems to be a bug (?) + # why to raise an OSError if the newpath does not exist? + # ofcourse newpath shouldn't exist, that's why we are renaming it! if not self.isfile(newpath): if not self.isdir(newpath): raise OSError(f'Destination {newpath} does not exist') @@ -1274,6 +1392,8 @@ def isfile(self, path): # but this is just to be sure if not path: return False + + path = fix_path(path) try: self.logger.debug( f"stat for path '{path}' ('{self.normalize(path)}'): {self.stat(path)} [{self.stat(path).st_mode}]" @@ -1451,6 +1571,8 @@ def gotocomputer_command(self, remotedir): """Specific gotocomputer string to connect to a given remote computer via ssh and directly go to the calculation folder. """ + remotedir = fix_path(remotedir) + further_params = [] if 'username' in self._connect_args: further_params.append(f"-l {escape_for_bash(self._connect_args['username'])}") @@ -1479,6 +1601,8 @@ def _symlink(self, source, dest): :param source: source of link :param dest: link to create """ + source = fix_path(source) + dest = fix_path(dest) self.sftp.symlink(source, dest) def symlink(self, remotesource, remotedestination): @@ -1488,6 +1612,8 @@ def symlink(self, remotesource, remotedestination): :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ + remotesource = fix_path(remotesource) + remotedestination = fix_path(remotedestination) # paramiko gives some errors if path is starting with '.' source = os.path.normpath(remotesource) dest = os.path.normpath(remotedestination) @@ -1495,7 +1621,7 @@ def symlink(self, remotesource, remotedestination): if self.has_magic(source): if self.has_magic(dest): # if there are patterns in dest, I don't know which name to assign - raise ValueError('Remotedestination cannot have patterns') + raise ValueError('`remotedestination` cannot have patterns') # find all files matching pattern for this_source in self.glob(source): @@ -1509,6 +1635,8 @@ def path_exists(self, path): """Check if path exists""" import errno + path = fix_path(path) + try: self.stat(path) except OSError as exc: diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py new file mode 100644 index 0000000000..6255c4f421 --- /dev/null +++ b/src/aiida/transports/plugins/ssh_async.py @@ -0,0 +1,915 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Plugin for transport over SSH asynchronously. + +Since for many dependencies the blocking methods are required, +this plugin develops both blocking methods, as well. +""" + +## TODO: +## and start writing tests! +## put & get methods could be simplified with the asyncssh.sftp.mget() & put() method or sftp.glob() +import asyncio +import glob +import os +from pathlib import Path, PurePath +from typing import Optional, Union + +import asyncssh +import click +from asyncssh import SFTPFileAlreadyExists + +from aiida.common.escaping import escape_for_bash +from aiida.common.exceptions import InvalidOperation + +from ..transport import Transport, TransportInternalError, _TransportPath, fix_path + +__all__ = ('AsyncSshTransport',) + + +def _validate_script(ctx, param, value: str): + if value == 'None': + return value + if not os.path.isabs(value): + raise click.BadParameter(f'{value} is not an absolute path') + if not os.path.isfile(value): + raise click.BadParameter(f'The script file: {value} does not exist') + if not os.access(value, os.X_OK): + raise click.BadParameter(f'The script {value} is not executable') + return value + + +def _validate_machine(ctx, param, value: str): + async def attempt_connection(): + try: + await asyncssh.connect(value) + except Exception: + return False + return True + + if not asyncio.run(attempt_connection()): + raise click.BadParameter("Couldn't connect! " 'Please make sure `ssh {value}` would work without password') + else: + click.echo(f'`ssh {value}` successful!') + + return value + + +class AsyncSshTransport(Transport): + """Transport plugin via SSH, asynchronously.""" + + # note, I intentionally wanted to keep connection parameters as simple as possible. + _valid_auth_options = [ + ( + 'machine', + { + 'type': str, + 'prompt': 'machine as in `ssh machine` command', + 'help': 'Password-less host-setup to connect, as in command `ssh machine`. ' + "You'll need to have a `Host machine` " + 'entry defined in your `~/.ssh/config` file. ', + 'non_interactive_default': True, + 'callback': _validate_machine, + }, + ), + ( + 'script_before', + { + 'type': str, + 'default': 'None', + 'prompt': 'Local script to run *before* opening connection (path)', + 'help': ' (optional) Specify a script to run *before* opening SSH connection. ' + 'The script should be executable', + 'non_interactive_default': True, + 'callback': _validate_script, + }, + ), + ( + 'script_during', + { + 'type': str, + 'default': 'None', + 'prompt': 'Local script to run *during* opening connection (path)', + 'help': '(optional) Specify a script to run *during* opening SSH connection. ' + 'The script should be executable', + 'non_interactive_default': True, + 'callback': _validate_script, + }, + ), + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.machine = kwargs.pop('machine') + self.script_before = kwargs.pop('script_before', 'None') + self.script_during = kwargs.pop('script_during', 'None') + + def __str__(self): + return f"{'OPEN' if self._is_open else 'CLOSED'} [AsyncSshTransport]" + + async def open_async(self): + if self._is_open: + raise InvalidOperation('Cannot open the transport twice') + + if self.script_before != 'None': + os.system(f'{self.script_before}') + + self._conn = await asyncssh.connect(self.machine) + + if self.script_during != 'None': + os.system(f'{self.script_during}') + + self._sftp = await self._conn.start_sftp_client() + + self._is_open = True + + return self + + async def close_async(self): + if not self._is_open: + raise InvalidOperation('Cannot close the transport: it is already closed') + + self._conn.close() + await self._conn.wait_closed() + self._is_open = False + + async def get_async(self, remotepath, localpath, dereference=True, overwrite=True, ignore_nonexisting=False): + """Get a file or folder from remote to local. + Redirects to getfile or gettree. + + :param remotepath: a remote path + :param localpath: an (absolute) local path + :param dereference: follow symbolic links. + Default = True (default behaviour in paramiko). + False is not implemented. + :param overwrite: if True overwrites files and folders. + Default = False + + :raise ValueError: if local path is invalid + :raise OSError: if the remotepath is not found + """ + remotepath = fix_path(remotepath) + localpath = fix_path(localpath) + + if not os.path.isabs(localpath): + raise ValueError('The localpath must be an absolute path') + + ## TODO: this whole glob part can be simplified via the asyncssh glob + ## or by using the asyncssh.sftp.mget() method + if self.has_magic(remotepath): + if self.has_magic(localpath): + raise ValueError('Pathname patterns are not allowed in the destination') + # use the self glob to analyze the path remotely + to_copy_list = await self.glob_async(remotepath) + + rename_local = False + if len(to_copy_list) > 1: + # I can't scp more than one file on a single file + if os.path.isfile(localpath): + raise OSError('Remote destination is not a directory') + # I can't scp more than one file in a non existing directory + elif not os.path.exists(localpath): # this should hold only for files + raise OSError('Remote directory does not exist') + else: # the remote path is a directory + rename_local = True + + for file in to_copy_list: + if await self.isfile_async(file): + if rename_local: # copying more than one file in one directory + # here is the case isfile and more than one file + remote = os.path.join(localpath, os.path.split(file)[1]) + await self.getfile_async(file, remote, dereference, overwrite) + else: # one file to copy on one file + await self.getfile_async(file, localpath, dereference, overwrite) + else: + await self.gettree_async(file, localpath, dereference, overwrite) + + elif await self.isdir_async(remotepath): + await self.gettree_async(remotepath, localpath, dereference, overwrite) + elif await self.isfile_async(remotepath): + if os.path.isdir(localpath): + remote = os.path.join(localpath, os.path.split(remotepath)[1]) + await self.getfile_async(remotepath, remote, dereference, overwrite) + else: + await self.getfile_async(remotepath, localpath, dereference, overwrite) + elif ignore_nonexisting: + pass + else: + raise OSError(f'The remote path {remotepath} does not exist') + + async def getfile_async(self, remotepath, localpath, dereference=True, overwrite=True): + """Get a file from remote to local. + + :param remotepath: a remote path + :param localpath: an (absolute) local path + :param overwrite: if True overwrites files and folders. + Default = False + + :raise ValueError: if local path is invalid + :raise OSError: if unintentionally overwriting + """ + remotepath = fix_path(remotepath) + localpath = fix_path(localpath) + + if not os.path.isabs(localpath): + raise ValueError('localpath must be an absolute path') + + if os.path.isfile(localpath) and not overwrite: + raise OSError('Destination already exists: not overwriting it') + + try: + await self._sftp.get( + remotepaths=remotepath, localpath=localpath, preserve=True, recurse=False, follow_symlinks=dereference + ) + except (OSError, asyncssh.Error) as exc: + raise OSError(f'Error while uploading file {localpath}: {exc}') + + async def gettree_async(self, remotepath, localpath, dereference=True, overwrite=True): + """Get a folder recursively from remote to local. + + :param remotepath: a remote path + :param localpath: an (absolute) local path + :param dereference: follow symbolic links. + Default = True (default behaviour in paramiko). + False is not implemented. + :param overwrite: if True overwrites files and folders. + Default = False + + :raise ValueError: if local path is invalid + :raise OSError: if the remotepath is not found + :raise OSError: if unintentionally overwriting + """ + remotepath = fix_path(remotepath) + localpath = fix_path(localpath) + + if not remotepath: + raise OSError('Remotepath must be a non empty string') + if not localpath: + raise ValueError('Localpaths must be a non empty string') + + if not os.path.isabs(localpath): + raise ValueError('Localpaths must be an absolute path') + + if not await self.isdir_async(remotepath): + raise OSError(f'Input remotepath is not a folder: {localpath}') + + if os.path.exists(localpath) and not overwrite: + raise OSError("Can't overwrite existing files") + if os.path.isfile(localpath): + raise OSError('Cannot copy a directory into a file') + + if not os.path.isdir(localpath): # in this case copy things in the remotepath directly + os.makedirs(localpath, exist_ok=True) # and make a directory at its place + else: # localpath exists already: copy the folder inside of it! + localpath = os.path.join(localpath, os.path.split(remotepath)[1]) + os.makedirs(localpath, exist_ok=overwrite) # create a nested folder + + content_list = await self.listdir_async(remotepath) + for content_ in content_list: + try: + await self._sftp.get( + remotepaths=PurePath(remotepath) / content_, + localpath=localpath, + preserve=True, + recurse=True, + follow_symlinks=dereference, + ) + except (OSError, asyncssh.Error) as exc: + raise OSError(f'Error while uploading file {localpath}: {exc}') + + async def put_async(self, localpath, remotepath, dereference=True, overwrite=True, ignore_nonexisting=False): + """Put a file or a folder from local to remote. + Redirects to putfile or puttree. + + :param localpath: an (absolute) local path + :param remotepath: a remote path + :param dereference: follow symbolic links (boolean). + Default = True (default behaviour in paramiko). False is not implemented. + :param overwrite: if True overwrites files and folders (boolean). + Default = False. + + :raise ValueError: if local path is invalid + :raise OSError: if the localpath does not exist + """ + localpath = fix_path(localpath) + remotepath = fix_path(remotepath) + + if not os.path.isabs(localpath): + raise ValueError('The localpath must be an absolute path') + + # TODO: this whole glob part can be simplified via the asyncssh glob + if self.has_magic(localpath): + if self.has_magic(remotepath): + raise ValueError('Pathname patterns are not allowed in the destination') + + # use the imported glob to analyze the path locally + to_copy_list = glob.glob(localpath) + + rename_remote = False + if len(to_copy_list) > 1: + # I can't scp more than one file on a single file + if await self.isfile_async(remotepath): + raise OSError('Remote destination is not a directory') + # I can't scp more than one file in a non existing directory + elif not await self.path_exists_async(remotepath): # questo dovrebbe valere solo per file + raise OSError('Remote directory does not exist') + else: # the remote path is a directory + rename_remote = True + + for file in to_copy_list: + if os.path.isfile(file): + if rename_remote: # copying more than one file in one directory + # here is the case isfile and more than one file + remotefile = os.path.join(remotepath, os.path.split(file)[1]) + await self.putfile_async(file, remotefile, dereference, overwrite) + + elif await self.isdir_async(remotepath): # one file to copy in '.' + remotefile = os.path.join(remotepath, os.path.split(file)[1]) + await self.putfile_async(file, remotefile, dereference, overwrite) + else: # one file to copy on one file + await self.putfile_async(file, remotepath, dereference, overwrite) + else: + await self.puttree_async(file, remotepath, dereference, overwrite) + + elif os.path.isdir(localpath): + await self.puttree_async(localpath, remotepath, dereference, overwrite) + elif os.path.isfile(localpath): + if await self.isdir_async(remotepath): + remote = os.path.join(remotepath, os.path.split(localpath)[1]) + await self.putfile_async(localpath, remote, dereference, overwrite) + else: + await self.putfile_async(localpath, remotepath, dereference, overwrite) + elif not ignore_nonexisting: + raise OSError(f'The local path {localpath} does not exist') + + async def putfile_async(self, localpath, remotepath, dereference=True, overwrite=True): + """Put a file from local to remote. + + :param localpath: an (absolute) local path + :param remotepath: a remote path + :param overwrite: if True overwrites files and folders (boolean). + Default = True. + + :raise ValueError: if local path is invalid + :raise OSError: if the localpath does not exist, + or unintentionally overwriting + """ + localpath = fix_path(localpath) + remotepath = fix_path(remotepath) + + if not os.path.isabs(localpath): + raise ValueError('The localpath must be an absolute path') + + if await self.isfile_async(remotepath) and not overwrite: + raise OSError('Destination already exists: not overwriting it') + + try: + await self._sftp.put( + localpaths=localpath, remotepath=remotepath, preserve=True, recurse=False, follow_symlinks=dereference + ) + except (OSError, asyncssh.Error) as exc: + raise OSError(f'Error while uploading file {localpath}: {exc}') + + async def puttree_async(self, localpath, remotepath, dereference=True, overwrite=True): + """Put a folder recursively from local to remote. + + By default, overwrite. + + :param localpath: an (absolute) local path + :param remotepath: a remote path + :param dereference: follow symbolic links (boolean) + Default = True (default behaviour in paramiko). False is not implemented. + :param overwrite: if True overwrites files and folders (boolean). + Default = True + + :raise ValueError: if local path is invalid + :raise OSError: if the localpath does not exist, or trying to overwrite + :raise OSError: if remotepath is invalid + + .. note:: setting dereference equal to True could cause infinite loops. + see os.walk() documentation + """ + localpath = fix_path(localpath) + remotepath = fix_path(remotepath) + + if not os.path.isabs(localpath): + raise ValueError('The localpath must be an absolute path') + + if not os.path.exists(localpath): + raise OSError('The localpath does not exists') + + if not os.path.isdir(localpath): + raise ValueError(f'Input localpath is not a folder: {localpath}') + + if not remotepath: + raise OSError('remotepath must be a non empty string') + + if await self.path_exists_async(remotepath) and not overwrite: + raise OSError("Can't overwrite existing files") + if await self.isfile_async(remotepath): + raise OSError('Cannot copy a directory into a file') + + if not await self.isdir_async(remotepath): # in this case copy things in the remotepath directly + await self.mkdir_async(remotepath) # and make a directory at its place + else: # remotepath exists already: copy the folder inside of it! + remotepath = os.path.join(remotepath, os.path.split(localpath)[1]) + await self.makedirs_async(remotepath, ignore_existing=overwrite) # create a nested folder + + # This is written in this way, only because AiiDA expects to put file inside an existing folder + # Or to put and rename the parent folder at the same time + content_list = os.listdir(localpath) + for content_ in content_list: + try: + await self._sftp.put( + localpaths=PurePath(localpath) / content_, + remotepath=remotepath, + preserve=True, + recurse=True, + follow_symlinks=dereference, + ) + except (OSError, asyncssh.Error) as exc: + raise OSError(f'Error while uploading file {PurePath(localpath)/content_}: {exc}') + + async def copy_async( + self, + remotesource: _TransportPath, + remotedestination: _TransportPath, + dereference: bool = False, + recursive: bool = True, + preserve: bool = False, + ): + """ """ + remotesource = fix_path(remotesource) + remotedestination = fix_path(remotedestination) + if self.has_magic(remotedestination): + raise ValueError('Pathname patterns are not allowed in the destination') + + if not remotedestination: + raise ValueError('remotedestination must be a non empty string') + if not remotesource: + raise ValueError('remotesource must be a non empty string') + try: + if self.has_magic(remotesource): + await self._sftp.mcopy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + ) + else: + await self._sftp.copy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + ) + except asyncssh.sftp.SFTPFailure as exc: + raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') + + async def exec_command_wait_async( + self, + command: str, + stdin: Optional[str] = None, + encoding: str = 'utf-8', + workdir: Union[_TransportPath, None] = None, + timeout: Optional[float] = 2, + **kwargs, + ): + """Execute a command on the remote machine and wait for it to finish. + + :param command: the command to execute + :param stdin: the standard input to pass to the command + :param encoding: (IGNORED) this is here just to keep the same signature as the one in `Transport` class + :param workdir: the working directory where to execute the command + :param timeout: the timeout in seconds + + :type command: str + :type stdin: str + :type encoding: str + :type workdir: str + :type timeout: float + + :return: a tuple with the return code, the stdout and the stderr of the command + :rtype: tuple(int, str, str) + """ + + if workdir: + command = f'cd {workdir} && {command}' + + bash_commmand = self._bash_command_str + '-c ' + + result = await self._conn.run( + bash_commmand + escape_for_bash(command), input=stdin, check=False, timeout=timeout + ) + # both stdout and stderr are strings + return (result.returncode, ''.join(result.stdout), ''.join(result.stderr)) # type: ignore [arg-type] + + async def get_attribute_async(self, path): + """ """ + path = fix_path(path) + from aiida.transports.util import FileAttribute + + asyncssh_attr = await self._sftp.lstat(path) + aiida_attr = FileAttribute() + # map the asyncssh class into the aiida one + for key in aiida_attr._valid_fields: + if key == 'st_size': + aiida_attr[key] = asyncssh_attr.size + elif key == 'st_uid': + aiida_attr[key] = asyncssh_attr.uid + elif key == 'st_gid': + aiida_attr[key] = asyncssh_attr.gid + elif key == 'st_mode': + aiida_attr[key] = asyncssh_attr.permissions + elif key == 'st_atime': + aiida_attr[key] = asyncssh_attr.atime + elif key == 'st_mtime': + aiida_attr[key] = asyncssh_attr.mtime + else: + raise NotImplementedError(f'Mapping the {key} attribute is not implemented') + return aiida_attr + + async def isdir_async(self, path): + """Return True if the given path is a directory, False otherwise. + Return False also if the path does not exist. + """ + # Return False on empty string + if not path: + return False + + path = fix_path(path) + + return await self._sftp.isdir(path) + + async def isfile_async(self, path): + """Return True if the given path is a file, False otherwise. + Return False also if the path does not exist. + """ + # Return False on empty string + if not path: + return False + + path = fix_path(path) + + return await self._sftp.isfile(path) + + async def listdir_async(self, path, pattern=None): + """ + + :param path: the absolute path to list + """ + path = fix_path(path) + if not pattern: + list_ = await self._sftp.listdir(path) + else: + patterned_path = pattern if pattern.startswith('/') else Path(path).joinpath(pattern) + list_ = await self._sftp.glob(patterned_path) + + for item in ['..', '.']: + if item in list_: + list_.remove(item) + + return list_ + + async def listdir_withattributes_async(self, path: _TransportPath, pattern: Optional[str] = None): + """Return a list of the names of the entries in the given path. + The list is in arbitrary order. It does not include the special + entries '.' and '..' even if they are present in the directory. + + :param str path: absolute path to list + :param str pattern: if used, listdir returns a list of files matching + filters in Unix style. Unix only. + :return: a list of dictionaries, one per entry. + The schema of the dictionary is + the following:: + + { + 'name': String, + 'attributes': FileAttributeObject, + 'isdir': Bool + } + + where 'name' is the file or folder directory, and any other information is metadata + (if the file is a folder, a directory, ...). 'attributes' behaves as the output of + transport.get_attribute(); isdir is a boolean indicating if the object is a directory or not. + """ + path = fix_path(path) + retlist = [] + listdir = await self.listdir_async(path, pattern) + for file_name in listdir: + filepath = os.path.join(path, file_name) + attributes = await self.get_attribute_async(filepath) + retlist.append({'name': file_name, 'attributes': attributes, 'isdir': await self.isdir_async(filepath)}) + + return retlist + + async def makedirs_async(self, path, ignore_existing=False): + """Super-mkdir; create a leaf directory and all intermediate ones. + Works like mkdir, except that any intermediate path segment (not + just the rightmost) will be created if it does not exist. + + :param str path: absolute path to directory to create + :param bool ignore_existing: if set to true, it doesn't give any error + if the leaf directory does already exist + + :raises: OSError, if directory at path already exists + """ + path = fix_path(path) + + try: + await self._sftp.makedirs(path, exist_ok=ignore_existing) + except SFTPFileAlreadyExists as exc: + raise OSError(f'Error while creating directory {path}: {exc}, directory already exists') + except asyncssh.sftp.SFTPFailure as exc: + if (self._sftp.version < 6) and not ignore_existing: + raise OSError(f'Error while creating directory {path}: {exc}, probably it already exists') + else: + raise TransportInternalError(f'Error while creating directory {path}: {exc}') + + async def mkdir_async(self, path: _TransportPath, ignore_existing=False): + """Create a directory. + + :param str path: absolute path to directory to create + :param bool ignore_existing: if set to true, it doesn't give any error + if the leaf directory does already exist + + :raises: OSError, if directory at path already exists + """ + path = fix_path(path) + + try: + await self._sftp.mkdir(path) + except SFTPFileAlreadyExists as exc: + # note: mkdir() in asyncssh does not support the exist_ok parameter + if ignore_existing: + return + raise OSError(f'Error while creating directory {path}: {exc}, directory already exists') + except asyncssh.sftp.SFTPFailure as exc: + if self._sftp.version < 6: + if ignore_existing: + return + else: + raise OSError(f'Error while creating directory {path}: {exc}, probably it already exists') + else: + raise TransportInternalError(f'Error while creating directory {path}: {exc}') + + async def remove_async(self, path): + """Remove the file at the given path. This only works on files; + for removing folders (directories), use rmdir. + + :param str path: path to file to remove + + :raise OSError: if the path is a directory + """ + path = fix_path(path) + # TODO: check if asyncssh does return SFTPFileIsADirectory in this case + # if that's the case, we can get rid of the isfile check + if await self.isdir_async(path): + raise OSError(f'The path {path} is a directory') + else: + await self._sftp.remove(path) + + async def rename_async(self, oldpath, newpath): + """ + Rename a file or folder from oldpath to newpath. + + :param str oldpath: existing name of the file or folder + :param str newpath: new name for the file or folder + + :raises OSError: if oldpath/newpath is not found + :raises ValueError: if oldpath/newpath is not a valid string + """ + oldpath = fix_path(oldpath) + newpath = fix_path(newpath) + if not oldpath or not newpath: + raise ValueError('oldpath and newpath must be non-empty strings') + + if await self._sftp.exists(newpath): + raise OSError(f'Cannot rename {oldpath} to {newpath}: destination exists') + + await self._sftp.rename(oldpath, newpath) + + async def rmdir_async(self, path): + """Remove the folder named path. + This works only for empty folders. For recursive remove, use rmtree. + + :param str path: absolute path to the folder to remove + """ + path = fix_path(path) + try: + await self._sftp.rmdir(path) + except asyncssh.sftp.SFTPFailure: + raise OSError(f'Error while removing directory {path}: probably directory is not empty') + + async def rmtree_async(self, path): + """Remove the folder named path, and all its contents. + + :param str path: absolute path to the folder to remove + """ + path = fix_path(path) + try: + await self._sftp.rmtree(path, ignore_errors=False) + except asyncssh.Error as exc: + raise OSError(f'Error while removing directory tree {path}: {exc}') + + async def path_exists_async(self, path): + """Returns True if path exists, False otherwise.""" + path = fix_path(path) + return await self._sftp.exists(path) + + async def whoami_async(self): + """Get the remote username + + :return: list of username (str), + retval (int), + stderr (str) + """ + command = 'whoami' + # Assuming here that the username is either ASCII or UTF-8 encoded + # This should be true essentially always + retval, username, stderr = await self.exec_command_wait_async(command) + if retval == 0: + if stderr.strip(): + self.logger.warning(f'There was nonempty stderr in the whoami command: {stderr}') + return username.strip() + + self.logger.error(f"Problem executing whoami. Exit code: {retval}, stdout: '{username}', stderr: '{stderr}'") + raise OSError(f'Error while executing whoami. Exit code: {retval}') + + async def symlink_async(self, remotesource, remotedestination): + """Create a symbolic link between the remote source and the remote + destination. + + :param remotesource: absolute path to remote source + :param remotedestination: absolute path to remote destination + """ + remotesource = fix_path(remotesource) + remotedestination = fix_path(remotedestination) + + if self.has_magic(remotesource): + if self.has_magic(remotedestination): + raise ValueError('`remotedestination` cannot have patterns') + + # find all files matching pattern + for this_source in await self._sftp.glob(remotesource): + # create the name of the link: take the last part of the path + this_dest = os.path.join(remotedestination, os.path.split(this_source)[-1]) + await self._sftp.symlink(this_source, this_dest) + else: + await self._sftp.symlink(remotesource, remotedestination) + + async def glob_async(self, pathname): + """Return a list of paths matching a pathname pattern. + + The pattern may contain simple shell-style wildcards a la fnmatch. + + :param str pathname: the pathname pattern to match. + It should only be an absolute path. + :return: a list of paths matching the pattern. + """ + return await self._sftp.glob(pathname) + + async def chmod_async(self, path, mode, follow_symlinks=True): + """Change the permissions of a file. + + :param str path: path to the file + :param int mode: the new permissions + """ + path = fix_path(path) + if not path: + raise OSError('Input path is an empty argument.') + try: + await self._sftp.chmod(path, mode, follow_symlinks=follow_symlinks) + except asyncssh.sftp.SFTPNoSuchFile as exc: + raise OSError(f'Error {exc}, directory does not exists') + + # ## Blocking methods. We need these for backwards compatibility + # def run_command_blocking(self, func, *args, **kwargs): + # """Call an async method blocking. + # This is useful, only because in some part of engine and + # many external plugins are synchronous function calls make more sense. + # However, be aware these synchronous calls probably won't be performant.""" + # return asyncio.run(func(*args, **kwargs)) + + def run_command_blocking(self, func, *args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete(func(*args, **kwargs)) + + def open(self): + return self.run_command_blocking(self.open_async) + + def close(self): + return self.run_command_blocking(self.close_async) + + def chown(self, *args, **kwargs): + raise NotImplementedError('Not implemented, for now') + + def get(self, *args, **kwargs): + return self.run_command_blocking(self.get_async, *args, **kwargs) + + def getfile(self, *args, **kwargs): + return self.run_command_blocking(self.getfile_async, *args, **kwargs) + + def gettree(self, *args, **kwargs): + return self.run_command_blocking(self.gettree_async, *args, **kwargs) + + def put(self, *args, **kwargs): + return self.run_command_blocking(self.put_async, *args, **kwargs) + + def putfile(self, *args, **kwargs): + return self.run_command_blocking(self.putfile_async, *args, **kwargs) + + def puttree(self, *args, **kwargs): + return self.run_command_blocking(self.puttree_async, *args, **kwargs) + + def chmod(self, *args, **kwargs): + return self.run_command_blocking(self.chmod_async, *args, **kwargs) + + def copy(self, *args, **kwargs): + return self.run_command_blocking(self.copy_async, *args, **kwargs) + + def copyfile(self, *args, **kwargs): + return self.copy(*args, **kwargs) + + def copytree(self, *args, **kwargs): + return self.copy(*args, **kwargs) + + def exec_command_wait(self, *args, **kwargs): + return self.run_command_blocking(self.exec_command_wait_async, *args, **kwargs) + + def get_attribute(self, *args, **kwargs): + return self.run_command_blocking(self.get_attribute_async, *args, **kwargs) + + def isdir(self, *args, **kwargs): + return self.run_command_blocking(self.isdir_async, *args, **kwargs) + + def isfile(self, *args, **kwargs): + return self.run_command_blocking(self.isfile_async, *args, **kwargs) + + def listdir(self, *args, **kwargs): + return self.run_command_blocking(self.listdir_async, *args, **kwargs) + + def listdir_withattributes(self, *args, **kwargs): + return self.run_command_blocking(self.listdir_withattributes_async, *args, **kwargs) + + def makedirs(self, *args, **kwargs): + return self.run_command_blocking(self.makedirs_async, *args, **kwargs) + + def mkdir(self, *args, **kwargs): + return self.run_command_blocking(self.mkdir_async, *args, **kwargs) + + def remove(self, *args, **kwargs): + return self.run_command_blocking(self.remove_async, *args, **kwargs) + + def rename(self, *args, **kwargs): + return self.run_command_blocking(self.rename_async, *args, **kwargs) + + def rmdir(self, *args, **kwargs): + return self.run_command_blocking(self.rmdir_async, *args, **kwargs) + + def rmtree(self, *args, **kwargs): + return self.run_command_blocking(self.rmtree_async, *args, **kwargs) + + def path_exists(self, *args, **kwargs): + return self.run_command_blocking(self.path_exists_async, *args, **kwargs) + + def whoami(self, *args, **kwargs): + return self.run_command_blocking(self.whoami_async, *args, **kwargs) + + def symlink(self, *args, **kwargs): + return self.run_command_blocking(self.symlink_async, *args, **kwargs) + + def glob(self, *args, **kwargs): + return self.run_command_blocking(self.glob_async, *args, **kwargs) + + def gotocomputer_command(self, remotedir): + connect_string = self._gotocomputer_string(remotedir) + cmd = f'ssh -t {self.machine} {connect_string}' + return cmd + + ## These methods are not implemented for async transport, + ## mainly because they are not being used across the codebase. + ## If you need them, please open an issue on GitHub + + def exec_command_wait_bytes(self, *args, **kwargs): + raise NotImplementedError('Not implemented, waiting for a use case') + + def _exec_command_internal(self, *args, **kwargs): + raise NotImplementedError('Not implemented, waiting for a use case') + + def normalize(self, *args, **kwargs): + raise NotImplementedError('Not implemented, waiting for a use case') + + def chdir(self, *args, **kwargs): + raise NotImplementedError("It's not safe to chdir() for async transport") + + def getcwd(self, *args, **kwargs): + raise NotImplementedError("It's not safe to getcwd() for async transport") diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index 311f4bbdf6..944d6225e3 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -15,6 +15,7 @@ import sys from collections import OrderedDict from pathlib import Path +from typing import Union from aiida.common.exceptions import InternalError from aiida.common.lang import classproperty @@ -22,6 +23,8 @@ __all__ = ('Transport',) +_TransportPath = Union[str, Path] + def validate_positive_number(ctx, param, value): """Validate that the number passed to this parameter is a positive number. @@ -39,6 +42,11 @@ def validate_positive_number(ctx, param, value): return value +def fix_path(path: _TransportPath) -> str: + """Convert a Path object to a string.""" + return str(path) + + class Transport(abc.ABC): """Abstract class for a generic transport (ssh, local, ...) contains the set of minimal methods.""" @@ -47,7 +55,7 @@ class Transport(abc.ABC): # This is used as a global default in case subclasses don't redefine this, # but this should be redefined in plugins where appropriate - _DEFAULT_SAFE_OPEN_INTERVAL = 30.0 + _DEFAULT_SAFE_OPEN_INTERVAL = 3.0 # To be defined in the subclass # See the ssh or local plugin to see the format @@ -372,6 +380,57 @@ def copy_from_remote_to_remote(self, transportdestination, remotesource, remoted for filename in sandbox.get_content_list(): transportdestination.put(os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put) + async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): + """Copy files or folders from a remote computer to another remote computer, asynchronously. + + :param transportdestination: transport to be used for the destination computer + :param str remotesource: path to the remote source directory / file + :param str remotedestination: path to the remote destination directory / file + :param kwargs: keyword parameters passed to the call to transportdestination.put, + except for 'dereference' that is passed to self.get + + .. note:: the keyword 'dereference' SHOULD be set to False for the + final put (onto the destination), while it can be set to the + value given in kwargs for the get from the source. In that + way, a symbolic link would never be followed in the final + copy to the remote destination. That way we could avoid getting + unknown (potentially malicious) files into the destination computer. + HOWEVER, since dereference=False is currently NOT + supported by all plugins, we still force it to True for the final put. + + .. note:: the supported keys in kwargs are callback, dereference, + overwrite and ignore_nonexisting. + """ + from aiida.common.folders import SandboxFolder + + kwargs_get = { + 'callback': None, + 'dereference': kwargs.pop('dereference', True), + 'overwrite': True, + 'ignore_nonexisting': False, + } + kwargs_put = { + 'callback': kwargs.pop('callback', None), + 'dereference': True, + 'overwrite': kwargs.pop('overwrite', True), + 'ignore_nonexisting': kwargs.pop('ignore_nonexisting', False), + } + + if kwargs: + self.logger.error('Unknown parameters passed to copy_from_remote_to_remote') + + with SandboxFolder() as sandbox: + await self.get_async(remotesource, sandbox.abspath, **kwargs_get) + # Then we scan the full sandbox directory with get_content_list, + # because copying directly from sandbox.abspath would not work + # to copy a single file into another single file, and copying + # from sandbox.get_abs_path('*') would not work for files + # beginning with a dot ('.'). + for filename in sandbox.get_content_list(): + await transportdestination.put_async( + os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put + ) + @abc.abstractmethod def _exec_command_internal(self, command, workdir=None, **kwargs): """Execute the command on the shell, similarly to os.system. @@ -398,7 +457,7 @@ def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): The command implementation can have some additional plugin-specific kwargs. :param str command: execute the command given as a string - :param stdin: (optional,default=None) can be a string or a file-like object. + :param stdin: (optional,default=None) can be bytes or a file-like object. :param workdir: (optional, default=None) if set, the command will be executed in the specified working directory. :return: a tuple: the retcode (int), stdout (bytes) and stderr (bytes). @@ -436,6 +495,9 @@ def get(self, remotepath, localpath, *args, **kwargs): """Retrieve a file or folder from remote source to local destination dst must be an absolute path (src not necessarily) + This method should be able to handle remothpath containing glob patterns, + in that case should only downloading matching patterns. + :param remotepath: (str) remote_folder_path :param localpath: (str) local_folder_path """ @@ -444,7 +506,6 @@ async def get_async(self, remotepath, localpath, *args, **kwargs): """ Retrieve a file or folder from remote source to local destination dst must be an absolute path (src not necessarily) - :param remotepath: (str) remote_folder_path :param localpath: (str) local_folder_path """ @@ -518,6 +579,7 @@ def get_mode(self, path): @abc.abstractmethod def isdir(self, path): """True if path is an existing directory. + Return False also if the path does not exist. :param str path: path to directory :return: boolean @@ -526,6 +588,7 @@ def isdir(self, path): @abc.abstractmethod def isfile(self, path): """Return True if path is an existing file. + Return False also if the path does not exist. :param str path: path to file :return: boolean @@ -543,7 +606,7 @@ def listdir(self, path='.', pattern=None): :return: a list of strings """ - def listdir_withattributes(self, path='.', pattern=None): + def listdir_withattributes(self, path: _TransportPath = '.', pattern=None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -567,6 +630,7 @@ def listdir_withattributes(self, path='.', pattern=None): (if the file is a folder, a directory, ...). 'attributes' behaves as the output of transport.get_attribute(); isdir is a boolean indicating if the object is a directory or not. """ + path = fix_path(path) retlist = [] if path.startswith('/'): cwd = Path(path).resolve().as_posix() @@ -624,6 +688,9 @@ def put(self, localpath, remotepath, *args, **kwargs): src must be an absolute path (dst not necessarily)) Redirects to putfile and puttree. + This method should be able to handle localpath containing glob patterns, + in that case should only uploading matching patterns. + :param str localpath: absolute path to local source :param str remotepath: path to remote destination """ @@ -633,7 +700,6 @@ async def put_async(self, localpath, remotepath, *args, **kwargs): Put a file or a directory from local src to remote dst. src must be an absolute path (dst not necessarily)) Redirects to putfile and puttree. - :param str localpath: absolute path to local source :param str remotepath: path to remote destination """ @@ -691,6 +757,8 @@ def rmtree(self, path): """Remove recursively the content at path :param str path: absolute path to remove + + :raise OSError: if the rm execution failed. """ @abc.abstractmethod @@ -796,13 +864,10 @@ def iglob(self, pathname): def glob1(self, dirname, pattern): """Match subpaths of dirname against pattern.""" if not dirname: - # dirname = os.curdir # ORIGINAL dirname = self.getcwd() if isinstance(pattern, str) and not isinstance(dirname, str): dirname = dirname.decode(sys.getfilesystemencoding() or sys.getdefaultencoding()) try: - # names = os.listdir(dirname) - # print dirname names = self.listdir(dirname) except EnvironmentError: return [] diff --git a/src/aiida/transports/util.py b/src/aiida/transports/util.py index 12d3b3d882..a75547a0e3 100644 --- a/src/aiida/transports/util.py +++ b/src/aiida/transports/util.py @@ -9,12 +9,58 @@ """General utilities for Transport classes.""" import time +from pathlib import Path, PurePosixPath +from typing import Union from paramiko import ProxyCommand from aiida.common.extendeddicts import FixedFieldsAttributeDict +class StrPath: + """A class to chain paths together. + This is useful to avoid the need to use os.path.join to chain paths. + + Note: + Eventually transport plugins may further develope so that functions with pathlib.Path + So far they are expected to work only with POSIX paths. + This class is a solution to avoid the need to use Path.join to chain paths and convert back again to str. + """ + + def __init__(self, path: Union[str, PurePosixPath]): + """Chain a path with multiple paths. + + :param path: the initial path (absolute) + """ + if isinstance(path, PurePosixPath): + path = str(path) + self.str = path.rstrip('/') + + def join(self, *paths: Union[str, PurePosixPath, Path], return_str=True) -> Union[str, 'StrPath']: + """Join the initial path with multiple paths. + + :param paths: the paths to chain (relative to the previous path) + :param paths: It should be of type str or Path or PurePosixPath + :param return_str: if True, return a string, otherwise return a new StrPath object + + :return: a new StrPath object + """ + path = self.str + for p in paths: + p_ = str(p) if isinstance(p, (PurePosixPath, Path)) else p + if self.str in p_: + raise ValueError( + 'The path to join is already included in the initial path, ' + 'probably you are trying to join an absolute path' + ) + path = f"{path}/{p_.strip('/')}" + + if return_str: + return path + + return StrPath(path) + + class FileAttribute(FixedFieldsAttributeDict): """A class, resembling a dictionary, to describe the attributes of a file, that is returned by get_attribute(). @@ -86,3 +132,24 @@ def copy_from_remote_to_remote(transportsource, transportdestination, remotesour .. note:: it uses the method transportsource.copy_from_remote_to_remote """ transportsource.copy_from_remote_to_remote(transportdestination, remotesource, remotedestination, **kwargs) + + +async def copy_from_remote_to_remote_async( + transportsource, transportdestination, remotesource, remotedestination, **kwargs +): + """Copy files or folders from a remote computer to another remote computer. + Note: To have a proper async performance, + both transports should be instance `core.async_ssh`. + Even if either or both are not async, the function will work, + but the performance might be lower than the sync version. + + :param transportsource: transport to be used for the source computer + :param transportdestination: transport to be used for the destination computer + :param str remotesource: path to the remote source directory / file + :param str remotedestination: path to the remote destination directory / file + :param kwargs: keyword parameters passed to the final put, + except for 'dereference' that is passed to the initial get + + .. note:: it uses the method transportsource.copy_from_remote_to_remote + """ + await transportsource.copy_from_remote_to_remote(transportdestination, remotesource, remotedestination, **kwargs) diff --git a/tests/engine/daemon/test_execmanager.py b/tests/engine/daemon/test_execmanager.py index beaba4aec7..79692b689b 100644 --- a/tests/engine/daemon/test_execmanager.py +++ b/tests/engine/daemon/test_execmanager.py @@ -72,41 +72,51 @@ def test_hierarchy_utility(file_hierarchy, tmp_path, create_file_hierarchy, seri assert serialize_file_hierarchy(tmp_path, read_bytes=False) == file_hierarchy -@pytest.mark.parametrize('retrieve_list, expected_hierarchy', ( - # Single file or folder, either toplevel or nested - (['file_a.txt'], {'file_a.txt': 'file_a'}), - (['path/sub/file_c.txt'], {'file_c.txt': 'file_c'}), - (['path'], {'path': {'file_b.txt': 'file_b', 'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), - (['path/sub'], {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), - (['*.txt'], {'file_a.txt': 'file_a'}), - (['*/*.txt'], {'file_b.txt': 'file_b'}), - # Single nested file that is retrieved keeping a varying level of depth of original hierarchy - ([('path/sub/file_c.txt', '.', 3)], {'path': {'sub': {'file_c.txt': 'file_c'}}}), - ([('path/sub/file_c.txt', '.', 2)], {'sub': {'file_c.txt': 'file_c'}}), - ([('path/sub/file_c.txt', '.', 1)], {'file_c.txt': 'file_c'}), - ([('path/sub/file_c.txt', '.', 0)], {'file_c.txt': 'file_c'}), - # Single nested folder that is retrieved keeping a varying level of depth of original hierarchy - ([('path/sub', '.', 2)], {'path': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), - ([('path/sub', '.', 1)], {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), - # Using globbing patterns - ([('path/*', '.', 0)], {'file_b.txt': 'file_b', 'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), - ([('path/sub/*', '.', 0)], {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}), # This is identical to ['path/sub'] - ([('path/sub/*c.txt', '.', 2)], {'sub': {'file_c.txt': 'file_c'}}), - ([('path/sub/*c.txt', '.', 0)], {'file_c.txt': 'file_c'}), - # Using globbing with depth `None` should maintain exact folder hierarchy - ([('path/*.txt', '.', None)], {'path': {'file_b.txt': 'file_b'}}), - ([('path/sub/*.txt', '.', None)], {'path': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), - # Different target directory - ([('path/sub/file_c.txt', 'target', 3)], {'target': {'path': {'sub': {'file_c.txt': 'file_c'}}}}), - ([('path/sub', 'target', 1)], {'target': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), - ([('path/sub/*c.txt', 'target', 2)], {'target': {'sub': {'file_c.txt': 'file_c'}}}), - # Missing files should be ignored and not cause the retrieval to except - (['file_a.txt', 'file_u.txt', 'path/file_u.txt', ('path/sub/file_u.txt', '.', 3)], {'file_a.txt': 'file_a'}), -)) -# yapf: enable -@pytest.mark.asyncio -async def test_retrieve_files_from_list( - tmp_path_factory, generate_calculation_node, file_hierarchy, retrieve_list, expected_hierarchy +@pytest.mark.parametrize( + 'retrieve_list, expected_hierarchy', + ( + # Single file or folder, either toplevel or nested + (['file_a.txt'], {'file_a.txt': 'file_a'}), + (['path/sub/file_c.txt'], {'file_c.txt': 'file_c'}), + (['path'], {'path': {'file_b.txt': 'file_b', 'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), + (['path/sub'], {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), + (['*.txt'], {'file_a.txt': 'file_a'}), + (['*/*.txt'], {'file_b.txt': 'file_b'}), + # Single nested file that is retrieved keeping a varying level of depth of original hierarchy + ([('path/sub/file_c.txt', '.', 3)], {'path': {'sub': {'file_c.txt': 'file_c'}}}), + ([('path/sub/file_c.txt', '.', 2)], {'sub': {'file_c.txt': 'file_c'}}), + ([('path/sub/file_c.txt', '.', 1)], {'file_c.txt': 'file_c'}), + ([('path/sub/file_c.txt', '.', 0)], {'file_c.txt': 'file_c'}), + # Single nested folder that is retrieved keeping a varying level of depth of original hierarchy + ([('path/sub', '.', 2)], {'path': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), + ([('path/sub', '.', 1)], {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), + # Using globbing patterns + ([('path/*', '.', 0)], {'file_b.txt': 'file_b', 'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}), + ( + [('path/sub/*', '.', 0)], + {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}, + ), # This is identical to ['path/sub'] + ([('path/sub/*c.txt', '.', 2)], {'sub': {'file_c.txt': 'file_c'}}), + ([('path/sub/*c.txt', '.', 0)], {'file_c.txt': 'file_c'}), + # Using globbing with depth `None` should maintain exact folder hierarchy + ([('path/*.txt', '.', None)], {'path': {'file_b.txt': 'file_b'}}), + ([('path/sub/*.txt', '.', None)], {'path': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), + # Different target directory + ([('path/sub/file_c.txt', 'target', 3)], {'target': {'path': {'sub': {'file_c.txt': 'file_c'}}}}), + ([('path/sub', 'target', 1)], {'target': {'sub': {'file_c.txt': 'file_c', 'file_d.txt': 'file_d'}}}), + ([('path/sub/*c.txt', 'target', 2)], {'target': {'sub': {'file_c.txt': 'file_c'}}}), + # Missing files should be ignored and not cause the retrieval to except + (['file_a.txt', 'file_u.txt', 'path/file_u.txt', ('path/sub/file_u.txt', '.', 3)], {'file_a.txt': 'file_a'}), + ), +) +def test_retrieve_files_from_list( + tmp_path_factory, + generate_calcjob_node, + file_hierarchy, + retrieve_list, + expected_hierarchy, + create_file_hierarchy, + serialize_file_hierarchy, ): """Test the `retrieve_files_from_list` function.""" source = tmp_path_factory.mktemp('source') @@ -115,27 +125,34 @@ async def test_retrieve_files_from_list( create_file_hierarchy(file_hierarchy, source) with LocalTransport() as transport: - node = generate_calculation_node() - transport.chdir(source) - await execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) + node = generate_calcjob_node(workdir=source) + execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) assert serialize_file_hierarchy(target, read_bytes=False) == expected_hierarchy -@pytest.mark.parametrize(('local_copy_list', 'expected_hierarchy'), ( - ([None, None], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), - (['.', None], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), - ([None, '.'], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), - (['.', '.'], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), - ([None, ''], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), - (['sub', None], {'b': 'file_b'}), - ([None, 'target'], {'target': {'sub': {'b': 'file_b'}, 'a': 'file_a'}}), - (['sub', 'target'], {'target': {'b': 'file_b'}}), -)) -# yapf: enable -@pytest.mark.asyncio -async def test_upload_local_copy_list( - fixture_sandbox, node_and_calc_info, file_hierarchy_simple, tmp_path, local_copy_list, expected_hierarchy +@pytest.mark.parametrize( + ('local_copy_list', 'expected_hierarchy'), + ( + ([None, None], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), + (['.', None], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), + ([None, '.'], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), + (['.', '.'], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), + ([None, ''], {'sub': {'b': 'file_b'}, 'a': 'file_a'}), + (['sub', None], {'b': 'file_b'}), + ([None, 'target'], {'target': {'sub': {'b': 'file_b'}, 'a': 'file_a'}}), + (['sub', 'target'], {'target': {'b': 'file_b'}}), + ), +) +def test_upload_local_copy_list( + fixture_sandbox, + node_and_calc_info, + file_hierarchy_simple, + tmp_path, + local_copy_list, + expected_hierarchy, + create_file_hierarchy, + serialize_file_hierarchy, ): """Test the ``local_copy_list`` functionality in ``upload_calculation``.""" create_file_hierarchy(file_hierarchy_simple, tmp_path) @@ -146,8 +163,8 @@ async def test_upload_local_copy_list( node, calc_info = node_and_calc_info calc_info.local_copy_list = [[folder.uuid] + local_copy_list] - with LocalTransport() as transport: - await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + with node.computer.get_transport() as transport: + execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) # Check that none of the files were written to the repository of the calculation node, since they were communicated # through the ``local_copy_list``. @@ -158,8 +175,9 @@ async def test_upload_local_copy_list( assert written_hierarchy == expected_hierarchy -@pytest.mark.asyncio -async def test_upload_local_copy_list_files_folders(fixture_sandbox, node_and_calc_info, file_hierarchy, tmp_path): +def test_upload_local_copy_list_files_folders( + fixture_sandbox, node_and_calc_info, file_hierarchy, tmp_path, create_file_hierarchy, serialize_file_hierarchy +): """Test the ``local_copy_list`` functionality in ``upload_calculation``. Specifically, verify that files in the ``local_copy_list`` do not end up in the repository of the node. @@ -182,8 +200,8 @@ async def test_upload_local_copy_list_files_folders(fixture_sandbox, node_and_ca (inputs['folder'].uuid, None, '.'), ] - with LocalTransport() as transport: - await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + with node.computer.get_transport() as transport: + execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) # Check that none of the files were written to the repository of the calculation node, since they were communicated # through the ``local_copy_list``. diff --git a/tests/engine/processes/test_caching.py b/tests/engine/processes/test_caching.py index a64cf0c62f..224b87dbe3 100644 --- a/tests/engine/processes/test_caching.py +++ b/tests/engine/processes/test_caching.py @@ -16,7 +16,7 @@ def define(cls, spec): spec.input('a') spec.output_namespace('nested', dynamic=True) - async def run(self): + def run(self): self.out('nested', {'a': self.inputs.a + 2}) diff --git a/tests/engine/test_process.py b/tests/engine/test_process.py index df4252edb9..9b1f230041 100644 --- a/tests/engine/test_process.py +++ b/tests/engine/test_process.py @@ -77,7 +77,7 @@ class ProcessStackTest(Process): _node_class = orm.WorkflowNode @override - async def run(self): + def run(self): pass @override @@ -323,7 +323,7 @@ def define(cls, spec): spec.input_namespace('namespace', valid_type=orm.Int, dynamic=True) spec.output_namespace('namespace', valid_type=orm.Int, dynamic=True) - async def run(self): + def run(self): self.out('namespace', self.inputs.namespace) results, node = run_get_node(TestProcess1, namespace={'alpha': orm.Int(1), 'beta': orm.Int(2)}) @@ -347,7 +347,7 @@ def define(cls, spec): spec.output_namespace('integer.namespace', valid_type=orm.Int, dynamic=True) spec.output('required_string', valid_type=orm.Str, required=True) - async def run(self): + def run(self): if self.inputs.add_outputs: self.out('required_string', orm.Str('testing').store()) self.out('integer.namespace.two', orm.Int(2).store()) diff --git a/tests/engine/test_runners.py b/tests/engine/test_runners.py index 44ce6b9771..4f746e5d34 100644 --- a/tests/engine/test_runners.py +++ b/tests/engine/test_runners.py @@ -37,7 +37,7 @@ def define(cls, spec): super().define(spec) spec.input('a') - async def run(self): + def run(self): pass diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 9ea07a6591..1fbb578b2a 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -1658,5 +1658,5 @@ def define(cls, spec): super().define(spec) spec.outline(cls.run) - async def run(self): + def run(self): pass diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 0a25add0a7..6d4fc8a6ad 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -13,7 +13,6 @@ import io import os -import pathlib import random import shutil import signal @@ -21,10 +20,11 @@ import tempfile import time import uuid +from pathlib import Path import psutil import pytest -from aiida.plugins import SchedulerFactory, TransportFactory, entry_point +from aiida.plugins import SchedulerFactory, TransportFactory from aiida.transports import Transport # TODO : test for copy with pattern @@ -33,7 +33,8 @@ # TODO : silly cases of copy/put/get from self to self -@pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) +# @pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) +@pytest.fixture(scope='function', params=['core.ssh', 'core.ssh_async']) def custom_transport(request, tmp_path, monkeypatch) -> Transport: """Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``.""" plugin = TransportFactory(request.param) @@ -46,6 +47,8 @@ def custom_transport(request, tmp_path, monkeypatch) -> Transport: monkeypatch.setattr(plugin, 'FILEPATH_CONFIG', filepath_config) if not filepath_config.exists(): filepath_config.write_text('Host localhost') + elif request.param == 'core.ssh_async': + kwargs = {'machine': 'localhost'} else: kwargs = {} @@ -62,122 +65,91 @@ def test_is_open(custom_transport): assert not custom_transport.is_open -def test_makedirs(custom_transport): +def test_makedirs(custom_transport, tmpdir): """Verify the functioning of makedirs command""" with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) - directory = 'temp_dir_test' - transport.chdir(location) + _scratch = Path(tmpdir / 'sampledir') + transport.mkdir(_scratch) + assert _scratch.exists() - assert location == transport.getcwd() - while transport.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - transport.chdir(directory) - - # define folder structure - dir_tree = os.path.join('1', '2') - # I create the tree - transport.makedirs(dir_tree) - # verify the existence - assert transport.isdir('1') - assert dir_tree - - # try to recreate the same folder + _scratch = tmpdir / 'sampledir2' / 'subdir' + transport.makedirs(_scratch) + assert _scratch.exists() + + # raise if directory already exists + with pytest.raises(OSError): + transport.makedirs(tmpdir / 'sampledir2') with pytest.raises(OSError): - transport.makedirs(dir_tree) + transport.mkdir(tmpdir / 'sampledir') + + # don't raise if directory already exists and ignore_existing is True + transport.mkdir(tmpdir / 'sampledir', ignore_existing=True) + transport.makedirs(tmpdir / 'sampledir2', ignore_existing=True) - # recreate but with ignore flag - transport.makedirs(dir_tree, True) - transport.rmdir(dir_tree) - transport.rmdir('1') +def test_is_dir(custom_transport, tmpdir): + with custom_transport as transport: + _scratch = tmpdir / 'sampledir' + transport.mkdir(_scratch) - transport.chdir('..') - transport.rmdir(directory) + assert transport.isdir(_scratch) + assert not transport.isdir(_scratch / 'does_not_exist') -def test_rmtree(custom_transport): +def test_rmtree(custom_transport, tmpdir): """Verify the functioning of rmtree command""" with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) - directory = 'temp_dir_test' - transport.chdir(location) - - assert location == transport.getcwd() - while transport.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - transport.chdir(directory) - - # define folder structure - dir_tree = os.path.join('1', '2') - # I create the tree - transport.makedirs(dir_tree) - # remove it - transport.rmtree('1') - # verify the removal - assert not transport.isdir('1') - - # also tests that it works with a single file - # create file - local_file_name = 'file.txt' - text = 'Viva Verdi\n' - with open(os.path.join(transport.getcwd(), local_file_name), 'w', encoding='utf8') as fhandle: - fhandle.write(text) - # remove it - transport.rmtree(local_file_name) - # verify the removal - assert not transport.isfile(local_file_name) + _remote = tmpdir / 'remote' + _local = tmpdir / 'local' + _remote.mkdir() + _local.mkdir() + + Path(_local / 'samplefile').touch() + + # remove a non-empty directory with rmtree() + _scratch = _remote / 'sampledir' + _scratch.mkdir() + Path(_remote / 'sampledir' / 'samplefile_remote').touch() + transport.rmtree(_scratch) + assert not _scratch.exists() + + # remove a non-empty directory should raise with rmdir() + transport.mkdir(_remote / 'sampledir') + Path(_remote / 'sampledir' / 'samplefile_remote').touch() + with pytest.raises(OSError): + transport.rmdir(_remote / 'sampledir') - transport.chdir('..') - transport.rmdir(directory) + # remove a file with remove() + transport.remove(_remote / 'sampledir' / 'samplefile_remote') + assert not Path(_remote / 'sampledir' / 'samplefile_remote').exists() + # remove a empty directory with rmdir + transport.rmdir(_remote / 'sampledir') + assert not _scratch.exists() -def test_listdir(custom_transport): - """Create directories, verify listdir, delete a folder with subfolders""" - with custom_transport as trans: - # We cannot use tempfile.mkdtemp because we're on a remote folder - location = trans.normalize(os.path.join('/', 'tmp')) - directory = 'temp_dir_test' - trans.chdir(location) - assert location == trans.getcwd() - while trans.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) - trans.mkdir(directory) - trans.chdir(directory) +def test_listdir(custom_transport, tmpdir): + """Create directories, verify listdir""" + with custom_transport as transport: list_of_dir = ['1', '-f a&', 'as', 'a2', 'a4f'] list_of_files = ['a', 'b'] for this_dir in list_of_dir: - trans.mkdir(this_dir) + transport.mkdir(tmpdir / this_dir) for fname in list_of_files: with tempfile.NamedTemporaryFile() as tmpf: # Just put an empty file there at the right file name - trans.putfile(tmpf.name, fname) + transport.putfile(tmpf.name, tmpdir / fname) - list_found = trans.listdir('.') + list_found = transport.listdir(tmpdir) assert sorted(list_found) == sorted(list_of_dir + list_of_files) - assert sorted(trans.listdir('.', 'a*')), sorted(['as', 'a2', 'a4f']) - assert sorted(trans.listdir('.', 'a?')), sorted(['as', 'a2']) - assert sorted(trans.listdir('.', 'a[2-4]*')), sorted(['a2', 'a4f']) - - for this_dir in list_of_dir: - trans.rmdir(this_dir) - - for this_file in list_of_files: - trans.remove(this_file) + assert sorted(transport.listdir(tmpdir, 'a*')), sorted(['as', 'a2', 'a4f']) + assert sorted(transport.listdir(tmpdir, 'a?')), sorted(['as', 'a2']) + assert sorted(transport.listdir(tmpdir, 'a[2-4]*')), sorted(['a2', 'a4f']) - trans.chdir('..') - trans.rmdir(directory) - -def test_listdir_withattributes(custom_transport): +def test_listdir_withattributes(custom_transport, tmpdir): """Create directories, verify listdir_withattributes, delete a folder with subfolders""" def simplify_attributes(data): @@ -189,127 +161,70 @@ def simplify_attributes(data): """ return {_['name']: _['isdir'] for _ in data} - with custom_transport as trans: - # We cannot use tempfile.mkdtemp because we're on a remote folder - location = trans.normalize(os.path.join('/', 'tmp')) - directory = 'temp_dir_test' - trans.chdir(location) - - assert location == trans.getcwd() - while trans.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) - trans.mkdir(directory) - trans.chdir(directory) + with custom_transport as transport: list_of_dir = ['1', '-f a&', 'as', 'a2', 'a4f'] list_of_files = ['a', 'b'] for this_dir in list_of_dir: - trans.mkdir(this_dir) + transport.mkdir(tmpdir / this_dir) for fname in list_of_files: with tempfile.NamedTemporaryFile() as tmpf: # Just put an empty file there at the right file name - trans.putfile(tmpf.name, fname) + transport.putfile(tmpf.name, tmpdir / fname) comparison_list = {k: True for k in list_of_dir} for k in list_of_files: comparison_list[k] = False - assert simplify_attributes(trans.listdir_withattributes('.')), comparison_list - assert simplify_attributes(trans.listdir_withattributes('.', 'a*')), { + assert simplify_attributes(transport.listdir_withattributes(tmpdir)), comparison_list + assert simplify_attributes(transport.listdir_withattributes(tmpdir, 'a*')), { 'as': True, 'a2': True, 'a4f': True, 'a': False, } - assert simplify_attributes(trans.listdir_withattributes('.', 'a?')), {'as': True, 'a2': True} - assert simplify_attributes(trans.listdir_withattributes('.', 'a[2-4]*')), {'a2': True, 'a4f': True} - - for this_dir in list_of_dir: - trans.rmdir(this_dir) - - for this_file in list_of_files: - trans.remove(this_file) - - trans.chdir('..') - trans.rmdir(directory) - + assert simplify_attributes(transport.listdir_withattributes(tmpdir, 'a?')), {'as': True, 'a2': True} + assert simplify_attributes(transport.listdir_withattributes(tmpdir, 'a[2-4]*')), {'a2': True, 'a4f': True} -def test_dir_creation_deletion(custom_transport): - """Test creating and deleting directories.""" - with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) - directory = 'temp_dir_test' - transport.chdir(location) - - assert location == transport.getcwd() - while transport.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - - with pytest.raises(OSError): - # I create twice the same directory - transport.mkdir(directory) - transport.isdir(directory) - assert not transport.isfile(directory) - transport.rmdir(directory) - - -def test_dir_copy(custom_transport): +def test_dir_copy(custom_transport, tmpdir): """Verify if in the copy of a directory also the protection bits are carried over """ with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - transport.chdir(location) - - while transport.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) + transport.mkdir(tmpdir / directory) dest_directory = f'{directory}_copy' - transport.copy(directory, dest_directory) + transport.copy(tmpdir / directory, tmpdir / dest_directory) with pytest.raises(ValueError): - transport.copy(directory, '') + transport.copy(tmpdir / directory, '') with pytest.raises(ValueError): - transport.copy('', directory) - - transport.rmdir(directory) - transport.rmdir(dest_directory) + transport.copy('', tmpdir / directory) -def test_dir_permissions_creation_modification(custom_transport): +def test_dir_permissions_creation_modification(custom_transport, tmpdir): """Verify if chmod raises OSError when trying to change bits on a non-existing folder """ with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - transport.chdir(location) - - while transport.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) # create directory with non default permissions - transport.mkdir(directory) + transport.mkdir(tmpdir / directory) # change permissions - transport.chmod(directory, 0o777) + transport.chmod(tmpdir / directory, 0o777) # test if the security bits have changed - assert transport.get_mode(directory) == 0o777 + assert transport.get_mode(tmpdir / directory) == 0o777 # change permissions - transport.chmod(directory, 0o511) + transport.chmod(tmpdir / directory, 0o511) # test if the security bits have changed - assert transport.get_mode(directory) == 0o511 + assert transport.get_mode(tmpdir / directory) == 0o511 # TODO : bug in paramiko. When changing the directory to very low \ # I cannot set it back to higher permissions @@ -318,52 +233,31 @@ def test_dir_permissions_creation_modification(custom_transport): # the new directory modes. To see if we want a higher # level function to ask for the mode, or we just # use get_attribute - transport.chdir(directory) # change permissions of an empty string, non existing folder. - fake_dir = '' with pytest.raises(OSError): - transport.chmod(fake_dir, 0o777) + transport.chmod('', 0o777) + # change permissions of a non existing folder. fake_dir = 'pippo' with pytest.raises(OSError): # chmod to a non existing folder - transport.chmod(fake_dir, 0o777) - - transport.chdir('..') - transport.rmdir(directory) + transport.chmod(tmpdir / directory / fake_dir, 0o777) -def test_dir_reading_permissions(custom_transport): - """Try to enter a directory with no read permissions. - Verify that the cwd has not changed after failed try. - """ +def test_dir_reading_permissions(custom_transport, tmpdir): + """Try to enter a directory with no read & write permissions.""" with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' - transport.chdir(location) - - while transport.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) # create directory with non default permissions - transport.mkdir(directory) + transport.mkdir(tmpdir / directory) # change permissions to low ones - transport.chmod(directory, 0) + transport.chmod(tmpdir / directory, 0) # test if the security bits have changed - assert transport.get_mode(directory) == 0 - - old_cwd = transport.getcwd() - - with pytest.raises(OSError): - transport.chdir(directory) - - new_cwd = transport.getcwd() - - assert old_cwd == new_cwd + assert transport.get_mode(tmpdir / directory) == 0 # TODO : the test leaves a directory even if it is successful # The bug is in paramiko. After lowering the permissions, @@ -371,29 +265,21 @@ def test_dir_reading_permissions(custom_transport): # transport.rmdir(directory) -def test_isfile_isdir_to_empty_string(custom_transport): - """I check that isdir or isfile return False when executed on an - empty string - """ +def test_isfile_isdir(custom_transport, tmpdir): with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) - transport.chdir(location) assert not transport.isdir('') assert not transport.isfile('') + assert not transport.isfile(tmpdir / 'does_not_exist') + assert not transport.isdir(tmpdir / 'does_not_exist') + Path(tmpdir / 'samplefile').touch() + assert transport.isfile(tmpdir / 'samplefile') + assert not transport.isdir(tmpdir / 'samplefile') -def test_isfile_isdir_to_non_existing_string(custom_transport): - """I check that isdir or isfile return False when executed on an - empty string - """ - with custom_transport as transport: - location = transport.normalize(os.path.join('/', 'tmp')) - transport.chdir(location) - fake_folder = 'pippo' - assert not transport.isfile(fake_folder) - assert not transport.isdir(fake_folder) - with pytest.raises(OSError): - transport.chdir(fake_folder) + transport.mkdir(tmpdir / 'sampledir') + + assert transport.isdir(tmpdir / 'sampledir') + assert not transport.isfile(tmpdir / 'sampledir') def test_chdir_to_empty_string(custom_transport): @@ -401,35 +287,29 @@ def test_chdir_to_empty_string(custom_transport): not change (this is a paramiko default behavior), but getcwd() is still correctly defined. """ - with custom_transport as transport: - new_dir = transport.normalize(os.path.join('/', 'tmp')) - transport.chdir(new_dir) - transport.chdir('') - assert new_dir == transport.getcwd() + try: + with custom_transport as transport: + new_dir = transport.normalize(os.path.join('/', 'tmp')) + transport.chdir(new_dir) + transport.chdir('') + assert new_dir == transport.getcwd() + except NotImplementedError: + # chdir() is no longer an abstract method, to be removed from interface + pass -def test_put_and_get_file(custom_transport): +def test_put_and_get_file(custom_transport, tmpdir): """Test putting and getting files.""" - local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir - directory = 'tmp_try' + local_dir = tmpdir / 'local' + remote_dir = tmpdir / 'remote' + local_dir.mkdir() + remote_dir.mkdir() with custom_transport as transport: - transport.chdir(remote_dir) - while transport.isdir(directory): - # I append a random letter/number until it is unique - directory += random.choice(string.ascii_uppercase + string.digits) - - transport.mkdir(directory) - transport.chdir(directory) - - local_file_name = os.path.join(local_dir, directory, 'file.txt') - remote_file_name = 'file_remote.txt' - retrieved_file_name = os.path.join(local_dir, directory, 'file_retrieved.txt') - - text = 'Viva Verdi\n' - with open(local_file_name, 'w', encoding='utf8') as fhandle: - fhandle.write(text) + local_file_name = local_dir / 'file.txt' + Path(local_file_name).touch() + remote_file_name = remote_dir / 'file_remote.txt' + retrieved_file_name = os.path.join(local_dir, 'file_retrieved.txt') # here use full path in src and dst transport.put(local_file_name, remote_file_name) @@ -437,92 +317,82 @@ def test_put_and_get_file(custom_transport): transport.putfile(local_file_name, remote_file_name) transport.getfile(remote_file_name, retrieved_file_name) - list_of_files = transport.listdir('.') + list_of_files = transport.listdir(remote_dir) # it is False because local_file_name has the full path, # while list_of_files has not assert local_file_name not in list_of_files - assert remote_file_name in list_of_files - assert retrieved_file_name not in list_of_files - - os.remove(local_file_name) - transport.remove(remote_file_name) - os.remove(retrieved_file_name) - - transport.chdir('..') - transport.rmdir(directory) + assert 'file_remote.txt' in list_of_files + assert 'file_retrieved.txt' not in list_of_files + assert 'file_retrieved.txt' in os.listdir(local_dir) def test_put_get_abs_path_file(custom_transport): """Test of exception for non existing files and abs path""" local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while transport.isdir(directory): + while transport.isdir(remote_dir / directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - transport.chdir(directory) + transport.mkdir(remote_dir / directory) partial_file_name = 'file.txt' local_file_name = os.path.join(local_dir, directory, 'file.txt') remote_file_name = 'file_remote.txt' retrieved_file_name = os.path.join(local_dir, directory, 'file_retrieved.txt') - pathlib.Path(local_file_name).touch() - + Path(local_file_name).touch() + workdir = Path(remote_dir / directory) # partial_file_name is not an abs path with pytest.raises(ValueError): - transport.put(partial_file_name, remote_file_name) + transport.put(partial_file_name, workdir / remote_file_name) with pytest.raises(ValueError): - transport.putfile(partial_file_name, remote_file_name) + transport.putfile(partial_file_name, workdir / remote_file_name) # retrieved_file_name does not exist with pytest.raises(OSError): - transport.put(retrieved_file_name, remote_file_name) + transport.put(retrieved_file_name, workdir / remote_file_name) with pytest.raises(OSError): - transport.putfile(retrieved_file_name, remote_file_name) + transport.putfile(retrieved_file_name, workdir / remote_file_name) # remote_file_name does not exist with pytest.raises(OSError): - transport.get(remote_file_name, retrieved_file_name) + transport.get(workdir / remote_file_name, retrieved_file_name) with pytest.raises(OSError): - transport.getfile(remote_file_name, retrieved_file_name) + transport.getfile(workdir / remote_file_name, retrieved_file_name) - transport.put(local_file_name, remote_file_name) - transport.putfile(local_file_name, remote_file_name) + transport.put(local_file_name, workdir / remote_file_name) + transport.putfile(local_file_name, workdir / remote_file_name) # local filename is not an abs path with pytest.raises(ValueError): - transport.get(remote_file_name, 'delete_me.txt') + transport.get(workdir / remote_file_name, 'delete_me.txt') with pytest.raises(ValueError): - transport.getfile(remote_file_name, 'delete_me.txt') + transport.getfile(workdir / remote_file_name, 'delete_me.txt') - transport.remove(remote_file_name) + transport.remove(workdir / remote_file_name) os.remove(local_file_name) - transport.chdir('..') - transport.rmdir(directory) + transport.rmdir(workdir) def test_put_get_empty_string_file(custom_transport): """Test of exception put/get of empty strings""" # TODO : verify the correctness of \n at the end of a file local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while transport.isdir(directory): + while transport.isdir(remote_dir / directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - transport.chdir(directory) + transport.mkdir(remote_dir / directory) + workdir = Path(remote_dir / directory) local_file_name = os.path.join(local_dir, directory, 'file_local.txt') remote_file_name = 'file_remote.txt' @@ -545,9 +415,9 @@ def test_put_get_empty_string_file(custom_transport): with pytest.raises(OSError): transport.putfile(local_file_name, '') - transport.put(local_file_name, remote_file_name) + transport.put(local_file_name, workdir / remote_file_name) # overwrite the remote_file_name - transport.putfile(local_file_name, remote_file_name) + transport.putfile(local_file_name, workdir / remote_file_name) # remote path is an empty string with pytest.raises(OSError): @@ -564,30 +434,27 @@ def test_put_get_empty_string_file(custom_transport): # TODO : get doesn't retrieve empty files. # Is it what we want? - transport.get(remote_file_name, retrieved_file_name) + transport.get(workdir / remote_file_name, retrieved_file_name) # overwrite retrieved_file_name - transport.getfile(remote_file_name, retrieved_file_name) + transport.getfile(workdir / remote_file_name, retrieved_file_name) os.remove(local_file_name) - transport.remove(remote_file_name) + transport.remove(workdir / remote_file_name) # If it couldn't end the copy, it leaves what he did on # local file - assert 'file_retrieved.txt' in transport.listdir('.') + assert 'file_retrieved.txt' in transport.listdir(workdir) os.remove(retrieved_file_name) - transport.chdir('..') - transport.rmdir(directory) + transport.rmdir(workdir) def test_put_and_get_tree(custom_transport): """Test putting and getting files.""" local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) @@ -599,7 +466,7 @@ def test_put_and_get_tree(custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - transport.chdir(directory) + workdir = remote_dir / directory local_file_name = os.path.join(local_subfolder, 'file.txt') @@ -610,14 +477,14 @@ def test_put_and_get_tree(custom_transport): # here use full path in src and dst for i in range(2): if i == 0: - transport.put(local_subfolder, remote_subfolder) - transport.get(remote_subfolder, retrieved_subfolder) + transport.put(local_subfolder, workdir / remote_subfolder) + transport.get(workdir / remote_subfolder, retrieved_subfolder) else: - transport.puttree(local_subfolder, remote_subfolder) - transport.gettree(remote_subfolder, retrieved_subfolder) + transport.puttree(local_subfolder, workdir / remote_subfolder) + transport.gettree(workdir / remote_subfolder, retrieved_subfolder) # Here I am mixing the local with the remote fold - list_of_dirs = transport.listdir('.') + list_of_dirs = transport.listdir(workdir) # # it is False because local_file_name has the full path, # # while list_of_files has not assert local_subfolder not in list_of_dirs @@ -626,17 +493,16 @@ def test_put_and_get_tree(custom_transport): assert 'tmp1' in list_of_dirs assert 'tmp3' in list_of_dirs - list_pushed_file = transport.listdir('tmp2') - list_retrieved_file = transport.listdir('tmp3') + list_pushed_file = transport.listdir(workdir / 'tmp2') + list_retrieved_file = transport.listdir(workdir / 'tmp3') assert 'file.txt' in list_pushed_file assert 'file.txt' in list_retrieved_file shutil.rmtree(local_subfolder) shutil.rmtree(retrieved_subfolder) - transport.rmtree(remote_subfolder) + transport.rmtree(workdir / remote_subfolder) - transport.chdir('..') - transport.rmdir(directory) + transport.rmdir(workdir) @pytest.mark.parametrize( @@ -713,18 +579,16 @@ def test_put_and_get_overwrite( def test_copy(custom_transport): """Test copying.""" local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - transport.chdir(directory) + transport.mkdir(remote_dir / directory) + workdir = remote_dir / directory local_base_dir = os.path.join(local_dir, directory, 'local') os.mkdir(local_base_dir) @@ -739,43 +603,42 @@ def test_copy(custom_transport): fhandle.write(text) # first test the copy. Copy of two files matching patterns, into a folder - transport.copy(os.path.join('local', '*.txt'), '.') - assert set(['a.txt', 'c.txt', 'local']) == set(transport.listdir('.')) - transport.remove('a.txt') - transport.remove('c.txt') + transport.copy(workdir / 'local' / '*.txt', workdir) + assert set(['a.txt', 'c.txt', 'local']) == set(transport.listdir(workdir)) + transport.remove(workdir / 'a.txt') + transport.remove(workdir / 'c.txt') # second test copy. Copy of two folders - transport.copy('local', 'prova') - assert set(['prova', 'local']) == set(transport.listdir('.')) - assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir('prova')) - transport.rmtree('prova') + transport.copy(workdir / 'local', workdir / 'prova') + assert set(['prova', 'local']) == set(transport.listdir(workdir)) + assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(workdir / 'prova')) + transport.rmtree(workdir / 'prova') # third test copy. Can copy one file into a new file - transport.copy(os.path.join('local', '*.tmp'), 'prova') - assert set(['prova', 'local']) == set(transport.listdir('.')) - transport.remove('prova') + transport.copy(workdir / 'local' / '*.tmp', workdir / 'prova') + assert set(['prova', 'local']) == set(transport.listdir(workdir)) + transport.remove(workdir / 'prova') # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with pytest.raises(OSError): - transport.copy(os.path.join('local', '*.txt'), 'prova') + transport.copy(workdir / 'local' / '*.txt', workdir / 'prova') # fifth test, copying one file into a folder - transport.mkdir('prova') - transport.copy(os.path.join('local', 'a.txt'), 'prova') - assert set(transport.listdir('prova')) == set(['a.txt']) - transport.rmtree('prova') + transport.mkdir(workdir / 'prova') + transport.copy(workdir / 'local' / 'a.txt', workdir / 'prova') + assert set(transport.listdir(workdir / 'prova')) == set(['a.txt']) + transport.rmtree(workdir / 'prova') # sixth test, copying one file into a file - transport.copy(os.path.join('local', 'a.txt'), 'prova') - assert transport.isfile('prova') - transport.remove('prova') + transport.copy(workdir / 'local' / 'a.txt', workdir / 'prova') + assert transport.isfile(workdir / 'prova') + transport.remove(workdir / 'prova') # copy of folder into an existing folder # NOTE: the command cp has a different behavior on Mac vs Ubuntu # tests performed locally on a Mac may result in a failure. - transport.mkdir('prova') - transport.copy('local', 'prova') - assert set(['local']) == set(transport.listdir('prova')) - assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(os.path.join('prova', 'local'))) - transport.rmtree('prova') + transport.mkdir(workdir / 'prova') + transport.copy(workdir / 'local', workdir / 'prova') + assert set(['local']) == set(transport.listdir(workdir / 'prova')) + assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(workdir / 'prova' / 'local')) + transport.rmtree(workdir / 'prova') # exit - transport.chdir('..') - transport.rmtree(directory) + transport.rmtree(workdir) def test_put(custom_transport): @@ -783,18 +646,16 @@ def test_put(custom_transport): # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - transport.chdir(directory) + transport.mkdir(remote_dir / directory) + workdir = remote_dir / directory local_base_dir = os.path.join(local_dir, directory, 'local') os.mkdir(local_base_dir) @@ -809,49 +670,48 @@ def test_put(custom_transport): fhandle.write(text) # first test putransport. Copy of two files matching patterns, into a folder - transport.put(os.path.join(local_base_dir, '*.txt'), '.') - assert set(['a.txt', 'c.txt', 'local']) == set(transport.listdir('.')) - transport.remove('a.txt') - transport.remove('c.txt') + transport.put(os.path.join(local_base_dir, '*.txt'), workdir) + assert set(['a.txt', 'c.txt', 'local']) == set(transport.listdir(workdir)) + transport.remove(workdir / 'a.txt') + transport.remove(workdir / 'c.txt') # second. Copy of folder into a non existing folder - transport.put(local_base_dir, 'prova') - assert set(['prova', 'local']) == set(transport.listdir('.')) - assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir('prova')) - transport.rmtree('prova') + transport.put(local_base_dir, workdir / 'prova') + assert set(['prova', 'local']) == set(transport.listdir(workdir)) + assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(workdir / 'prova')) + transport.rmtree(workdir / 'prova') # third. copy of folder into an existing folder - transport.mkdir('prova') - transport.put(local_base_dir, 'prova') - assert set(['prova', 'local']) == set(transport.listdir('.')) - assert set(['local']) == set(transport.listdir('prova')) - assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(os.path.join('prova', 'local'))) - transport.rmtree('prova') + transport.mkdir(workdir / 'prova') + transport.put(local_base_dir, workdir / 'prova') + assert set(['prova', 'local']) == set(transport.listdir(workdir)) + assert set(['local']) == set(transport.listdir(workdir / 'prova')) + assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(workdir / 'prova' / 'local')) + transport.rmtree(workdir / 'prova') # third test copy. Can copy one file into a new file - transport.put(os.path.join(local_base_dir, '*.tmp'), 'prova') - assert set(['prova', 'local']) == set(transport.listdir('.')) - transport.remove('prova') + transport.put(os.path.join(local_base_dir, '*.tmp'), workdir / 'prova') + assert set(['prova', 'local']) == set(transport.listdir(workdir)) + transport.remove(workdir / 'prova') # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with pytest.raises(OSError): - transport.put(os.path.join(local_base_dir, '*.txt'), 'prova') + transport.put(os.path.join(local_base_dir, '*.txt'), workdir / 'prova') # copy of folder into file with open(os.path.join(local_dir, directory, 'existing.txt'), 'w', encoding='utf8') as fhandle: fhandle.write(text) with pytest.raises(OSError): - transport.put(os.path.join(local_base_dir), 'existing.txt') - transport.remove('existing.txt') + transport.put(os.path.join(local_base_dir), workdir / 'existing.txt') + transport.remove(workdir / 'existing.txt') # fifth test, copying one file into a folder - transport.mkdir('prova') - transport.put(os.path.join(local_base_dir, 'a.txt'), 'prova') - assert set(transport.listdir('prova')) == set(['a.txt']) - transport.rmtree('prova') + transport.mkdir(workdir / 'prova') + transport.put(os.path.join(local_base_dir, 'a.txt'), workdir / 'prova') + assert set(transport.listdir(workdir / 'prova')) == set(['a.txt']) + transport.rmtree(workdir / 'prova') # sixth test, copying one file into a file - transport.put(os.path.join(local_base_dir, 'a.txt'), 'prova') - assert transport.isfile('prova') - transport.remove('prova') + transport.put(os.path.join(local_base_dir, 'a.txt'), workdir / 'prova') + assert transport.isfile(workdir / 'prova') + transport.remove(workdir / 'prova') # exit - transport.chdir('..') - transport.rmtree(directory) + transport.rmtree(workdir) def test_get(custom_transport): @@ -859,18 +719,16 @@ def test_get(custom_transport): # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - transport.mkdir(directory) - transport.chdir(directory) + transport.mkdir(remote_dir / directory) + workdir = remote_dir / directory local_base_dir = os.path.join(local_dir, directory, 'local') local_destination = os.path.join(local_dir, directory) @@ -886,60 +744,57 @@ def test_get(custom_transport): fhandle.write(text) # first test put. Copy of two files matching patterns, into a folder - transport.get(os.path.join('local', '*.txt'), local_destination) + transport.get(workdir / 'local' / '*.txt', local_destination) assert set(['a.txt', 'c.txt', 'local']) == set(os.listdir(local_destination)) os.remove(os.path.join(local_destination, 'a.txt')) os.remove(os.path.join(local_destination, 'c.txt')) # second. Copy of folder into a non existing folder - transport.get('local', os.path.join(local_destination, 'prova')) + transport.get(workdir / 'local', os.path.join(local_destination, 'prova')) assert set(['prova', 'local']) == set(os.listdir(local_destination)) assert set(['a.txt', 'b.tmp', 'c.txt']) == set(os.listdir(os.path.join(local_destination, 'prova'))) shutil.rmtree(os.path.join(local_destination, 'prova')) # third. copy of folder into an existing folder os.mkdir(os.path.join(local_destination, 'prova')) - transport.get('local', os.path.join(local_destination, 'prova')) + transport.get(workdir / 'local', os.path.join(local_destination, 'prova')) assert set(['prova', 'local']) == set(os.listdir(local_destination)) assert set(['local']) == set(os.listdir(os.path.join(local_destination, 'prova'))) assert set(['a.txt', 'b.tmp', 'c.txt']) == set(os.listdir(os.path.join(local_destination, 'prova', 'local'))) shutil.rmtree(os.path.join(local_destination, 'prova')) # third test copy. Can copy one file into a new file - transport.get(os.path.join('local', '*.tmp'), os.path.join(local_destination, 'prova')) + transport.get(workdir / 'local' / '*.tmp', os.path.join(local_destination, 'prova')) assert set(['prova', 'local']) == set(os.listdir(local_destination)) os.remove(os.path.join(local_destination, 'prova')) # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with pytest.raises(OSError): - transport.get(os.path.join('local', '*.txt'), os.path.join(local_destination, 'prova')) + transport.get(workdir / 'local' / '*.txt', os.path.join(local_destination, 'prova')) # copy of folder into file with open(os.path.join(local_destination, 'existing.txt'), 'w', encoding='utf8') as fhandle: fhandle.write(text) with pytest.raises(OSError): - transport.get('local', os.path.join(local_destination, 'existing.txt')) + transport.get(workdir / 'local', os.path.join(local_destination, 'existing.txt')) os.remove(os.path.join(local_destination, 'existing.txt')) # fifth test, copying one file into a folder os.mkdir(os.path.join(local_destination, 'prova')) - transport.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) + transport.get(workdir / 'local' / 'a.txt', os.path.join(local_destination, 'prova')) assert set(os.listdir(os.path.join(local_destination, 'prova'))) == set(['a.txt']) shutil.rmtree(os.path.join(local_destination, 'prova')) # sixth test, copying one file into a file - transport.get(os.path.join('local', 'a.txt'), os.path.join(local_destination, 'prova')) + transport.get(workdir / 'local' / 'a.txt', os.path.join(local_destination, 'prova')) assert os.path.isfile(os.path.join(local_destination, 'prova')) os.remove(os.path.join(local_destination, 'prova')) # exit - transport.chdir('..') - transport.rmtree(directory) + transport.rmtree(workdir) def test_put_get_abs_path_tree(custom_transport): """Test of exception for non existing files and abs path""" local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) @@ -951,62 +806,59 @@ def test_put_get_abs_path_tree(custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - transport.chdir(directory) + workdir = remote_dir / directory local_file_name = os.path.join(local_subfolder, 'file.txt') - pathlib.Path(local_file_name).touch() + Path(local_file_name).touch() # 'tmp1' is not an abs path with pytest.raises(ValueError): - transport.put('tmp1', remote_subfolder) + transport.put('tmp1', workdir / remote_subfolder) with pytest.raises(ValueError): - transport.putfile('tmp1', remote_subfolder) + transport.putfile('tmp1', workdir / remote_subfolder) with pytest.raises(ValueError): - transport.puttree('tmp1', remote_subfolder) + transport.puttree('tmp1', workdir / remote_subfolder) # 'tmp3' does not exist with pytest.raises(OSError): - transport.put(retrieved_subfolder, remote_subfolder) + transport.put(retrieved_subfolder, workdir / remote_subfolder) with pytest.raises(OSError): - transport.putfile(retrieved_subfolder, remote_subfolder) + transport.putfile(retrieved_subfolder, workdir / remote_subfolder) with pytest.raises(OSError): - transport.puttree(retrieved_subfolder, remote_subfolder) + transport.puttree(retrieved_subfolder, workdir / remote_subfolder) # remote_file_name does not exist with pytest.raises(OSError): - transport.get('non_existing', retrieved_subfolder) + transport.get(workdir / 'non_existing', retrieved_subfolder) with pytest.raises(OSError): - transport.getfile('non_existing', retrieved_subfolder) + transport.getfile(workdir / 'non_existing', retrieved_subfolder) with pytest.raises(OSError): - transport.gettree('non_existing', retrieved_subfolder) + transport.gettree(workdir / 'non_existing', retrieved_subfolder) - transport.put(local_subfolder, remote_subfolder) + transport.put(local_subfolder, workdir / remote_subfolder) # local filename is not an abs path with pytest.raises(ValueError): - transport.get(remote_subfolder, 'delete_me_tree') + transport.get(workdir / remote_subfolder, 'delete_me_tree') with pytest.raises(ValueError): - transport.getfile(remote_subfolder, 'delete_me_tree') + transport.getfile(workdir / remote_subfolder, 'delete_me_tree') with pytest.raises(ValueError): - transport.gettree(remote_subfolder, 'delete_me_tree') + transport.gettree(workdir / remote_subfolder, 'delete_me_tree') os.remove(os.path.join(local_subfolder, 'file.txt')) os.rmdir(local_subfolder) - transport.rmtree(remote_subfolder) + transport.rmtree(workdir / remote_subfolder) - transport.chdir('..') - transport.rmdir(directory) + transport.rmdir(workdir) def test_put_get_empty_string_tree(custom_transport): """Test of exception put/get of empty strings""" # TODO : verify the correctness of \n at the end of a file local_dir = os.path.join('/', 'tmp') - remote_dir = local_dir + remote_dir = Path(local_dir) directory = 'tmp_try' with custom_transport as transport: - transport.chdir(remote_dir) - while os.path.exists(os.path.join(local_dir, directory)): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) @@ -1018,7 +870,7 @@ def test_put_get_empty_string_tree(custom_transport): os.mkdir(os.path.join(local_dir, directory)) os.mkdir(os.path.join(local_dir, directory, local_subfolder)) - transport.chdir(directory) + workdir = remote_dir / directory local_file_name = os.path.join(local_subfolder, 'file.txt') text = 'Viva Verdi\n' @@ -1028,13 +880,13 @@ def test_put_get_empty_string_tree(custom_transport): # localpath is an empty string # ValueError because it is not an abs path with pytest.raises(ValueError): - transport.puttree('', remote_subfolder) + transport.puttree('', workdir / remote_subfolder) # remote path is an empty string with pytest.raises(OSError): transport.puttree(local_subfolder, '') - transport.puttree(local_subfolder, remote_subfolder) + transport.puttree(local_subfolder, workdir / remote_subfolder) # remote path is an empty string with pytest.raises(OSError): @@ -1043,24 +895,23 @@ def test_put_get_empty_string_tree(custom_transport): # local path is an empty string # ValueError because it is not an abs path with pytest.raises(ValueError): - transport.gettree(remote_subfolder, '') + transport.gettree(workdir / remote_subfolder, '') # TODO : get doesn't retrieve empty files. # Is it what we want? - transport.gettree(remote_subfolder, retrieved_subfolder) + transport.gettree(workdir / remote_subfolder, retrieved_subfolder) os.remove(os.path.join(local_subfolder, 'file.txt')) os.rmdir(local_subfolder) - transport.remove(os.path.join(remote_subfolder, 'file.txt')) - transport.rmdir(remote_subfolder) + transport.remove(workdir / remote_subfolder / 'file.txt') + transport.rmdir(workdir / remote_subfolder) # If it couldn't end the copy, it leaves what he did on local file # here I am mixing local with remote - assert 'file.txt' in transport.listdir('tmp3') + assert 'file.txt' in transport.listdir(workdir / 'tmp3') os.remove(os.path.join(retrieved_subfolder, 'file.txt')) os.rmdir(retrieved_subfolder) - transport.chdir('..') - transport.rmdir(directory) + transport.rmdir(workdir) def test_gettree_nested_directory(custom_transport): @@ -1087,32 +938,36 @@ def test_exec_pwd(custom_transport): # Start value delete_at_end = False - with custom_transport as transport: - # To compare with: getcwd uses the normalized ('realpath') path - location = transport.normalize('/tmp') - subfolder = """_'s f"#""" # A folder with characters to escape - subfolder_fullpath = os.path.join(location, subfolder) - - transport.chdir(location) - if not transport.isdir(subfolder): - # Since I created the folder, I will remember to - # delete it at the end of this test - delete_at_end = True - transport.mkdir(subfolder) - - assert transport.isdir(subfolder) - transport.chdir(subfolder) - - assert subfolder_fullpath == transport.getcwd() - retcode, stdout, stderr = transport.exec_command_wait('pwd') - assert retcode == 0 - # I have to strip it because 'pwd' returns a trailing \n - assert stdout.strip() == subfolder_fullpath - assert stderr == '' + try: + with custom_transport as transport: + # To compare with: getcwd uses the normalized ('realpath') path + location = transport.normalize('/tmp') + subfolder = """_'s f"#""" # A folder with characters to escape + subfolder_fullpath = os.path.join(location, subfolder) - if delete_at_end: transport.chdir(location) - transport.rmdir(subfolder) + if not transport.isdir(subfolder): + # Since I created the folder, I will remember to + # delete it at the end of this test + delete_at_end = True + transport.mkdir(subfolder) + + assert transport.isdir(subfolder) + transport.chdir(subfolder) + + assert subfolder_fullpath == transport.getcwd() + retcode, stdout, stderr = transport.exec_command_wait('pwd') + assert retcode == 0 + # I have to strip it because 'pwd' returns a trailing \n + assert stdout.strip() == subfolder_fullpath + assert stderr == '' + + if delete_at_end: + transport.chdir(location) + transport.rmdir(subfolder) + except NotImplementedError: + # chdir() & getcwd() is no longer an abstract method, to be removed from interface + pass def test_exec_with_stdin_string(custom_transport): @@ -1131,6 +986,11 @@ def test_exec_with_stdin_bytes(custom_transport): I test directly the exec_command_wait_bytes function; I also pass some non-unicode bytes to check that there is no internal implicit encoding/decoding in the code. """ + + # Skip this test for AsyncSshTransport + if 'AsyncSshTransport' in custom_transport.__str__(): + return + test_string = b'some_test bytes with non-unicode -> \xfa' with custom_transport as transport: retcode, stdout, stderr = transport.exec_command_wait_bytes('cat', stdin=test_string) @@ -1141,6 +1001,11 @@ def test_exec_with_stdin_bytes(custom_transport): def test_exec_with_stdin_filelike(custom_transport): """Test command execution with a stdin from filelike.""" + + # Skip this test for AsyncSshTransport + if 'AsyncSshTransport' in custom_transport.__str__(): + return + test_string = 'some_test String' stdin = io.StringIO(test_string) with custom_transport as transport: @@ -1160,6 +1025,10 @@ def test_exec_with_stdin_filelike_bytes(custom_transport): cannot be decoded to UTF8). (Note: we cannot test for all encodings, we test for unicode hoping that this would already catch possible issues.) """ + # Skip this test for AsyncSshTransport + if 'AsyncSshTransport' in custom_transport.__str__(): + return + test_string = b'some_test bytes with non-unicode -> \xfa' stdin = io.BytesIO(test_string) with custom_transport as transport: @@ -1179,6 +1048,10 @@ def test_exec_with_stdin_filelike_bytes_decoding(custom_transport): cannot be decoded to UTF8). (Note: we cannot test for all encodings, we test for unicode hoping that this would already catch possible issues.) """ + # Skip this test for AsyncSshTransport + if 'AsyncSshTransport' in custom_transport.__str__(): + return + test_string = b'some_test bytes with non-unicode -> \xfa' stdin = io.BytesIO(test_string) with custom_transport as transport: @@ -1188,6 +1061,10 @@ def test_exec_with_stdin_filelike_bytes_decoding(custom_transport): def test_exec_with_wrong_stdin(custom_transport): """Test command execution with incorrect stdin string.""" + # Skip this test for AsyncSshTransport + if 'AsyncSshTransport' in custom_transport.__str__(): + return + # I pass a number with custom_transport as transport: with pytest.raises(ValueError): @@ -1209,25 +1086,23 @@ def test_transfer_big_stdout(custom_transport): line_repetitions = min_file_size_bytes // len(file_line_binary) + 1 fcontent = (file_line_binary * line_repetitions).decode('utf8') - with custom_transport as trans: + with custom_transport as transport: # We cannot use tempfile.mkdtemp because we're on a remote folder - location = trans.normalize(os.path.join('/', 'tmp')) - trans.chdir(location) - assert location == trans.getcwd() + location = Path(os.path.join('/', 'tmp')) directory = 'temp_dir_test_transfer_big_stdout' - while trans.isdir(directory): + while transport.isdir(location / directory): # I append a random letter/number until it is unique directory += random.choice(string.ascii_uppercase + string.digits) - trans.mkdir(directory) - trans.chdir(directory) + transport.mkdir(location / directory) + workdir = location / directory with tempfile.NamedTemporaryFile(mode='wb') as tmpf: tmpf.write(fcontent.encode('utf8')) tmpf.flush() # I put a file with specific content there at the right file name - trans.putfile(tmpf.name, fname) + transport.putfile(tmpf.name, workdir / fname) python_code = r"""import sys @@ -1251,16 +1126,18 @@ def test_transfer_big_stdout(custom_transport): tmpf.flush() # I put a file with specific content there at the right file name - trans.putfile(tmpf.name, script_fname) + transport.putfile(tmpf.name, workdir / script_fname) # I get its content via the stdout; emulate also network slowness (note I cat twice) - retcode, stdout, stderr = trans.exec_command_wait(f'cat {fname} ; sleep 1 ; cat {fname}') + retcode, stdout, stderr = transport.exec_command_wait(f'cat {fname} ; sleep 1 ; cat {fname}', workdir=workdir) assert stderr == '' assert stdout == fcontent + fcontent assert retcode == 0 # I get its content via the stderr; emulate also network slowness (note I cat twice) - retcode, stdout, stderr = trans.exec_command_wait(f'cat {fname} >&2 ; sleep 1 ; cat {fname} >&2') + retcode, stdout, stderr = transport.exec_command_wait( + f'cat {fname} >&2 ; sleep 1 ; cat {fname} >&2', workdir=workdir + ) assert stderr == fcontent + fcontent assert stdout == '' assert retcode == 0 @@ -1273,17 +1150,16 @@ def test_transfer_big_stdout(custom_transport): # line_repetitions, file_line, file_line)) # However this is pretty slow (and using 'cat' of a file containing only one line is even slower) - retcode, stdout, stderr = trans.exec_command_wait(f'python3 {script_fname}') + retcode, stdout, stderr = transport.exec_command_wait(f'python3 {script_fname}', workdir=workdir) assert stderr == fcontent assert stdout == fcontent assert retcode == 0 # Clean-up - trans.remove(fname) - trans.remove(script_fname) - trans.chdir('..') - trans.rmdir(directory) + transport.remove(workdir / fname) + transport.remove(workdir / script_fname) + transport.rmdir(workdir) def test_asynchronous_execution(custom_transport): diff --git a/tests/utils/processes.py b/tests/utils/processes.py index 3acf31714a..43582eea45 100644 --- a/tests/utils/processes.py +++ b/tests/utils/processes.py @@ -24,7 +24,7 @@ def define(cls, spec): spec.inputs.valid_type = Data spec.outputs.valid_type = Data - async def run(self): + def run(self): pass @@ -40,7 +40,7 @@ def define(cls, spec): spec.input('b', required=True) spec.output('result', required=True) - async def run(self): + def run(self): summed = self.inputs.a + self.inputs.b self.out(summed.store()) @@ -55,7 +55,7 @@ def define(cls, spec): super().define(spec) spec.outputs.valid_type = Data - async def run(self): + def run(self): self.out('bad_output', 5) @@ -64,7 +64,7 @@ class ExceptionProcess(Process): _node_class = WorkflowNode - async def run(self): + def run(self): raise RuntimeError('CRASH') @@ -73,7 +73,7 @@ class WaitProcess(Process): _node_class = WorkflowNode - async def run(self): + def run(self): return plumpy.Wait(self.next_step) def next_step(self): @@ -93,7 +93,7 @@ def define(cls, spec): 123, 'GENERIC_EXIT_CODE', message='This process should not be used as cache.', invalidates_cache=True ) - async def run(self): + def run(self): if self.inputs.return_exit_code: return self.exit_codes.GENERIC_EXIT_CODE @@ -108,7 +108,7 @@ def define(cls, spec): super().define(spec) spec.input('not_valid_cache', valid_type=Bool, default=lambda: Bool(False)) - async def run(self): + def run(self): pass @classmethod diff --git a/utils/dependency_management.py b/utils/dependency_management.py old mode 100755 new mode 100644 From 5fdde51318e726f5c3d70430f93bbebb532c3981 Mon Sep 17 00:00:00 2001 From: Ali Khosravi Date: Mon, 18 Nov 2024 17:34:44 +0100 Subject: [PATCH 05/29] asynchrounous counterparts are added to transport.py --- src/aiida/engine/daemon/execmanager.py | 58 ++-- src/aiida/engine/processes/calcjobs/tasks.py | 2 +- src/aiida/transports/plugins/ssh_async.py | 16 +- src/aiida/transports/transport.py | 266 ++++++++++++++----- 4 files changed, 235 insertions(+), 107 deletions(-) diff --git a/src/aiida/engine/daemon/execmanager.py b/src/aiida/engine/daemon/execmanager.py index fb24f8955b..d1fefcea28 100644 --- a/src/aiida/engine/daemon/execmanager.py +++ b/src/aiida/engine/daemon/execmanager.py @@ -105,7 +105,7 @@ async def upload_calculation( if dry_run: workdir = Path(folder.abspath) else: - remote_user = transport.whoami() + remote_user = await transport.whoami_async() remote_working_directory = computer.get_workdir().format(username=remote_user) if not remote_working_directory.strip(): raise exceptions.ConfigurationError( @@ -114,13 +114,13 @@ async def upload_calculation( ) # If it already exists, no exception is raised - if not transport.path_exists(remote_working_directory): + if not await transport.path_exists_async(remote_working_directory): logger.debug( f'[submission of calculation {node.pk}] Path ' f'{remote_working_directory} does not exist, trying to create it' ) try: - transport.makedirs(remote_working_directory) + await transport.makedirs_async(remote_working_directory) except EnvironmentError as exc: raise exceptions.ConfigurationError( f'[submission of calculation {node.pk}] ' @@ -133,14 +133,14 @@ async def upload_calculation( # and I do not have to know the logic, but I just need to # read the absolute path from the calculation properties. workdir = Path(remote_working_directory).joinpath(calc_info.uuid[:2], calc_info.uuid[2:4]) - transport.makedirs(str(workdir), ignore_existing=True) + await transport.makedirs_async(str(workdir), ignore_existing=True) try: # The final directory may already exist, most likely because this function was already executed once, but # failed and as a result was rescheduled by the engine. In this case it would be fine to delete the folder # and create it from scratch, except that we cannot be sure that this the actual case. Therefore, to err on # the safe side, we move the folder to the lost+found directory before recreating the folder from scratch - transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(str(workdir.joinpath(calc_info.uuid[4:]))) except OSError: # Move the existing directory to lost+found, log a warning and create a clean directory anyway path_existing = os.path.join(str(workdir), calc_info.uuid[4:]) @@ -151,12 +151,12 @@ async def upload_calculation( ) # Make sure the lost+found directory exists, then copy the existing folder there and delete the original - transport.mkdir(path_lost_found, ignore_existing=True) - transport.copytree(path_existing, path_target) - transport.rmtree(path_existing) + await transport.mkdir_async(path_lost_found, ignore_existing=True) + await transport.copytree_async(path_existing, path_target) + await transport.rmtree_async(path_existing) # Now we can create a clean folder for this calculation - transport.mkdir(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(str(workdir.joinpath(calc_info.uuid[4:]))) finally: workdir = workdir.joinpath(calc_info.uuid[4:]) @@ -171,11 +171,11 @@ async def upload_calculation( # Note: this will possibly overwrite files for root, dirnames, filenames in code.base.repository.walk(): # mkdir of root - transport.makedirs(str(workdir.joinpath(root)), ignore_existing=True) + await transport.makedirs_async(str(workdir.joinpath(root)), ignore_existing=True) # remotely mkdir first for dirname in dirnames: - transport.makedirs(str(workdir.joinpath(root, dirname)), ignore_existing=True) + await transport.makedirs_async(str(workdir.joinpath(root, dirname)), ignore_existing=True) # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in # combination with the new `Transport.put_object_from_filelike` @@ -187,9 +187,9 @@ async def upload_calculation( handle.flush() await transport.put_async(handle.name, str(workdir.joinpath(root, filename))) if code.filepath_executable.is_absolute(): - transport.chmod(str(code.filepath_executable), 0o755) # rwxr-xr-x + await transport.chmod_async(str(code.filepath_executable), 0o755) # rwxr-xr-x else: - transport.chmod(str(workdir.joinpath(code.filepath_executable)), 0o755) # rwxr-xr-x + await transport.chmod_async(str(workdir.joinpath(code.filepath_executable)), 0o755) # rwxr-xr-x # local_copy_list is a list of tuples, each with (uuid, dest_path, rel_path) # NOTE: validation of these lists are done inside calculation.presubmit() @@ -209,7 +209,7 @@ async def upload_calculation( await _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir) elif file_copy_operation is FileCopyOperation.REMOTE: if not dry_run: - _copy_remote_files( + await _copy_remote_files( logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir=workdir ) elif file_copy_operation is FileCopyOperation.SANDBOX: @@ -279,7 +279,7 @@ async def upload_calculation( return None -def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path): +async def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path): """Perform the copy instructions of the ``remote_copy_list`` and ``remote_symlink_list``.""" for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: if remote_computer_uuid == computer.uuid: @@ -288,7 +288,7 @@ def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remo f'remotely, directly on the machine {computer.label}' ) try: - transport.copy(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.copy_async(remote_abs_path, str(workdir.joinpath(dest_rel_path))) except FileNotFoundError: logger.warning( f'[submission of calculation {node.pk}] Unable to copy remote ' @@ -314,8 +314,8 @@ def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remo ) remote_dirname = Path(dest_rel_path).parent try: - transport.makedirs(str(workdir.joinpath(remote_dirname)), ignore_existing=True) - transport.symlink(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.makedirs_async(str(workdir.joinpath(remote_dirname)), ignore_existing=True) + await transport.symlink_async(remote_abs_path, str(workdir.joinpath(dest_rel_path))) except OSError: logger.warning( f'[submission of calculation {node.pk}] Unable to create remote symlink ' @@ -356,14 +356,18 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo # The logic below takes care of an edge case where the source is a file but the target is a directory. In # this case, the v2.5.1 implementation would raise an `IsADirectoryError` exception, because it would try # to open the directory in the sandbox folder as a file when writing the contents. - if file_type_source == FileType.FILE and target and transport.isdir(str(workdir.joinpath(target))): + if ( + file_type_source == FileType.FILE + and target + and await transport.isdir_async(str(workdir.joinpath(target))) + ): raise IsADirectoryError # In case the source filename is specified and it is a directory that already exists in the remote, we # want to avoid nested directories in the target path to replicate the behavior of v2.5.1. This is done by # setting the target filename to '.', which means the contents of the node will be copied in the top level # of the temporary directory, whose contents are then copied into the target directory. - if filename and transport.isdir(str(workdir.joinpath(filename))): + if filename and await transport.isdir_async(str(workdir.joinpath(filename))): filename_target = '.' filepath_target = (dirpath / filename_target).resolve().absolute() @@ -382,7 +386,7 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo with filepath_target.open('wb') as handle: with data_node.base.repository.open(filename_source, 'rb') as source: shutil.copyfileobj(source, handle) - transport.makedirs(str(workdir.joinpath(Path(target).parent)), ignore_existing=True) + await transport.makedirs_async(str(workdir.joinpath(Path(target).parent)), ignore_existing=True) await transport.put_async(str(filepath_target), str(workdir.joinpath(target))) @@ -423,7 +427,7 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | return result -def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: +async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: """Stash files from the working directory of a completed calculation to a permanent remote folder. After a calculation has been completed, optionally stash files from the work directory to a storage location on the @@ -461,7 +465,7 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: for source_filename in source_list: if transport.has_magic(source_filename): copy_instructions = [] - for globbed_filename in transport.glob(str(source_basepath / source_filename)): + for globbed_filename in await transport.glob_async(str(source_basepath / source_filename)): target_filepath = target_basepath / Path(globbed_filename).relative_to(source_basepath) copy_instructions.append((globbed_filename, target_filepath)) else: @@ -470,10 +474,10 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: for source_filepath, target_filepath in copy_instructions: # If the source file is in a (nested) directory, create those directories first in the target directory target_dirname = target_filepath.parent - transport.makedirs(str(target_dirname), ignore_existing=True) + await transport.makedirs_async(str(target_dirname), ignore_existing=True) try: - transport.copy(str(source_filepath), str(target_filepath)) + await transport.copy_async(str(source_filepath), str(target_filepath)) except (OSError, ValueError) as exception: EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') else: @@ -621,7 +625,7 @@ async def retrieve_files_from_list( tmp_rname, tmp_lname, depth = item # if there are more than one file I do something differently if transport.has_magic(tmp_rname): - remote_names = transport.glob(str(workdir.joinpath(tmp_rname))) + remote_names = await transport.glob_async(str(workdir.joinpath(tmp_rname))) local_names = [] for rem in remote_names: # get the relative path so to make local_names relative @@ -644,7 +648,7 @@ async def retrieve_files_from_list( abs_item = item if item.startswith('/') else str(workdir.joinpath(item)) if transport.has_magic(abs_item): - remote_names = transport.glob(abs_item) + remote_names = await transport.glob_async(abs_item) local_names = [os.path.split(rem)[1] for rem in remote_names] else: remote_names = [abs_item] diff --git a/src/aiida/engine/processes/calcjobs/tasks.py b/src/aiida/engine/processes/calcjobs/tasks.py index 45d5d98fa4..80617e3bfd 100644 --- a/src/aiida/engine/processes/calcjobs/tasks.py +++ b/src/aiida/engine/processes/calcjobs/tasks.py @@ -376,7 +376,7 @@ async def do_stash(): transport = await cancellable.with_interrupt(request) logger.info(f'stashing calculation<{node.pk}>') - return execmanager.stash_calculation(node, transport) + return await execmanager.stash_calculation(node, transport) try: await exponential_backoff_retry( diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 6255c4f421..4fe3afa126 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -561,7 +561,7 @@ async def isfile_async(self, path): return await self._sftp.isfile(path) - async def listdir_async(self, path, pattern=None): + async def listdir_async(self, path, pattern=None): # type: ignore[override] """ :param path: the absolute path to list @@ -579,7 +579,7 @@ async def listdir_async(self, path, pattern=None): return list_ - async def listdir_withattributes_async(self, path: _TransportPath, pattern: Optional[str] = None): + async def listdir_withattributes_async(self, path: _TransportPath, pattern: Optional[str] = None): # type: ignore[override] """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -791,13 +791,11 @@ async def chmod_async(self, path, mode, follow_symlinks=True): except asyncssh.sftp.SFTPNoSuchFile as exc: raise OSError(f'Error {exc}, directory does not exists') - # ## Blocking methods. We need these for backwards compatibility - # def run_command_blocking(self, func, *args, **kwargs): - # """Call an async method blocking. - # This is useful, only because in some part of engine and - # many external plugins are synchronous function calls make more sense. - # However, be aware these synchronous calls probably won't be performant.""" - # return asyncio.run(func(*args, **kwargs)) + ## Blocking methods. We need these for backwards compatibility + # This is useful, only because some part of engine and + # many external plugins are synchronous, in those cases blocking calls make more sense. + # However, be aware you cannot use these methods in an async functions, + # because they will block the event loop. def run_command_blocking(self, func, *args, **kwargs): loop = asyncio.get_event_loop() diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index 944d6225e3..e5836e3f09 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -380,57 +380,6 @@ def copy_from_remote_to_remote(self, transportdestination, remotesource, remoted for filename in sandbox.get_content_list(): transportdestination.put(os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put) - async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): - """Copy files or folders from a remote computer to another remote computer, asynchronously. - - :param transportdestination: transport to be used for the destination computer - :param str remotesource: path to the remote source directory / file - :param str remotedestination: path to the remote destination directory / file - :param kwargs: keyword parameters passed to the call to transportdestination.put, - except for 'dereference' that is passed to self.get - - .. note:: the keyword 'dereference' SHOULD be set to False for the - final put (onto the destination), while it can be set to the - value given in kwargs for the get from the source. In that - way, a symbolic link would never be followed in the final - copy to the remote destination. That way we could avoid getting - unknown (potentially malicious) files into the destination computer. - HOWEVER, since dereference=False is currently NOT - supported by all plugins, we still force it to True for the final put. - - .. note:: the supported keys in kwargs are callback, dereference, - overwrite and ignore_nonexisting. - """ - from aiida.common.folders import SandboxFolder - - kwargs_get = { - 'callback': None, - 'dereference': kwargs.pop('dereference', True), - 'overwrite': True, - 'ignore_nonexisting': False, - } - kwargs_put = { - 'callback': kwargs.pop('callback', None), - 'dereference': True, - 'overwrite': kwargs.pop('overwrite', True), - 'ignore_nonexisting': kwargs.pop('ignore_nonexisting', False), - } - - if kwargs: - self.logger.error('Unknown parameters passed to copy_from_remote_to_remote') - - with SandboxFolder() as sandbox: - await self.get_async(remotesource, sandbox.abspath, **kwargs_get) - # Then we scan the full sandbox directory with get_content_list, - # because copying directly from sandbox.abspath would not work - # to copy a single file into another single file, and copying - # from sandbox.get_abs_path('*') would not work for files - # beginning with a dot ('.'). - for filename in sandbox.get_content_list(): - await transportdestination.put_async( - os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put - ) - @abc.abstractmethod def _exec_command_internal(self, command, workdir=None, **kwargs): """Execute the command on the shell, similarly to os.system. @@ -502,15 +451,6 @@ def get(self, remotepath, localpath, *args, **kwargs): :param localpath: (str) local_folder_path """ - async def get_async(self, remotepath, localpath, *args, **kwargs): - """ - Retrieve a file or folder from remote source to local destination - dst must be an absolute path (src not necessarily) - :param remotepath: (str) remote_folder_path - :param localpath: (str) local_folder_path - """ - return self.get(remotepath, localpath, *args, **kwargs) - @abc.abstractmethod def getfile(self, remotepath, localpath, *args, **kwargs): """Retrieve a file from remote source to local destination @@ -695,16 +635,6 @@ def put(self, localpath, remotepath, *args, **kwargs): :param str remotepath: path to remote destination """ - async def put_async(self, localpath, remotepath, *args, **kwargs): - """ - Put a file or a directory from local src to remote dst. - src must be an absolute path (dst not necessarily)) - Redirects to putfile and puttree. - :param str localpath: absolute path to local source - :param str remotepath: path to remote destination - """ - return self.put(localpath, remotepath, *args, **kwargs) - @abc.abstractmethod def putfile(self, localpath, remotepath, *args, **kwargs): """Put a file from local src to remote dst. @@ -904,6 +834,202 @@ def _gotocomputer_string(self, remotedir): return connect_string + ## Here we bring the async counterparts of the methods + ## that some of them are not async yet. This is done by defining + ## a new method that calls the sync method and awaits. + ## It's up to the plugin to implement the async methods. + + async def open_async(self): + """Counterpart to open() that is async.""" + return self.open() + + async def close_async(self): + """Counterpart to close() that is async.""" + return self.close() + + async def chdir_async(self, path): + """Counterpart to chdir() that is async.""" + return self.chdir(path) + + async def chmod_async(self, path, mode): + """Counterpart to chmod() that is async.""" + return self.chmod(path, mode) + + async def chown_async(self, path, uid, gid): + """Counterpart to chown() that is async.""" + return self.chown(path, uid, gid) + + async def copy_async(self, remotesource, remotedestination, dereference=False, recursive=True): + """Counterpart to copy() that is async.""" + return self.copy(remotesource, remotedestination, dereference, recursive) + + async def copyfile_async(self, remotesource, remotedestination, dereference=False): + """Counterpart to copyfile() that is async.""" + return self.copyfile(remotesource, remotedestination, dereference) + + async def copytree_async(self, remotesource, remotedestination, dereference=False): + """Counterpart to copytree() that is async.""" + return self.copytree(remotesource, remotedestination, dereference) + + async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): + """Copy files or folders from a remote computer to another remote computer, asynchronously. + + :param transportdestination: transport to be used for the destination computer + :param str remotesource: path to the remote source directory / file + :param str remotedestination: path to the remote destination directory / file + :param kwargs: keyword parameters passed to the call to transportdestination.put, + except for 'dereference' that is passed to self.get + + .. note:: the keyword 'dereference' SHOULD be set to False for the + final put (onto the destination), while it can be set to the + value given in kwargs for the get from the source. In that + way, a symbolic link would never be followed in the final + copy to the remote destination. That way we could avoid getting + unknown (potentially malicious) files into the destination computer. + HOWEVER, since dereference=False is currently NOT + supported by all plugins, we still force it to True for the final put. + + .. note:: the supported keys in kwargs are callback, dereference, + overwrite and ignore_nonexisting. + """ + from aiida.common.folders import SandboxFolder + + kwargs_get = { + 'callback': None, + 'dereference': kwargs.pop('dereference', True), + 'overwrite': True, + 'ignore_nonexisting': False, + } + kwargs_put = { + 'callback': kwargs.pop('callback', None), + 'dereference': True, + 'overwrite': kwargs.pop('overwrite', True), + 'ignore_nonexisting': kwargs.pop('ignore_nonexisting', False), + } + + if kwargs: + self.logger.error('Unknown parameters passed to copy_from_remote_to_remote') + + with SandboxFolder() as sandbox: + await self.get_async(remotesource, sandbox.abspath, **kwargs_get) + # Then we scan the full sandbox directory with get_content_list, + # because copying directly from sandbox.abspath would not work + # to copy a single file into another single file, and copying + # from sandbox.get_abs_path('*') would not work for files + # beginning with a dot ('.'). + for filename in sandbox.get_content_list(): + await transportdestination.put_async( + os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put + ) + + async def exec_command_internal_async(self, command, workdir=None, **kwargs): + """Counterpart to _exec_command_internal() that is async.""" + return self._exec_command_internal(command, workdir, **kwargs) + + async def exec_command_wait_bytes_async(self, command, stdin=None, workdir=None, **kwargs): + """Counterpart to exec_command_wait_bytes() that is async.""" + return self.exec_command_wait_bytes(command, stdin, workdir, **kwargs) + + async def exec_command_wait_async(self, command, stdin=None, encoding='utf-8', workdir=None, **kwargs): + """Counterpart to exec_command_wait() that is async.""" + return self.exec_command_wait(command, stdin, encoding, workdir, **kwargs) + + async def get_async(self, remotepath, localpath, *args, **kwargs): + """Counterpart to get() that is async.""" + return self.get(remotepath, localpath, *args, **kwargs) + + async def getfile_async(self, remotepath, localpath, *args, **kwargs): + """Counterpart to getfile() that is async.""" + return self.getfile(remotepath, localpath, *args, **kwargs) + + async def gettree_async(self, remotepath, localpath, *args, **kwargs): + """Counterpart to gettree() that is async.""" + return self.gettree(remotepath, localpath, *args, **kwargs) + + async def getcwd_async(self): + """Counterpart to getcwd() that is async.""" + return self.getcwd() + + async def get_attribute_async(self, path): + """Counterpart to get_attribute() that is async.""" + return self.get_attribute(path) + + async def get_mode_async(self, path): + """Counterpart to get_mode() that is async.""" + return self.get_mode(path) + + async def isdir_async(self, path): + """Counterpart to isdir() that is async.""" + return self.isdir(path) + + async def isfile_async(self, path): + """Counterpart to isfile() that is async.""" + return self.isfile(path) + + async def listdir_async(self, path='.', pattern=None): + """Counterpart to listdir() that is async.""" + return self.listdir(path, pattern) + + async def listdir_withattributes_async(self, path: _TransportPath = '.', pattern=None): + """Counterpart to listdir_withattributes() that is async.""" + return self.listdir_withattributes(path, pattern) + + async def makedirs_async(self, path, ignore_existing=False): + """Counterpart to makedirs() that is async.""" + return self.makedirs(path, ignore_existing) + + async def mkdir_async(self, path, ignore_existing=False): + """Counterpart to mkdir() that is async.""" + return self.mkdir(path, ignore_existing) + + async def normalize_async(self, path='.'): + """Counterpart to normalize() that is async.""" + return self.normalize(path) + + async def put_async(self, localpath, remotepath, *args, **kwargs): + """Counterpart to put() that is async.""" + return self.put(localpath, remotepath, *args, **kwargs) + + async def putfile_async(self, localpath, remotepath, *args, **kwargs): + """Counterpart to putfile() that is async.""" + return self.putfile(localpath, remotepath, *args, **kwargs) + + async def puttree_async(self, localpath, remotepath, *args, **kwargs): + """Counterpart to puttree() that is async.""" + return self.puttree(localpath, remotepath, *args, **kwargs) + + async def remove_async(self, path): + """Counterpart to remove() that is async.""" + return self.remove(path) + + async def rename_async(self, oldpath, newpath): + """Counterpart to rename() that is async.""" + return self.rename(oldpath, newpath) + + async def rmdir_async(self, path): + """Counterpart to rmdir() that is async.""" + return self.rmdir(path) + + async def rmtree_async(self, path): + """Counterpart to rmtree() that is async.""" + return self.rmtree(path) + + async def symlink_async(self, remotesource, remotedestination): + """Counterpart to symlink() that is async.""" + return self.symlink(remotesource, remotedestination) + + async def whoami_async(self): + """Counterpart to whoami() that is async.""" + return self.whoami() + + async def path_exists_async(self, path): + """Counterpart to path_exists() that is async.""" + return self.path_exists(path) + + async def glob_async(self, pathname): + """Counterpart to glob() that is async.""" + return self.glob(pathname) + class TransportInternalError(InternalError): """Raised if there is a transport error that is raised to an internal error (e.g. From ccc545e61756eb4d8d22ebeda1ef807169c15686 Mon Sep 17 00:00:00 2001 From: Ali Khosravi Date: Tue, 19 Nov 2024 15:53:03 +0100 Subject: [PATCH 06/29] Giovanni's review applied --- src/aiida/transports/plugins/local.py | 4 +- src/aiida/transports/plugins/ssh.py | 4 +- src/aiida/transports/plugins/ssh_async.py | 67 +++- src/aiida/transports/transport.py | 405 ++++++++++++++++------ src/aiida/transports/util.py | 46 --- 5 files changed, 357 insertions(+), 169 deletions(-) diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index 1fa30f4650..c6e613a55b 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -18,11 +18,11 @@ from aiida.common.warnings import warn_deprecation from aiida.transports import cli as transport_cli -from aiida.transports.transport import Transport, TransportInternalError +from aiida.transports.transport import BlockingTransport, TransportInternalError # refactor or raise the limit: issue #1784 -class LocalTransport(Transport): +class LocalTransport(BlockingTransport): """Support copy and command execution on the same host on which AiiDA is running via direct file copy and execution commands. diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index 3279c4430f..16b7e20a8f 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -22,7 +22,7 @@ from aiida.common.escaping import escape_for_bash from aiida.common.warnings import warn_deprecation -from ..transport import Transport, TransportInternalError, _TransportPath, fix_path +from ..transport import BlockingTransport, TransportInternalError, _TransportPath, fix_path __all__ = ('parse_sshconfig', 'convert_to_bool', 'SshTransport') @@ -63,7 +63,7 @@ def convert_to_bool(string): raise ValueError('Invalid boolean value provided') -class SshTransport(Transport): +class SshTransport(BlockingTransport): """Support connection, command execution and data transfer to remote computers via SSH+SFTP.""" # Valid keywords accepted by the connect method of paramiko.SSHClient diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 4fe3afa126..b4391f0e7b 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -28,7 +28,7 @@ from aiida.common.escaping import escape_for_bash from aiida.common.exceptions import InvalidOperation -from ..transport import Transport, TransportInternalError, _TransportPath, fix_path +from ..transport import AsyncTransport, TransportInternalError, _TransportPath, fix_path __all__ = ('AsyncSshTransport',) @@ -61,7 +61,7 @@ async def attempt_connection(): return value -class AsyncSshTransport(Transport): +class AsyncSshTransport(AsyncTransport): """Transport plugin via SSH, asynchronously.""" # note, I intentionally wanted to keep connection parameters as simple as possible. @@ -110,9 +110,6 @@ def __init__(self, *args, **kwargs): self.script_before = kwargs.pop('script_before', 'None') self.script_during = kwargs.pop('script_during', 'None') - def __str__(self): - return f"{'OPEN' if self._is_open else 'CLOSED'} [AsyncSshTransport]" - async def open_async(self): if self._is_open: raise InvalidOperation('Cannot open the transport twice') @@ -139,6 +136,9 @@ async def close_async(self): await self._conn.wait_closed() self._is_open = False + def __str__(self): + return f"{'OPEN' if self._is_open else 'CLOSED'} [AsyncSshTransport]" + async def get_async(self, remotepath, localpath, dereference=True, overwrite=True, ignore_nonexisting=False): """Get a file or folder from remote to local. Redirects to getfile or gettree. @@ -791,6 +791,57 @@ async def chmod_async(self, path, mode, follow_symlinks=True): except asyncssh.sftp.SFTPNoSuchFile as exc: raise OSError(f'Error {exc}, directory does not exists') + async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): + """Copy files or folders from a remote computer to another remote computer, asynchronously. + + :param transportdestination: transport to be used for the destination computer + :param str remotesource: path to the remote source directory / file + :param str remotedestination: path to the remote destination directory / file + :param kwargs: keyword parameters passed to the call to transportdestination.put, + except for 'dereference' that is passed to self.get + + .. note:: the keyword 'dereference' SHOULD be set to False for the + final put (onto the destination), while it can be set to the + value given in kwargs for the get from the source. In that + way, a symbolic link would never be followed in the final + copy to the remote destination. That way we could avoid getting + unknown (potentially malicious) files into the destination computer. + HOWEVER, since dereference=False is currently NOT + supported by all plugins, we still force it to True for the final put. + + .. note:: the supported keys in kwargs are callback, dereference, + overwrite and ignore_nonexisting. + """ + from aiida.common.folders import SandboxFolder + + kwargs_get = { + 'callback': None, + 'dereference': kwargs.pop('dereference', True), + 'overwrite': True, + 'ignore_nonexisting': False, + } + kwargs_put = { + 'callback': kwargs.pop('callback', None), + 'dereference': True, + 'overwrite': kwargs.pop('overwrite', True), + 'ignore_nonexisting': kwargs.pop('ignore_nonexisting', False), + } + + if kwargs: + self.logger.error('Unknown parameters passed to copy_from_remote_to_remote') + + with SandboxFolder() as sandbox: + await self.get_async(remotesource, sandbox.abspath, **kwargs_get) + # Then we scan the full sandbox directory with get_content_list, + # because copying directly from sandbox.abspath would not work + # to copy a single file into another single file, and copying + # from sandbox.get_abs_path('*') would not work for files + # beginning with a dot ('.'). + for filename in sandbox.get_content_list(): + await transportdestination.put_async( + os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put + ) + ## Blocking methods. We need these for backwards compatibility # This is useful, only because some part of engine and # many external plugins are synchronous, in those cases blocking calls make more sense. @@ -905,9 +956,3 @@ def _exec_command_internal(self, *args, **kwargs): def normalize(self, *args, **kwargs): raise NotImplementedError('Not implemented, waiting for a use case') - - def chdir(self, *args, **kwargs): - raise NotImplementedError("It's not safe to chdir() for async transport") - - def getcwd(self, *args, **kwargs): - raise NotImplementedError("It's not safe to getcwd() for async transport") diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index e5836e3f09..f15b584dc2 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -9,6 +9,7 @@ """Transport interface.""" import abc +import asyncio import fnmatch import os import re @@ -47,8 +48,9 @@ def fix_path(path: _TransportPath) -> str: return str(path) -class Transport(abc.ABC): - """Abstract class for a generic transport (ssh, local, ...) contains the set of minimal methods.""" +class _BaseTransport: + """Abstract class for a generic blocking transport. + A plugin inhereting from this class should implement the blocking methods, only.""" # This will be used for ``Computer.get_minimum_job_poll_interval`` DEFAULT_MINIMUM_JOB_POLL_INTERVAL = 10 @@ -111,6 +113,14 @@ def __init__(self, *args, **kwargs): # for accessing the identity of the underlying machine self.hostname = kwargs.get('machine') + @abc.abstractmethod + def open(self): + """Opens a local transport channel""" + + @abc.abstractmethod + def close(self): + """Closes the local transport channel""" + def __enter__(self): """For transports that require opening a connection, opens all required channels (used in 'with' statements). @@ -150,21 +160,6 @@ def __exit__(self, type_, value, traceback): def is_open(self): return self._is_open - @abc.abstractmethod - def open(self): - """Opens a local transport channel""" - - @abc.abstractmethod - def close(self): - """Closes the local transport channel""" - - def __repr__(self): - return f'<{self.__class__.__name__}: {self!s}>' - - # redefine this in each subclass - def __str__(self): - return '[Transport class or subclass]' - def set_logger_extra(self, logger_extra): """Pass the data that should be passed automatically to self.logger as 'extra' keyword. This is typically useful if you pass data @@ -253,23 +248,42 @@ def get_safe_open_interval(self): """ return self._safe_open_interval - @abc.abstractmethod - def chdir(self, path): - """ - DEPRECATED: This method is deprecated and should be removed in the next major version. - PLEASE DON'T USE IT IN THE INTERFACE!! - - Change directory to 'path'. - :param str path: path to change working directory into. - :raises: OSError, if the requested path does not exist - :rtype: str - """ + def has_magic(self, string): + """Return True if the given string contains any special shell characters.""" + return self._MAGIC_CHECK.search(string) is not None - warn_deprecation( - '`chdir()` is deprecated and will be removed in the next major version.', - version=3, + def _gotocomputer_string(self, remotedir): + """Command executed when goto computer.""" + connect_string = ( + """ "if [ -d {escaped_remotedir} ] ;""" + """ then cd {escaped_remotedir} ; {bash_command} ; else echo ' ** The directory' ; """ + """echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format( + bash_command=self._bash_command_str, escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir + ) ) + return connect_string + + +class BlockingTransport(abc.ABC, _BaseTransport): + """Abstract class for a generic blocking transport. + A plugin inhereting from this class should implement the blocking methods, only.""" + + # This will be used for connection authentication + # To be defined in the subclass, the format is a list of tuples + # where the first element is the name of the parameter and the second + # is a dictionary with the following + # keys: 'default', 'prompt', 'help', 'non_interactive_default' + _valid_auth_options = [] + + @abc.abstractmethod + def __repr__(self): + return f'<{self.__class__.__name__}: {self!s}>' + + @abc.abstractmethod + def __str__(self): + return '[Transport class or subclass]' + @abc.abstractmethod def chmod(self, path, mode): """Change permissions of a path. @@ -331,6 +345,7 @@ def copytree(self, remotesource, remotedestination, dereference=False): :raise OSError: if one of src or dst does not exist """ + ## non-abtract methods. Plugin developers can safely ingore developing these methods def copy_from_remote_to_remote(self, transportdestination, remotesource, remotedestination, **kwargs): """Copy files or folders from a remote computer to another remote computer. @@ -818,26 +833,10 @@ def glob0(self, dirname, basename): return [basename] return [] - def has_magic(self, string): - """Return True if the given string contains any special shell characters.""" - return self._MAGIC_CHECK.search(string) is not None - - def _gotocomputer_string(self, remotedir): - """Command executed when goto computer.""" - connect_string = ( - """ "if [ -d {escaped_remotedir} ] ;""" - """ then cd {escaped_remotedir} ; {bash_command} ; else echo ' ** The directory' ; """ - """echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format( - bash_command=self._bash_command_str, escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir - ) - ) - - return connect_string - - ## Here we bring the async counterparts of the methods - ## that some of them are not async yet. This is done by defining - ## a new method that calls the sync method and awaits. - ## It's up to the plugin to implement the async methods. + ## Here we bring the async counterparts of the methods. + ## This is done by defining new methods that calls the sync method and awaits. + ## aiida-core engine is ultimately moving towards async, so this is a step in that direction. + ## To keep backward compatibility, the sync methods are kept as they are. async def open_async(self): """Counterpart to open() that is async.""" @@ -847,10 +846,6 @@ async def close_async(self): """Counterpart to close() that is async.""" return self.close() - async def chdir_async(self, path): - """Counterpart to chdir() that is async.""" - return self.chdir(path) - async def chmod_async(self, path, mode): """Counterpart to chmod() that is async.""" return self.chmod(path, mode) @@ -872,55 +867,8 @@ async def copytree_async(self, remotesource, remotedestination, dereference=Fals return self.copytree(remotesource, remotedestination, dereference) async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): - """Copy files or folders from a remote computer to another remote computer, asynchronously. - - :param transportdestination: transport to be used for the destination computer - :param str remotesource: path to the remote source directory / file - :param str remotedestination: path to the remote destination directory / file - :param kwargs: keyword parameters passed to the call to transportdestination.put, - except for 'dereference' that is passed to self.get - - .. note:: the keyword 'dereference' SHOULD be set to False for the - final put (onto the destination), while it can be set to the - value given in kwargs for the get from the source. In that - way, a symbolic link would never be followed in the final - copy to the remote destination. That way we could avoid getting - unknown (potentially malicious) files into the destination computer. - HOWEVER, since dereference=False is currently NOT - supported by all plugins, we still force it to True for the final put. - - .. note:: the supported keys in kwargs are callback, dereference, - overwrite and ignore_nonexisting. - """ - from aiida.common.folders import SandboxFolder - - kwargs_get = { - 'callback': None, - 'dereference': kwargs.pop('dereference', True), - 'overwrite': True, - 'ignore_nonexisting': False, - } - kwargs_put = { - 'callback': kwargs.pop('callback', None), - 'dereference': True, - 'overwrite': kwargs.pop('overwrite', True), - 'ignore_nonexisting': kwargs.pop('ignore_nonexisting', False), - } - - if kwargs: - self.logger.error('Unknown parameters passed to copy_from_remote_to_remote') - - with SandboxFolder() as sandbox: - await self.get_async(remotesource, sandbox.abspath, **kwargs_get) - # Then we scan the full sandbox directory with get_content_list, - # because copying directly from sandbox.abspath would not work - # to copy a single file into another single file, and copying - # from sandbox.get_abs_path('*') would not work for files - # beginning with a dot ('.'). - for filename in sandbox.get_content_list(): - await transportdestination.put_async( - os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put - ) + """Counterpart to copy_from_remote_to_remote().""" + return self.copy_from_remote_to_remote(transportdestination, remotesource, remotedestination, **kwargs) async def exec_command_internal_async(self, command, workdir=None, **kwargs): """Counterpart to _exec_command_internal() that is async.""" @@ -946,10 +894,6 @@ async def gettree_async(self, remotepath, localpath, *args, **kwargs): """Counterpart to gettree() that is async.""" return self.gettree(remotepath, localpath, *args, **kwargs) - async def getcwd_async(self): - """Counterpart to getcwd() that is async.""" - return self.getcwd() - async def get_attribute_async(self, path): """Counterpart to get_attribute() that is async.""" return self.get_attribute(path) @@ -1031,6 +975,251 @@ async def glob_async(self, pathname): return self.glob(pathname) +class AsyncTransport(abc.ABC, _BaseTransport): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @abc.abstractmethod + async def open_async(self): + """Open the transport.""" + + @abc.abstractmethod + async def close_async(self): + """Close the transport.""" + + @abc.abstractmethod + async def chmod_async(self, path, mode): + """Change permissions of a path.""" + + @abc.abstractmethod + async def chown_async(self, path, uid, gid): + """Change the owner (uid) and group (gid) of a file.""" + + @abc.abstractmethod + async def copy_async(self, remotesource, remotedestination, dereference=False, recursive=True): + """Copy a file or a directory from remote source to remote destination + (On the same remote machine)""" + + @abc.abstractmethod + async def copyfile_async(self, remotesource, remotedestination, dereference=False): + """Copy a file from remote source to remote destination + (On the same remote machine)""" + + @abc.abstractmethod + async def copytree_async(self, remotesource, remotedestination, dereference=False): + """Copy a folder from remote source to remote destination + (On the same remote machine)""" + + @abc.abstractmethod + async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): + """Copy files or folders from a remote computer to another remote computer.""" + + @abc.abstractmethod + async def exec_command_wait_async(self, command, stdin=None, encoding='utf-8', workdir=None, **kwargs): + """Executes the specified command and waits for it to finish.""" + + @abc.abstractmethod + async def get_async(self, remotepath, localpath, *args, **kwargs): + """Retrieve a file or folder from remote source to local destination""" + + @abc.abstractmethod + async def getfile_async(self, remotepath, localpath, *args, **kwargs): + """Retrieve a file from remote source to local destination""" + + @abc.abstractmethod + async def gettree_async(self, remotepath, localpath, *args, **kwargs): + """Retrieve a folder recursively from remote source to local destination""" + + @abc.abstractmethod + async def get_attribute_async(self, path): + """Return an object FixedFieldsAttributeDict for file in a given path""" + + @abc.abstractmethod + async def get_mode_async(self, path): + """Return the portion of the file's mode that can be set by chmod().""" + + @abc.abstractmethod + async def isdir_async(self, path): + """True if path is an existing directory.""" + + @abc.abstractmethod + async def isfile_async(self, path): + """Return True if path is an existing file.""" + + @abc.abstractmethod + async def listdir_async(self, path='.', pattern=None): + """Return a list of the names of the entries in the given path.""" + + @abc.abstractmethod + async def listdir_withattributes_async(self, path: _TransportPath = '.', pattern=None): + """Return a list of the names of the entries in the given path.""" + + @abc.abstractmethod + async def makedirs_async(self, path, ignore_existing=False): + """Super-mkdir; create a leaf directory and all intermediate ones.""" + + @abc.abstractmethod + async def mkdir_async(self, path, ignore_existing=False): + """Create a folder (directory) named path.""" + + @abc.abstractmethod + async def normalize_async(self, path='.'): + """Return the normalized path (on the server) of a given path.""" + + @abc.abstractmethod + async def put_async(self, localpath, remotepath, *args, **kwargs): + """Put a file or a directory from local src to remote dst.""" + + @abc.abstractmethod + async def putfile_async(self, localpath, remotepath, *args, **kwargs): + """Put a file from local src to remote dst.""" + + @abc.abstractmethod + async def puttree_async(self, localpath, remotepath, *args, **kwargs): + """Put a folder recursively from local src to remote dst.""" + + @abc.abstractmethod + async def remove_async(self, path): + """Remove the file at the given path.""" + + @abc.abstractmethod + async def rename_async(self, oldpath, newpath): + """Rename a file or folder from oldpath to newpath.""" + + @abc.abstractmethod + async def rmdir_async(self, path): + """Remove the folder named path.""" + + @abc.abstractmethod + async def rmtree_async(self, path): + """Remove recursively the content at path""" + + @abc.abstractmethod + async def symlink_async(self, remotesource, remotedestination): + """Create a symbolic link between the remote source and the remote destination.""" + + @abc.abstractmethod + async def whoami_async(self): + """Get the remote username""" + + @abc.abstractmethod + async def path_exists_async(self, path): + """Returns True if path exists, False otherwise.""" + + @abc.abstractmethod + async def glob_async(self, pathname): + """Return a list of paths matching a pathname pattern.""" + + @abc.abstractmethod + def gotocomputer_command(self, remotedir): + """Return a string to be run using os.system in order to connect + via the transport to the remote directory.""" + + ## Blocking counterpart methods. We need these for backwards compatibility + # This is useful, only because some part of engine and + # many external plugins are synchronous, in those cases blocking calls make more sense. + # However, be aware you cannot use these methods in an async functions, + # because they will block the event loop. + + def run_command_blocking(self, func, *args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete(func(*args, **kwargs)) + + def open(self): + return self.run_command_blocking(self.open_async) + + def close(self): + return self.run_command_blocking(self.close_async) + + def chown(self, *args, **kwargs): + raise NotImplementedError('Not implemented, for now') + + def get(self, *args, **kwargs): + return self.run_command_blocking(self.get_async, *args, **kwargs) + + def getfile(self, *args, **kwargs): + return self.run_command_blocking(self.getfile_async, *args, **kwargs) + + def gettree(self, *args, **kwargs): + return self.run_command_blocking(self.gettree_async, *args, **kwargs) + + def put(self, *args, **kwargs): + return self.run_command_blocking(self.put_async, *args, **kwargs) + + def putfile(self, *args, **kwargs): + return self.run_command_blocking(self.putfile_async, *args, **kwargs) + + def puttree(self, *args, **kwargs): + return self.run_command_blocking(self.puttree_async, *args, **kwargs) + + def chmod(self, *args, **kwargs): + return self.run_command_blocking(self.chmod_async, *args, **kwargs) + + def copy(self, *args, **kwargs): + return self.run_command_blocking(self.copy_async, *args, **kwargs) + + def copyfile(self, *args, **kwargs): + return self.copy(*args, **kwargs) + + def copytree(self, *args, **kwargs): + return self.copy(*args, **kwargs) + + def exec_command_wait(self, *args, **kwargs): + return self.run_command_blocking(self.exec_command_wait_async, *args, **kwargs) + + def get_attribute(self, *args, **kwargs): + return self.run_command_blocking(self.get_attribute_async, *args, **kwargs) + + def isdir(self, *args, **kwargs): + return self.run_command_blocking(self.isdir_async, *args, **kwargs) + + def isfile(self, *args, **kwargs): + return self.run_command_blocking(self.isfile_async, *args, **kwargs) + + def listdir(self, *args, **kwargs): + return self.run_command_blocking(self.listdir_async, *args, **kwargs) + + def listdir_withattributes(self, *args, **kwargs): + return self.run_command_blocking(self.listdir_withattributes_async, *args, **kwargs) + + def makedirs(self, *args, **kwargs): + return self.run_command_blocking(self.makedirs_async, *args, **kwargs) + + def mkdir(self, *args, **kwargs): + return self.run_command_blocking(self.mkdir_async, *args, **kwargs) + + def remove(self, *args, **kwargs): + return self.run_command_blocking(self.remove_async, *args, **kwargs) + + def rename(self, *args, **kwargs): + return self.run_command_blocking(self.rename_async, *args, **kwargs) + + def rmdir(self, *args, **kwargs): + return self.run_command_blocking(self.rmdir_async, *args, **kwargs) + + def rmtree(self, *args, **kwargs): + return self.run_command_blocking(self.rmtree_async, *args, **kwargs) + + def path_exists(self, *args, **kwargs): + return self.run_command_blocking(self.path_exists_async, *args, **kwargs) + + def whoami(self, *args, **kwargs): + return self.run_command_blocking(self.whoami_async, *args, **kwargs) + + def symlink(self, *args, **kwargs): + return self.run_command_blocking(self.symlink_async, *args, **kwargs) + + def glob(self, *args, **kwargs): + return self.run_command_blocking(self.glob_async, *args, **kwargs) + + def normalize(self, *args, **kwargs): + return self.run_command_blocking(self.normalize_async, *args, **kwargs) + + +# This is here for backwards compatibility +Transport = BlockingTransport + + class TransportInternalError(InternalError): """Raised if there is a transport error that is raised to an internal error (e.g. a transport method called without opening the channel first). diff --git a/src/aiida/transports/util.py b/src/aiida/transports/util.py index a75547a0e3..dfda089e95 100644 --- a/src/aiida/transports/util.py +++ b/src/aiida/transports/util.py @@ -9,58 +9,12 @@ """General utilities for Transport classes.""" import time -from pathlib import Path, PurePosixPath -from typing import Union from paramiko import ProxyCommand from aiida.common.extendeddicts import FixedFieldsAttributeDict -class StrPath: - """A class to chain paths together. - This is useful to avoid the need to use os.path.join to chain paths. - - Note: - Eventually transport plugins may further develope so that functions with pathlib.Path - So far they are expected to work only with POSIX paths. - This class is a solution to avoid the need to use Path.join to chain paths and convert back again to str. - """ - - def __init__(self, path: Union[str, PurePosixPath]): - """Chain a path with multiple paths. - - :param path: the initial path (absolute) - """ - if isinstance(path, PurePosixPath): - path = str(path) - self.str = path.rstrip('/') - - def join(self, *paths: Union[str, PurePosixPath, Path], return_str=True) -> Union[str, 'StrPath']: - """Join the initial path with multiple paths. - - :param paths: the paths to chain (relative to the previous path) - :param paths: It should be of type str or Path or PurePosixPath - :param return_str: if True, return a string, otherwise return a new StrPath object - - :return: a new StrPath object - """ - path = self.str - for p in paths: - p_ = str(p) if isinstance(p, (PurePosixPath, Path)) else p - if self.str in p_: - raise ValueError( - 'The path to join is already included in the initial path, ' - 'probably you are trying to join an absolute path' - ) - path = f"{path}/{p_.strip('/')}" - - if return_str: - return path - - return StrPath(path) - - class FileAttribute(FixedFieldsAttributeDict): """A class, resembling a dictionary, to describe the attributes of a file, that is returned by get_attribute(). From f187fdc4656ebd55c992dd4f211007114bd6f1dc Mon Sep 17 00:00:00 2001 From: Ali Khosravi Date: Tue, 19 Nov 2024 16:44:23 +0100 Subject: [PATCH 07/29] adopted tests --- environment.yml | 3 +- pyproject.toml | 4 +- requirements/requirements-py-3.10.txt | 3 +- requirements/requirements-py-3.11.txt | 3 +- requirements/requirements-py-3.12.txt | 3 +- requirements/requirements-py-3.9.txt | 3 +- src/aiida/calculations/monitors/base.py | 5 +- src/aiida/engine/daemon/execmanager.py | 66 ++++--- .../engine/processes/calcjobs/monitors.py | 9 +- src/aiida/engine/transports.py | 8 +- src/aiida/orm/authinfos.py | 6 +- src/aiida/orm/computers.py | 6 +- src/aiida/orm/nodes/data/remote/base.py | 3 +- .../orm/nodes/process/calculation/calcjob.py | 7 +- src/aiida/orm/utils/remote.py | 7 +- src/aiida/plugins/factories.py | 18 +- src/aiida/schedulers/scheduler.py | 5 +- src/aiida/tools/pytest_fixtures/__init__.py | 2 + src/aiida/tools/pytest_fixtures/orm.py | 32 ++++ src/aiida/transports/__init__.py | 6 +- src/aiida/transports/plugins/ssh.py | 4 +- src/aiida/transports/plugins/ssh_async.py | 161 +++++------------- src/aiida/transports/transport.py | 50 +++--- tests/engine/daemon/test_execmanager.py | 38 +++-- tests/manage/tests/test_pytest_fixtures.py | 29 +++- tests/orm/test_computers.py | 2 +- tests/plugins/test_factories.py | 12 +- tests/test_calculation_node.py | 4 +- tests/transports/test_all_plugins.py | 25 +-- utils/dependency_management.py | 15 +- 30 files changed, 292 insertions(+), 247 deletions(-) diff --git a/environment.yml b/environment.yml index 86eee6f90b..e517f374d1 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - python~=3.9 - alembic~=1.2 - archive-path~=0.4.2 +- asyncssh~=2.18.0 - circus~=0.18.0 - click-spinner~=0.1.8 - click~=8.1 @@ -22,7 +23,7 @@ dependencies: - importlib-metadata~=6.0 - numpy~=1.21 - paramiko~=3.0 -- plumpy~=0.22.3 +- plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy - pgsu~=0.3.0 - psutil~=5.6 - psycopg[binary]~=3.0 diff --git a/pyproject.toml b/pyproject.toml index ef7176fc4c..18ab84e069 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ dependencies = [ 'alembic~=1.2', 'archive-path~=0.4.2', + 'asyncssh~=2.18.0', 'circus~=0.18.0', 'click-spinner~=0.1.8', 'click~=8.1', @@ -34,7 +35,7 @@ dependencies = [ 'importlib-metadata~=6.0', 'numpy~=1.21', 'paramiko~=3.0', - 'plumpy~=0.22.3', + 'plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy', 'pgsu~=0.3.0', 'psutil~=5.6', 'psycopg[binary]~=3.0', @@ -306,6 +307,7 @@ module = 'tests.*' ignore_missing_imports = true module = [ 'ase.*', + 'asyncssh.*', 'bpython.*', 'bs4.*', 'CifFile.*', diff --git a/requirements/requirements-py-3.10.txt b/requirements/requirements-py-3.10.txt index b2408e8087..6bc2c11ef2 100644 --- a/requirements/requirements-py-3.10.txt +++ b/requirements/requirements-py-3.10.txt @@ -20,6 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.2.1 async-generator==1.10 +asyncssh~=2.18.0 attrs==23.1.0 babel==2.12.1 backcall==0.2.0 @@ -120,7 +121,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy==0.22.3 +plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.11.txt b/requirements/requirements-py-3.11.txt index 24acc25a6b..036a8b6242 100644 --- a/requirements/requirements-py-3.11.txt +++ b/requirements/requirements-py-3.11.txt @@ -20,6 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.2.1 async-generator==1.10 +asyncssh~=2.18.0 attrs==23.1.0 babel==2.12.1 backcall==0.2.0 @@ -119,7 +120,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy==0.22.3 +plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.12.txt b/requirements/requirements-py-3.12.txt index 3f5d72ebb6..7ef132ef76 100644 --- a/requirements/requirements-py-3.12.txt +++ b/requirements/requirements-py-3.12.txt @@ -20,6 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.4.0 async-generator==1.10 +asyncssh~=2.18.0 attrs==23.1.0 babel==2.13.1 backcall==0.2.0 @@ -119,7 +120,7 @@ pillow==10.1.0 platformdirs==3.11.0 plotly==5.17.0 pluggy==1.3.0 -plumpy==0.22.3 +plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.6 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index 3087e62844..1d837f5469 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -20,6 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.2.1 async-generator==1.10 +asyncssh~=2.18.0 attrs==23.1.0 babel==2.12.1 backcall==0.2.0 @@ -122,7 +123,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy==0.22.3 +plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/src/aiida/calculations/monitors/base.py b/src/aiida/calculations/monitors/base.py index 459f4eba9d..588b0debb6 100644 --- a/src/aiida/calculations/monitors/base.py +++ b/src/aiida/calculations/monitors/base.py @@ -4,12 +4,13 @@ import tempfile from pathlib import Path +from typing import Union from aiida.orm import CalcJobNode -from aiida.transports import Transport +from aiida.transports import AsyncTransport, BlockingTransport -def always_kill(node: CalcJobNode, transport: Transport) -> str | None: +def always_kill(node: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> str | None: """Retrieve and inspect files in working directory of job to determine whether the job should be killed. This particular implementation is just for demonstration purposes and will kill the job as long as there is a diff --git a/src/aiida/engine/daemon/execmanager.py b/src/aiida/engine/daemon/execmanager.py index d1fefcea28..18e650ebae 100644 --- a/src/aiida/engine/daemon/execmanager.py +++ b/src/aiida/engine/daemon/execmanager.py @@ -35,7 +35,7 @@ from aiida.schedulers.datastructures import JobState if TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found' @@ -64,7 +64,7 @@ def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: async def upload_calculation( node: CalcJobNode, - transport: Transport, + transport: Union['BlockingTransport', 'AsyncTransport'], calc_info: CalcInfo, folder: Folder, inputs: Optional[MappingType[str, Any]] = None, @@ -133,14 +133,14 @@ async def upload_calculation( # and I do not have to know the logic, but I just need to # read the absolute path from the calculation properties. workdir = Path(remote_working_directory).joinpath(calc_info.uuid[:2], calc_info.uuid[2:4]) - await transport.makedirs_async(str(workdir), ignore_existing=True) + await transport.makedirs_async(workdir, ignore_existing=True) try: # The final directory may already exist, most likely because this function was already executed once, but # failed and as a result was rescheduled by the engine. In this case it would be fine to delete the folder # and create it from scratch, except that we cannot be sure that this the actual case. Therefore, to err on # the safe side, we move the folder to the lost+found directory before recreating the folder from scratch - await transport.mkdir_async(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(workdir.joinpath(calc_info.uuid[4:])) except OSError: # Move the existing directory to lost+found, log a warning and create a clean directory anyway path_existing = os.path.join(str(workdir), calc_info.uuid[4:]) @@ -156,7 +156,7 @@ async def upload_calculation( await transport.rmtree_async(path_existing) # Now we can create a clean folder for this calculation - await transport.mkdir_async(str(workdir.joinpath(calc_info.uuid[4:]))) + await transport.mkdir_async(workdir.joinpath(calc_info.uuid[4:])) finally: workdir = workdir.joinpath(calc_info.uuid[4:]) @@ -171,11 +171,11 @@ async def upload_calculation( # Note: this will possibly overwrite files for root, dirnames, filenames in code.base.repository.walk(): # mkdir of root - await transport.makedirs_async(str(workdir.joinpath(root)), ignore_existing=True) + await transport.makedirs_async(workdir.joinpath(root), ignore_existing=True) # remotely mkdir first for dirname in dirnames: - await transport.makedirs_async(str(workdir.joinpath(root, dirname)), ignore_existing=True) + await transport.makedirs_async(workdir.joinpath(root, dirname), ignore_existing=True) # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in # combination with the new `Transport.put_object_from_filelike` @@ -185,11 +185,11 @@ async def upload_calculation( content = code.base.repository.get_object_content(Path(root) / filename, mode='rb') handle.write(content) handle.flush() - await transport.put_async(handle.name, str(workdir.joinpath(root, filename))) + await transport.put_async(handle.name, workdir.joinpath(root, filename)) if code.filepath_executable.is_absolute(): - await transport.chmod_async(str(code.filepath_executable), 0o755) # rwxr-xr-x + await transport.chmod_async(code.filepath_executable, 0o755) # rwxr-xr-x else: - await transport.chmod_async(str(workdir.joinpath(code.filepath_executable)), 0o755) # rwxr-xr-x + await transport.chmod_async(workdir.joinpath(code.filepath_executable), 0o755) # rwxr-xr-x # local_copy_list is a list of tuples, each with (uuid, dest_path, rel_path) # NOTE: validation of these lists are done inside calculation.presubmit() @@ -288,7 +288,7 @@ async def _copy_remote_files(logger, node, computer, transport, remote_copy_list f'remotely, directly on the machine {computer.label}' ) try: - await transport.copy_async(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.copy_async(remote_abs_path, workdir.joinpath(dest_rel_path)) except FileNotFoundError: logger.warning( f'[submission of calculation {node.pk}] Unable to copy remote ' @@ -314,8 +314,8 @@ async def _copy_remote_files(logger, node, computer, transport, remote_copy_list ) remote_dirname = Path(dest_rel_path).parent try: - await transport.makedirs_async(str(workdir.joinpath(remote_dirname)), ignore_existing=True) - await transport.symlink_async(remote_abs_path, str(workdir.joinpath(dest_rel_path))) + await transport.makedirs_async(workdir.joinpath(remote_dirname), ignore_existing=True) + await transport.symlink_async(remote_abs_path, workdir.joinpath(dest_rel_path)) except OSError: logger.warning( f'[submission of calculation {node.pk}] Unable to create remote symlink ' @@ -356,18 +356,14 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo # The logic below takes care of an edge case where the source is a file but the target is a directory. In # this case, the v2.5.1 implementation would raise an `IsADirectoryError` exception, because it would try # to open the directory in the sandbox folder as a file when writing the contents. - if ( - file_type_source == FileType.FILE - and target - and await transport.isdir_async(str(workdir.joinpath(target))) - ): + if file_type_source == FileType.FILE and target and await transport.isdir_async(workdir.joinpath(target)): raise IsADirectoryError # In case the source filename is specified and it is a directory that already exists in the remote, we # want to avoid nested directories in the target path to replicate the behavior of v2.5.1. This is done by # setting the target filename to '.', which means the contents of the node will be copied in the top level # of the temporary directory, whose contents are then copied into the target directory. - if filename and await transport.isdir_async(str(workdir.joinpath(filename))): + if filename and await transport.isdir_async(workdir.joinpath(filename)): filename_target = '.' filepath_target = (dirpath / filename_target).resolve().absolute() @@ -378,7 +374,7 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo data_node.base.repository.copy_tree(filepath_target, filename_source) await transport.put_async( f'{dirpath}/*', - str(workdir.joinpath(target)) if target else str(workdir.joinpath('.')), + workdir.joinpath(target) if target else workdir.joinpath('.'), overwrite=True, ) else: @@ -386,18 +382,20 @@ async def _copy_local_files(logger, node, transport, inputs, local_copy_list, wo with filepath_target.open('wb') as handle: with data_node.base.repository.open(filename_source, 'rb') as source: shutil.copyfileobj(source, handle) - await transport.makedirs_async(str(workdir.joinpath(Path(target).parent)), ignore_existing=True) - await transport.put_async(str(filepath_target), str(workdir.joinpath(target))) + await transport.makedirs_async(workdir.joinpath(Path(target).parent), ignore_existing=True) + await transport.put_async(filepath_target, workdir.joinpath(target)) async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path): """Copy the contents of the sandbox folder to the working directory.""" for filename in folder.get_content_list(): logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...') - await transport.put_async(folder.get_abs_path(filename), str(workdir.joinpath(filename))) + await transport.put_async(folder.get_abs_path(filename), workdir.joinpath(filename)) -def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | ExitCode: +def submit_calculation( + calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'] +) -> str | ExitCode: """Submit a previously uploaded `CalcJob` to the scheduler. :param calculation: the instance of CalcJobNode to submit. @@ -427,7 +425,7 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | return result -async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: +async def stash_calculation(calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> None: """Stash files from the working directory of a completed calculation to a permanent remote folder. After a calculation has been completed, optionally stash files from the work directory to a storage location on the @@ -465,7 +463,7 @@ async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> N for source_filename in source_list: if transport.has_magic(source_filename): copy_instructions = [] - for globbed_filename in await transport.glob_async(str(source_basepath / source_filename)): + for globbed_filename in await transport.glob_async(source_basepath / source_filename): target_filepath = target_basepath / Path(globbed_filename).relative_to(source_basepath) copy_instructions.append((globbed_filename, target_filepath)) else: @@ -474,10 +472,10 @@ async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> N for source_filepath, target_filepath in copy_instructions: # If the source file is in a (nested) directory, create those directories first in the target directory target_dirname = target_filepath.parent - await transport.makedirs_async(str(target_dirname), ignore_existing=True) + await transport.makedirs_async(target_dirname, ignore_existing=True) try: - await transport.copy_async(str(source_filepath), str(target_filepath)) + await transport.copy_async(source_filepath, target_filepath) except (OSError, ValueError) as exception: EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') else: @@ -493,7 +491,7 @@ async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> N async def retrieve_calculation( - calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str + calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'], retrieved_temporary_folder: str ) -> FolderData | None: """Retrieve all the files of a completed job calculation using the given transport. @@ -558,7 +556,7 @@ async def retrieve_calculation( return retrieved_files -def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: +def kill_calculation(calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> None: """Kill the calculation through the scheduler :param calculation: the instance of CalcJobNode to kill. @@ -593,7 +591,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: async def retrieve_files_from_list( calculation: CalcJobNode, - transport: Transport, + transport: Union['BlockingTransport', 'AsyncTransport'], folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], list]], ) -> None: @@ -616,7 +614,7 @@ async def retrieve_files_from_list( upto what level of the original remotepath nesting the files will be copied. :param transport: the Transport instance. - :param folder: an absolute path to a folder that contains the files to copy. + :param folder: an absolute path to a folder that contains the files to retrieve. :param retrieve_list: the list of files to retrieve. """ workdir = Path(calculation.get_remote_workdir()) @@ -625,7 +623,7 @@ async def retrieve_files_from_list( tmp_rname, tmp_lname, depth = item # if there are more than one file I do something differently if transport.has_magic(tmp_rname): - remote_names = await transport.glob_async(str(workdir.joinpath(tmp_rname))) + remote_names = await transport.glob_async(workdir.joinpath(tmp_rname)) local_names = [] for rem in remote_names: # get the relative path so to make local_names relative @@ -660,6 +658,6 @@ async def retrieve_files_from_list( if rem.startswith('/'): to_get = rem else: - to_get = str(workdir.joinpath(rem)) + to_get = workdir.joinpath(rem) await transport.get_async(to_get, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/src/aiida/engine/processes/calcjobs/monitors.py b/src/aiida/engine/processes/calcjobs/monitors.py index 507122ff1e..a9d2853b1d 100644 --- a/src/aiida/engine/processes/calcjobs/monitors.py +++ b/src/aiida/engine/processes/calcjobs/monitors.py @@ -8,6 +8,7 @@ import inspect import typing as t from datetime import datetime, timedelta +from typing import Union from aiida.common.lang import type_check from aiida.common.log import AIIDA_LOGGER @@ -15,7 +16,7 @@ from aiida.plugins import BaseFactory if t.TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport LOGGER = AIIDA_LOGGER.getChild(__name__) @@ -122,7 +123,9 @@ def validate(self): parameters = list(signature.parameters.keys()) if any(required_parameter not in parameters for required_parameter in ('node', 'transport')): - correct_signature = '(node: CalcJobNode, transport: Transport, **kwargs) str | None:' + correct_signature = ( + "(node: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'], **kwargs) str | None:" + ) raise ValueError( f'The monitor `{self.entry_point}` has an invalid function signature, it should be: {correct_signature}' ) @@ -176,7 +179,7 @@ def monitors(self) -> collections.OrderedDict: def process( self, node: CalcJobNode, - transport: Transport, + transport: Union['BlockingTransport', 'AsyncTransport'], ) -> CalcJobMonitorResult | None: """Call all monitors in order and return the result as one returns anything other than ``None``. diff --git a/src/aiida/engine/transports.py b/src/aiida/engine/transports.py index fe32df7884..33e43e5b62 100644 --- a/src/aiida/engine/transports.py +++ b/src/aiida/engine/transports.py @@ -13,12 +13,12 @@ import contextvars import logging import traceback -from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional +from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional, Union from aiida.orm import AuthInfo if TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport _LOGGER = logging.getLogger(__name__) @@ -54,7 +54,9 @@ def loop(self) -> asyncio.AbstractEventLoop: return self._loop @contextlib.contextmanager - def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable['Transport']]: + def request_transport( + self, authinfo: AuthInfo + ) -> Iterator[Awaitable[Union['BlockingTransport', 'AsyncTransport']]]: """Request a transport from an authinfo. Because the client is not allowed to request a transport immediately they will instead be given back a future that can be awaited to get the transport:: diff --git a/src/aiida/orm/authinfos.py b/src/aiida/orm/authinfos.py index e87be97367..3d8a45afa0 100644 --- a/src/aiida/orm/authinfos.py +++ b/src/aiida/orm/authinfos.py @@ -8,7 +8,7 @@ ########################################################################### """Module for the `AuthInfo` ORM class.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union from aiida.common import exceptions from aiida.manage import get_manager @@ -21,7 +21,7 @@ from aiida.orm import Computer, User from aiida.orm.implementation import StorageBackend from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401 - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport __all__ = ('AuthInfo',) @@ -166,7 +166,7 @@ def get_workdir(self) -> str: except KeyError: return self.computer.get_workdir() - def get_transport(self) -> 'Transport': + def get_transport(self) -> Union['BlockingTransport', 'AsyncTransport']: """Return a fully configured transport that can be used to connect to the computer set for this instance.""" computer = self.computer transport_type = computer.transport_type diff --git a/src/aiida/orm/computers.py b/src/aiida/orm/computers.py index 1c695910af..46b4ec522b 100644 --- a/src/aiida/orm/computers.py +++ b/src/aiida/orm/computers.py @@ -23,7 +23,7 @@ from aiida.orm import AuthInfo, User from aiida.orm.implementation import StorageBackend from aiida.schedulers import Scheduler - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport __all__ = ('Computer',) @@ -622,7 +622,7 @@ def is_user_enabled(self, user: 'User') -> bool: # Return False if the user is not configured (in a sense, it is disabled for that user) return False - def get_transport(self, user: Optional['User'] = None) -> 'Transport': + def get_transport(self, user: Optional['User'] = None) -> Union['BlockingTransport', 'AsyncTransport']: """Return a Transport class, configured with all correct parameters. The Transport is closed (meaning that if you want to run any operation with it, you have to open it first (i.e., e.g. for a SSH transport, you have @@ -646,7 +646,7 @@ def get_transport(self, user: Optional['User'] = None) -> 'Transport': authinfo = authinfos.AuthInfo.get_collection(self.backend).get(dbcomputer=self, aiidauser=user) return authinfo.get_transport() - def get_transport_class(self) -> Type['Transport']: + def get_transport_class(self) -> Union[Type['BlockingTransport'], Type['AsyncTransport']]: """Get the transport class for this computer. Can be used to instantiate a transport instance.""" try: return TransportFactory(self.transport_type) diff --git a/src/aiida/orm/nodes/data/remote/base.py b/src/aiida/orm/nodes/data/remote/base.py index 1fc691d113..655d2fccad 100644 --- a/src/aiida/orm/nodes/data/remote/base.py +++ b/src/aiida/orm/nodes/data/remote/base.py @@ -117,7 +117,8 @@ def listdir_withattributes(self, path='.'): """Connects to the remote folder and lists the directory content. :param relpath: If 'relpath' is specified, lists the content of the given subfolder. - :return: a list of dictionaries, where the documentation is in :py:class:Transport.listdir_withattributes. + :return: a list of dictionaries, where the documentation + is in :py:class:BlockingTransport.listdir_withattributes. """ authinfo = self.get_authinfo() diff --git a/src/aiida/orm/nodes/process/calculation/calcjob.py b/src/aiida/orm/nodes/process/calculation/calcjob.py index a7cd20c88e..c7580cd91b 100644 --- a/src/aiida/orm/nodes/process/calculation/calcjob.py +++ b/src/aiida/orm/nodes/process/calculation/calcjob.py @@ -26,7 +26,7 @@ from aiida.parsers import Parser from aiida.schedulers.datastructures import JobInfo, JobState from aiida.tools.calculations import CalculationTools - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport __all__ = ('CalcJobNode',) @@ -450,10 +450,11 @@ def get_authinfo(self) -> 'AuthInfo': return computer.get_authinfo(self.user) - def get_transport(self) -> 'Transport': + def get_transport(self) -> Union['BlockingTransport', 'AsyncTransport']: """Return the transport for this calculation. - :return: `Transport` configured with the `AuthInfo` associated to the computer of this node + :return: Union['BlockingTransport', 'AsyncTransport'] configured + with the `AuthInfo` associated to the computer of this node """ return self.get_authinfo().get_transport() diff --git a/src/aiida/orm/utils/remote.py b/src/aiida/orm/utils/remote.py index f55cedc35a..a8aa19b3fc 100644 --- a/src/aiida/orm/utils/remote.py +++ b/src/aiida/orm/utils/remote.py @@ -12,6 +12,7 @@ import os import typing as t +from typing import Union from aiida.orm.nodes.data.remote.base import RemoteData @@ -20,14 +21,14 @@ from aiida import orm from aiida.orm.implementation import StorageBackend - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport -def clean_remote(transport: Transport, path: str) -> None: +def clean_remote(transport: Union['BlockingTransport', 'AsyncTransport'], path: str) -> None: """Recursively remove a remote folder, with the given absolute path, and all its contents. The path should be made accessible through the transport channel, which should already be open - :param transport: an open Transport channel + :param transport: an open Union['BlockingTransport', 'AsyncTransport'] channel :param path: an absolute path on the remote made available through the transport """ if not isinstance(path, str): diff --git a/src/aiida/plugins/factories.py b/src/aiida/plugins/factories.py index 3c028a3c47..affce3d405 100644 --- a/src/aiida/plugins/factories.py +++ b/src/aiida/plugins/factories.py @@ -42,7 +42,7 @@ from aiida.schedulers import Scheduler from aiida.tools.data.orbital import Orbital from aiida.tools.dbimporters import DbImporter - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> NoReturn: @@ -410,15 +410,19 @@ def StorageFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint @overload -def TransportFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Transport']: ... +def TransportFactory( + entry_point_name: str, load: Literal[True] = True +) -> Union[Type['BlockingTransport'], Type['AsyncTransport']]: ... @overload def TransportFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ... -def TransportFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Transport']]: - """Return the `Transport` sub class registered under the given entry point. +def TransportFactory( + entry_point_name: str, load: bool = True +) -> Union[EntryPoint, Type['BlockingTransport'], Type['AsyncTransport']]: + """Return the Union['BlockingTransport', 'AsyncTransport'] sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. @@ -426,16 +430,16 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoi """ from inspect import isclass - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport entry_point_group = 'aiida.transports' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) - valid_classes = (Transport,) + valid_classes = (BlockingTransport, AsyncTransport) if not load: return entry_point - if isclass(entry_point) and issubclass(entry_point, Transport): + if isclass(entry_point) and (issubclass(entry_point, BlockingTransport) or issubclass(entry_point, AsyncTransport)): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) diff --git a/src/aiida/schedulers/scheduler.py b/src/aiida/schedulers/scheduler.py index 3cd4136984..e9fc2db3e2 100644 --- a/src/aiida/schedulers/scheduler.py +++ b/src/aiida/schedulers/scheduler.py @@ -12,6 +12,7 @@ import abc import typing as t +from typing import Union from aiida.common import exceptions, log, warnings from aiida.common.datastructures import CodeRunMode @@ -21,7 +22,7 @@ from aiida.schedulers.datastructures import JobInfo, JobResource, JobTemplate, JobTemplateCodeInfo if t.TYPE_CHECKING: - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport __all__ = ('Scheduler', 'SchedulerError', 'SchedulerParsingError') @@ -365,7 +366,7 @@ def transport(self): return self._transport - def set_transport(self, transport: Transport): + def set_transport(self, transport: Union['BlockingTransport', 'AsyncTransport']): """Set the transport to be used to query the machine or to submit scripts. This class assumes that the transport is open and active. diff --git a/src/aiida/tools/pytest_fixtures/__init__.py b/src/aiida/tools/pytest_fixtures/__init__.py index 1b2c38e285..c092ad9f28 100644 --- a/src/aiida/tools/pytest_fixtures/__init__.py +++ b/src/aiida/tools/pytest_fixtures/__init__.py @@ -22,6 +22,7 @@ aiida_computer, aiida_computer_local, aiida_computer_ssh, + aiida_computer_ssh_async, aiida_localhost, ssh_key, ) @@ -32,6 +33,7 @@ 'aiida_code', 'aiida_computer_local', 'aiida_computer_ssh', + 'aiida_computer_ssh_async', 'aiida_computer', 'aiida_config_factory', 'aiida_config_tmp', diff --git a/src/aiida/tools/pytest_fixtures/orm.py b/src/aiida/tools/pytest_fixtures/orm.py index 0ed7ea18d7..076eac2ddb 100644 --- a/src/aiida/tools/pytest_fixtures/orm.py +++ b/src/aiida/tools/pytest_fixtures/orm.py @@ -190,6 +190,38 @@ def factory(label: str | None = None, configure: bool = True) -> 'Computer': return factory +@pytest.fixture +def aiida_computer_ssh_async(aiida_computer) -> t.Callable[[], 'Computer']: + """Factory to return a :class:`aiida.orm.computers.Computer` instance with ``core.ssh_async`` transport. + + Usage:: + + def test(aiida_computer_ssh): + computer = aiida_computer_ssh(label='some-label', configure=True) + assert computer.transport_type == 'core.ssh_async' + assert computer.is_configured + + The factory has the following signature: + + :param label: The computer label. If not specified, a random UUID4 is used. + :param configure: Boolean, if ``True``, ensures the computer is configured, otherwise the computer is returned + as is. Note that if a computer with the given label already exists and it was configured before, the + computer will not be "un-"configured. If an unconfigured computer is absolutely required, make sure to first + delete the existing computer or specify another label. + :return: A stored computer instance. + """ + + def factory(label: str | None = None, configure: bool = True) -> 'Computer': + computer = aiida_computer(label=label, hostname='localhost', transport_type='core.ssh_async') + + if configure: + computer.configure() + + return computer + + return factory + + @pytest.fixture def aiida_localhost(aiida_computer_local) -> 'Computer': """Return a :class:`aiida.orm.computers.Computer` instance representing localhost with ``core.local`` transport. diff --git a/src/aiida/transports/__init__.py b/src/aiida/transports/__init__.py index c09153228e..8b6080f77d 100644 --- a/src/aiida/transports/__init__.py +++ b/src/aiida/transports/__init__.py @@ -14,14 +14,14 @@ from .plugins import * from .transport import * -from .util import StrPath __all__ = ( - 'SshTransport', 'Transport', + 'BlockingTransport', + 'SshTransport', + 'AsyncTransport', 'convert_to_bool', 'parse_sshconfig', - 'StrPath', ) # fmt: on diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index 16b7e20a8f..a7f45e6e87 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -1137,7 +1137,7 @@ def gettree( localpath: _TransportPath, callback=None, dereference: Optional[bool] = True, - overwrite: Optional[bool] = None, + overwrite: Optional[bool] = True, ): """Get a folder recursively from remote to local. @@ -1147,7 +1147,7 @@ def gettree( Default = True (default behaviour in paramiko). False is not implemented. :param overwrite: if True overwrites files and folders. - Default = False + Default = True :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index b4391f0e7b..c698acb9ca 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -454,6 +454,7 @@ async def copy_async( raise ValueError('remotedestination must be a non empty string') if not remotesource: raise ValueError('remotesource must be a non empty string') + try: if self.has_magic(remotesource): await self._sftp.mcopy( @@ -464,6 +465,8 @@ async def copy_async( follow_symlinks=dereference, ) else: + if not await self.path_exists_async(remotesource): + raise OSError(f'The remote path {remotesource} does not exist') await self._sftp.copy( remotesource, remotedestination, @@ -474,6 +477,24 @@ async def copy_async( except asyncssh.sftp.SFTPFailure as exc: raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') + async def copyfile_async( + self, + remotesource: _TransportPath, + remotedestination: _TransportPath, + dereference: bool = False, + preserve: bool = False, + ): + return await self.copy_async(remotesource, remotedestination, dereference, recursive=False, preserve=preserve) + + async def copytree_async( + self, + remotesource: _TransportPath, + remotedestination: _TransportPath, + dereference: bool = False, + preserve: bool = False, + ): + return await self.copy_async(remotesource, remotedestination, dereference, recursive=True, preserve=preserve) + async def exec_command_wait_async( self, command: str, @@ -487,7 +508,7 @@ async def exec_command_wait_async( :param command: the command to execute :param stdin: the standard input to pass to the command - :param encoding: (IGNORED) this is here just to keep the same signature as the one in `Transport` class + :param encoding: (IGNORED) this is here just to keep the same signature as the one in `BlockingTransport` class :param workdir: the working directory where to execute the command :param timeout: the timeout in seconds @@ -509,8 +530,8 @@ async def exec_command_wait_async( result = await self._conn.run( bash_commmand + escape_for_bash(command), input=stdin, check=False, timeout=timeout ) - # both stdout and stderr are strings - return (result.returncode, ''.join(result.stdout), ''.join(result.stderr)) # type: ignore [arg-type] + # Since the command is str, both stdout and stderr are strings + return (result.returncode, ''.join(str(result.stdout)), ''.join(str(result.stderr))) async def get_attribute_async(self, path): """ """ @@ -561,17 +582,19 @@ async def isfile_async(self, path): return await self._sftp.isfile(path) - async def listdir_async(self, path, pattern=None): # type: ignore[override] + async def listdir_async(self, path: _TransportPath, pattern=None): """ :param path: the absolute path to list """ path = fix_path(path) if not pattern: - list_ = await self._sftp.listdir(path) + list_ = list(await self._sftp.listdir(path)) else: patterned_path = pattern if pattern.startswith('/') else Path(path).joinpath(pattern) - list_ = await self._sftp.glob(patterned_path) + # I put the type ignore here because the asyncssh.sftp.glob() + # method alwyas returns a sequence of str, if input is str + list_ = list(await self._sftp.glob(patterned_path)) # type: ignore[arg-type] for item in ['..', '.']: if item in list_: @@ -579,7 +602,7 @@ async def listdir_async(self, path, pattern=None): # type: ignore[override] return list_ - async def listdir_withattributes_async(self, path: _TransportPath, pattern: Optional[str] = None): # type: ignore[override] + async def listdir_withattributes_async(self, path: _TransportPath, pattern: Optional[str] = None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -791,6 +814,21 @@ async def chmod_async(self, path, mode, follow_symlinks=True): except asyncssh.sftp.SFTPNoSuchFile as exc: raise OSError(f'Error {exc}, directory does not exists') + async def chown_async(self, path, uid, gid): + """Change the owner and group id of a file. + + :param str path: path to the file + :param int uid: the new owner id + :param int gid: the new group id + """ + path = fix_path(path) + if not path: + raise OSError('Input path is an empty argument.') + try: + await self._sftp.chown(path, uid, gid, follow_symlinks=True) + except asyncssh.sftp.SFTPNoSuchFile as exc: + raise OSError(f'Error {exc}, directory does not exists') + async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): """Copy files or folders from a remote computer to another remote computer, asynchronously. @@ -842,117 +880,10 @@ async def copy_from_remote_to_remote_async(self, transportdestination, remotesou os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put ) - ## Blocking methods. We need these for backwards compatibility - # This is useful, only because some part of engine and - # many external plugins are synchronous, in those cases blocking calls make more sense. - # However, be aware you cannot use these methods in an async functions, - # because they will block the event loop. - - def run_command_blocking(self, func, *args, **kwargs): - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) - - def open(self): - return self.run_command_blocking(self.open_async) - - def close(self): - return self.run_command_blocking(self.close_async) - - def chown(self, *args, **kwargs): - raise NotImplementedError('Not implemented, for now') - - def get(self, *args, **kwargs): - return self.run_command_blocking(self.get_async, *args, **kwargs) - - def getfile(self, *args, **kwargs): - return self.run_command_blocking(self.getfile_async, *args, **kwargs) - - def gettree(self, *args, **kwargs): - return self.run_command_blocking(self.gettree_async, *args, **kwargs) - - def put(self, *args, **kwargs): - return self.run_command_blocking(self.put_async, *args, **kwargs) - - def putfile(self, *args, **kwargs): - return self.run_command_blocking(self.putfile_async, *args, **kwargs) - - def puttree(self, *args, **kwargs): - return self.run_command_blocking(self.puttree_async, *args, **kwargs) - - def chmod(self, *args, **kwargs): - return self.run_command_blocking(self.chmod_async, *args, **kwargs) - - def copy(self, *args, **kwargs): - return self.run_command_blocking(self.copy_async, *args, **kwargs) - - def copyfile(self, *args, **kwargs): - return self.copy(*args, **kwargs) - - def copytree(self, *args, **kwargs): - return self.copy(*args, **kwargs) - - def exec_command_wait(self, *args, **kwargs): - return self.run_command_blocking(self.exec_command_wait_async, *args, **kwargs) - - def get_attribute(self, *args, **kwargs): - return self.run_command_blocking(self.get_attribute_async, *args, **kwargs) - - def isdir(self, *args, **kwargs): - return self.run_command_blocking(self.isdir_async, *args, **kwargs) - - def isfile(self, *args, **kwargs): - return self.run_command_blocking(self.isfile_async, *args, **kwargs) - - def listdir(self, *args, **kwargs): - return self.run_command_blocking(self.listdir_async, *args, **kwargs) - - def listdir_withattributes(self, *args, **kwargs): - return self.run_command_blocking(self.listdir_withattributes_async, *args, **kwargs) - - def makedirs(self, *args, **kwargs): - return self.run_command_blocking(self.makedirs_async, *args, **kwargs) - - def mkdir(self, *args, **kwargs): - return self.run_command_blocking(self.mkdir_async, *args, **kwargs) - - def remove(self, *args, **kwargs): - return self.run_command_blocking(self.remove_async, *args, **kwargs) - - def rename(self, *args, **kwargs): - return self.run_command_blocking(self.rename_async, *args, **kwargs) - - def rmdir(self, *args, **kwargs): - return self.run_command_blocking(self.rmdir_async, *args, **kwargs) - - def rmtree(self, *args, **kwargs): - return self.run_command_blocking(self.rmtree_async, *args, **kwargs) - - def path_exists(self, *args, **kwargs): - return self.run_command_blocking(self.path_exists_async, *args, **kwargs) - - def whoami(self, *args, **kwargs): - return self.run_command_blocking(self.whoami_async, *args, **kwargs) - - def symlink(self, *args, **kwargs): - return self.run_command_blocking(self.symlink_async, *args, **kwargs) - - def glob(self, *args, **kwargs): - return self.run_command_blocking(self.glob_async, *args, **kwargs) - def gotocomputer_command(self, remotedir): connect_string = self._gotocomputer_string(remotedir) cmd = f'ssh -t {self.machine} {connect_string}' return cmd - ## These methods are not implemented for async transport, - ## mainly because they are not being used across the codebase. - ## If you need them, please open an issue on GitHub - - def exec_command_wait_bytes(self, *args, **kwargs): - raise NotImplementedError('Not implemented, waiting for a use case') - - def _exec_command_internal(self, *args, **kwargs): - raise NotImplementedError('Not implemented, waiting for a use case') - - def normalize(self, *args, **kwargs): + async def normalize_async(self, path: _TransportPath): raise NotImplementedError('Not implemented, waiting for a use case') diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index f15b584dc2..1155a46895 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -15,16 +15,16 @@ import re import sys from collections import OrderedDict -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import Union from aiida.common.exceptions import InternalError from aiida.common.lang import classproperty from aiida.common.warnings import warn_deprecation -__all__ = ('Transport',) +__all__ = ('Transport', 'AsyncTransport', 'BlockingTransport') -_TransportPath = Union[str, Path] +_TransportPath = Union[str, Path, PurePosixPath] def validate_positive_number(ctx, param, value): @@ -44,7 +44,8 @@ def validate_positive_number(ctx, param, value): def fix_path(path: _TransportPath) -> str: - """Convert a Path object to a string.""" + """Convert an instance of _TransportPath = Union[str, Path, PurePosixPath] instance to a string.""" + # We could check if the path is a Path or PurePosixPath instance, but it's too much overhead. return str(path) @@ -249,6 +250,7 @@ def get_safe_open_interval(self): return self._safe_open_interval def has_magic(self, string): + string = fix_path(string) """Return True if the given string contains any special shell characters.""" return self._MAGIC_CHECK.search(string) is not None @@ -276,13 +278,12 @@ class BlockingTransport(abc.ABC, _BaseTransport): # keys: 'default', 'prompt', 'help', 'non_interactive_default' _valid_auth_options = [] - @abc.abstractmethod def __repr__(self): return f'<{self.__class__.__name__}: {self!s}>' @abc.abstractmethod def __str__(self): - return '[Transport class or subclass]' + """return [Transport class or subclass]""" @abc.abstractmethod def chmod(self, path, mode): @@ -754,16 +755,17 @@ def path_exists(self, path): # The following definitions are almost copied and pasted # from the python module glob. - def glob(self, pathname): + def glob(self, pathname: _TransportPath): """Return a list of paths matching a pathname pattern. The pattern may contain simple shell-style wildcards a la fnmatch. - :param str pathname: the pathname pattern to match. - It should only be absolute path. + :param pathname: the pathname pattern to match. + It should only be absolute path of type _TransportPath. DEPRECATED: using relative path is deprecated. :return: a list of paths matching the pattern. """ + pathname = fix_path(pathname) if not pathname.startswith('/'): warn_deprecation( 'Using relative paths across transport in `glob` is deprecated ' @@ -910,11 +912,11 @@ async def isfile_async(self, path): """Counterpart to isfile() that is async.""" return self.isfile(path) - async def listdir_async(self, path='.', pattern=None): + async def listdir_async(self, path, pattern=None): """Counterpart to listdir() that is async.""" return self.listdir(path, pattern) - async def listdir_withattributes_async(self, path: _TransportPath = '.', pattern=None): + async def listdir_withattributes_async(self, path: _TransportPath, pattern=None): """Counterpart to listdir_withattributes() that is async.""" return self.listdir_withattributes(path, pattern) @@ -926,7 +928,7 @@ async def mkdir_async(self, path, ignore_existing=False): """Counterpart to mkdir() that is async.""" return self.mkdir(path, ignore_existing) - async def normalize_async(self, path='.'): + async def normalize_async(self, path): """Counterpart to normalize() that is async.""" return self.normalize(path) @@ -1034,9 +1036,16 @@ async def gettree_async(self, remotepath, localpath, *args, **kwargs): async def get_attribute_async(self, path): """Return an object FixedFieldsAttributeDict for file in a given path""" - @abc.abstractmethod async def get_mode_async(self, path): - """Return the portion of the file's mode that can be set by chmod().""" + """Return the portion of the file's mode that can be set by chmod(). + + :param str path: path to file + :return: the portion of the file's mode that can be set by chmod() + """ + import stat + + attr = await self.get_attribute_async(path) + return stat.S_IMODE(attr.st_mode) @abc.abstractmethod async def isdir_async(self, path): @@ -1047,11 +1056,11 @@ async def isfile_async(self, path): """Return True if path is an existing file.""" @abc.abstractmethod - async def listdir_async(self, path='.', pattern=None): + async def listdir_async(self, path: _TransportPath, pattern=None): """Return a list of the names of the entries in the given path.""" @abc.abstractmethod - async def listdir_withattributes_async(self, path: _TransportPath = '.', pattern=None): + async def listdir_withattributes_async(self, path: _TransportPath, pattern=None): """Return a list of the names of the entries in the given path.""" @abc.abstractmethod @@ -1063,7 +1072,7 @@ async def mkdir_async(self, path, ignore_existing=False): """Create a folder (directory) named path.""" @abc.abstractmethod - async def normalize_async(self, path='.'): + async def normalize_async(self, path: _TransportPath): """Return the normalized path (on the server) of a given path.""" @abc.abstractmethod @@ -1143,6 +1152,9 @@ def getfile(self, *args, **kwargs): def gettree(self, *args, **kwargs): return self.run_command_blocking(self.gettree_async, *args, **kwargs) + def get_mode(self, *args, **kwargs): + return self.run_command_blocking(self.get_mode_async, *args, **kwargs) + def put(self, *args, **kwargs): return self.run_command_blocking(self.put_async, *args, **kwargs) @@ -1159,10 +1171,10 @@ def copy(self, *args, **kwargs): return self.run_command_blocking(self.copy_async, *args, **kwargs) def copyfile(self, *args, **kwargs): - return self.copy(*args, **kwargs) + return self.run_command_blocking(self.copyfile_async, *args, **kwargs) def copytree(self, *args, **kwargs): - return self.copy(*args, **kwargs) + return self.run_command_blocking(self.copytree_async, *args, **kwargs) def exec_command_wait(self, *args, **kwargs): return self.run_command_blocking(self.exec_command_wait_async, *args, **kwargs) diff --git a/tests/engine/daemon/test_execmanager.py b/tests/engine/daemon/test_execmanager.py index 79692b689b..9ced2a0cd4 100644 --- a/tests/engine/daemon/test_execmanager.py +++ b/tests/engine/daemon/test_execmanager.py @@ -15,6 +15,7 @@ from aiida.common.datastructures import CalcInfo, CodeInfo, FileCopyOperation from aiida.common.folders import SandboxFolder from aiida.engine.daemon import execmanager +from aiida.manage import get_manager from aiida.orm import CalcJobNode, FolderData, PortableCode, RemoteData, SinglefileData from aiida.transports.plugins.local import LocalTransport @@ -123,10 +124,12 @@ def test_retrieve_files_from_list( target = tmp_path_factory.mktemp('target') create_file_hierarchy(file_hierarchy, source) + runner = get_manager().get_runner() with LocalTransport() as transport: node = generate_calcjob_node(workdir=source) - execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) + runner.loop.run_until_complete(execmanager.retrieve_files_from_list(node, transport, target, retrieve_list)) + # await execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) assert serialize_file_hierarchy(target, read_bytes=False) == expected_hierarchy @@ -164,7 +167,8 @@ def test_upload_local_copy_list( calc_info.local_copy_list = [[folder.uuid] + local_copy_list] with node.computer.get_transport() as transport: - execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + runner = get_manager().get_runner() + runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)) # Check that none of the files were written to the repository of the calculation node, since they were communicated # through the ``local_copy_list``. @@ -201,7 +205,8 @@ def test_upload_local_copy_list_files_folders( ] with node.computer.get_transport() as transport: - execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + runner = get_manager().get_runner() + runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)) # Check that none of the files were written to the repository of the calculation node, since they were communicated # through the ``local_copy_list``. @@ -232,7 +237,8 @@ def test_upload_remote_symlink_list( ] with node.computer.get_transport() as transport: - execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + runner = get_manager().get_runner() + runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)) filepath_workdir = pathlib.Path(node.get_remote_workdir()) assert (filepath_workdir / 'file_a.txt').is_symlink() @@ -296,7 +302,8 @@ def test_upload_file_copy_operation_order(node_and_calc_info, tmp_path, order, e calc_info.file_copy_operation_order = order with node.computer.get_transport() as transport: - execmanager.upload_calculation(node, transport, calc_info, sandbox, inputs) + runner = get_manager().get_runner() + runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, sandbox, inputs)) filepath = pathlib.Path(node.get_remote_workdir()) / 'file.txt' assert filepath.is_file() assert filepath.read_text() == expected @@ -567,18 +574,20 @@ def test_upload_combinations( calc_info.remote_copy_list.append( (node.computer.uuid, (sub_tmp_path_remote / source_path).as_posix(), target_path) ) - + runner = get_manager().get_runner() if expected_exception is not None: with pytest.raises(expected_exception): with node.computer.get_transport() as transport: - execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + runner.loop.run_until_complete( + execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + ) filepath_workdir = pathlib.Path(node.get_remote_workdir()) assert serialize_file_hierarchy(filepath_workdir, read_bytes=False) == expected_hierarchy else: with node.computer.get_transport() as transport: - execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) + runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)) filepath_workdir = pathlib.Path(node.get_remote_workdir()) @@ -606,9 +615,12 @@ def test_upload_calculation_portable_code(fixture_sandbox, node_and_calc_info, t calc_info.codes_info = [code_info] with node.computer.get_transport() as transport: - execmanager.upload_calculation( - node, - transport, - calc_info, - fixture_sandbox, + runner = get_manager().get_runner() + runner.loop.run_until_complete( + execmanager.upload_calculation( + node, + transport, + calc_info, + fixture_sandbox, + ) ) diff --git a/tests/manage/tests/test_pytest_fixtures.py b/tests/manage/tests/test_pytest_fixtures.py index ab2e9d82b6..c3f4de39dc 100644 --- a/tests/manage/tests/test_pytest_fixtures.py +++ b/tests/manage/tests/test_pytest_fixtures.py @@ -6,7 +6,7 @@ from aiida.manage.configuration import get_config from aiida.manage.configuration.config import Config from aiida.orm import Computer -from aiida.transports import Transport +from aiida.transports import AsyncTransport, BlockingTransport def test_profile_config(): @@ -29,7 +29,7 @@ def test_aiida_computer_local(aiida_computer_local): assert computer.transport_type == 'core.local' with computer.get_transport() as transport: - assert isinstance(transport, Transport) + assert isinstance(transport, BlockingTransport) # Calling it again with the same label should simply return the existing computer computer_alt = aiida_computer_local(label=computer.label) @@ -52,7 +52,7 @@ def test_aiida_computer_ssh(aiida_computer_ssh): assert computer.transport_type == 'core.ssh' with computer.get_transport() as transport: - assert isinstance(transport, Transport) + assert isinstance(transport, BlockingTransport) # Calling it again with the same label should simply return the existing computer computer_alt = aiida_computer_ssh(label=computer.label) @@ -63,3 +63,26 @@ def test_aiida_computer_ssh(aiida_computer_ssh): computer_unconfigured = aiida_computer_ssh(label=str(uuid.uuid4()), configure=False) assert not computer_unconfigured.is_configured + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_aiida_computer_ssh_async(aiida_computer_ssh_async): + """Test the ``aiida_computer_ssh_async`` fixture.""" + computer = aiida_computer_ssh_async() + assert isinstance(computer, Computer) + assert computer.is_configured + assert computer.hostname == 'localhost' + assert computer.transport_type == 'core.ssh_async' + + with computer.get_transport() as transport: + assert isinstance(transport, AsyncTransport) + + # Calling it again with the same label should simply return the existing computer + computer_alt = aiida_computer_ssh_async(label=computer.label) + assert computer_alt.uuid == computer.uuid + + computer_new = aiida_computer_ssh_async(label=str(uuid.uuid4())) + assert computer_new.uuid != computer.uuid + + computer_unconfigured = aiida_computer_ssh_async(label=str(uuid.uuid4()), configure=False) + assert not computer_unconfigured.is_configured diff --git a/tests/orm/test_computers.py b/tests/orm/test_computers.py index adb1cbaae8..572cfa9c7c 100644 --- a/tests/orm/test_computers.py +++ b/tests/orm/test_computers.py @@ -67,7 +67,7 @@ def test_get_minimum_job_poll_interval(self): # No transport class defined: fall back on class default. assert computer.get_minimum_job_poll_interval() == Computer.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT - # Transport class defined: use default of the transport class. + # BlockingTransport class defined: use default of the transport class. transport = TransportFactory('core.local') computer.transport_type = 'core.local' assert computer.get_minimum_job_poll_interval() == transport.DEFAULT_MINIMUM_JOB_POLL_INTERVAL diff --git a/tests/plugins/test_factories.py b/tests/plugins/test_factories.py index 5ad4a83c13..f077e5ba6d 100644 --- a/tests/plugins/test_factories.py +++ b/tests/plugins/test_factories.py @@ -18,7 +18,7 @@ from aiida.schedulers import Scheduler from aiida.tools.data.orbital import Orbital from aiida.tools.dbimporters import DbImporter -from aiida.transports import Transport +from aiida.transports import AsyncTransport, BlockingTransport def custom_load_entry_point(group, name): @@ -68,7 +68,8 @@ def work_function(): 'invalid': Node, }, 'aiida.transports': { - 'valid': Transport, + 'valid_A': AsyncTransport, + 'valid_B': BlockingTransport, 'invalid': Node, }, 'aiida.workflows': { @@ -189,8 +190,11 @@ def test_storage_factory(self): @pytest.mark.usefixtures('mock_load_entry_point') def test_transport_factory(self): """Test the ``TransportFactory``.""" - plugin = factories.TransportFactory('valid') - assert plugin is Transport + plugin = factories.TransportFactory('valid_B') + assert plugin is BlockingTransport + + plugin = factories.TransportFactory('valid_A') + assert plugin is AsyncTransport with pytest.raises(InvalidEntryPointTypeError): factories.TransportFactory('invalid') diff --git a/tests/test_calculation_node.py b/tests/test_calculation_node.py index d1120499f8..d845de6df4 100644 --- a/tests/test_calculation_node.py +++ b/tests/test_calculation_node.py @@ -120,7 +120,7 @@ def test_get_authinfo(self): def test_get_transport(self): """Test that we can get the Transport object from the calculation instance.""" - from aiida.transports import Transport + from aiida.transports import AsyncTransport, BlockingTransport transport = self.calcjob.get_transport() - assert isinstance(transport, Transport) + assert isinstance(transport, BlockingTransport) or isinstance(transport, AsyncTransport) diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 6d4fc8a6ad..707a38020b 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -21,11 +21,12 @@ import time import uuid from pathlib import Path +from typing import Union import psutil import pytest -from aiida.plugins import SchedulerFactory, TransportFactory -from aiida.transports import Transport +from aiida.plugins import SchedulerFactory, TransportFactory, entry_point +from aiida.transports import AsyncTransport, BlockingTransport # TODO : test for copy with pattern # TODO : test for copy with/without patterns, overwriting folder @@ -33,9 +34,8 @@ # TODO : silly cases of copy/put/get from self to self -# @pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) -@pytest.fixture(scope='function', params=['core.ssh', 'core.ssh_async']) -def custom_transport(request, tmp_path, monkeypatch) -> Transport: +@pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) +def custom_transport(request, tmp_path, monkeypatch) -> Union['BlockingTransport', 'AsyncTransport']: """Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``.""" plugin = TransportFactory(request.param) @@ -130,23 +130,26 @@ def test_rmtree(custom_transport, tmpdir): def test_listdir(custom_transport, tmpdir): """Create directories, verify listdir""" + # we need another directory as tmpdir is polluted by custom_transport for the case of core.ssh_auto + tmpdir_ = tmpdir / 'remote' + tmpdir_.mkdir() with custom_transport as transport: list_of_dir = ['1', '-f a&', 'as', 'a2', 'a4f'] list_of_files = ['a', 'b'] for this_dir in list_of_dir: - transport.mkdir(tmpdir / this_dir) + transport.mkdir(tmpdir_ / this_dir) for fname in list_of_files: with tempfile.NamedTemporaryFile() as tmpf: # Just put an empty file there at the right file name - transport.putfile(tmpf.name, tmpdir / fname) + transport.putfile(tmpf.name, tmpdir_ / fname) - list_found = transport.listdir(tmpdir) + list_found = transport.listdir(tmpdir_) assert sorted(list_found) == sorted(list_of_dir + list_of_files) - assert sorted(transport.listdir(tmpdir, 'a*')), sorted(['as', 'a2', 'a4f']) - assert sorted(transport.listdir(tmpdir, 'a?')), sorted(['as', 'a2']) - assert sorted(transport.listdir(tmpdir, 'a[2-4]*')), sorted(['a2', 'a4f']) + assert sorted(transport.listdir(tmpdir_, 'a*')), sorted(['as', 'a2', 'a4f']) + assert sorted(transport.listdir(tmpdir_, 'a?')), sorted(['as', 'a2']) + assert sorted(transport.listdir(tmpdir_, 'a[2-4]*')), sorted(['a2', 'a4f']) def test_listdir_withattributes(custom_transport, tmpdir): diff --git a/utils/dependency_management.py b/utils/dependency_management.py index dd6875e57e..14d4c6c862 100644 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -294,10 +294,17 @@ def check_requirements(extras, github_annotate): for requirement_abstract in requirements_abstract: for requirement_concrete in requirements_concrete: - version = Specifier(str(requirement_concrete.specifier)).version - if canonicalize_name(requirement_abstract.name) == canonicalize_name( - requirement_concrete.name - ) and requirement_abstract.specifier.contains(version): + if '@' in str(requirement_concrete): + version = str(requirement_concrete).split('@')[1] + abstract_contains = version in str(requirement_abstract) + else: + version = Specifier(str(requirement_concrete.specifier)).version + abstract_contains = requirement_abstract.specifier.contains(version) + + if ( + canonicalize_name(requirement_abstract.name) == canonicalize_name(requirement_concrete.name) + and abstract_contains + ): installed.append(requirement_abstract) break From 565724dac46f73dc579070cbcef28838d13e41cd Mon Sep 17 00:00:00 2001 From: Ali Khosravi Date: Wed, 20 Nov 2024 21:21:52 +0100 Subject: [PATCH 08/29] docstring updated --- environment.yml | 2 +- src/aiida/transports/plugins/local.py | 108 ++- src/aiida/transports/plugins/ssh.py | 141 ++-- src/aiida/transports/plugins/ssh_async.py | 443 ++++++++---- src/aiida/transports/transport.py | 781 +++++++++++++++++----- 5 files changed, 1086 insertions(+), 389 deletions(-) diff --git a/environment.yml b/environment.yml index e517f374d1..abbe6282f0 100644 --- a/environment.yml +++ b/environment.yml @@ -23,7 +23,7 @@ dependencies: - importlib-metadata~=6.0 - numpy~=1.21 - paramiko~=3.0 -- plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy +- plumpy@ git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy - pgsu~=0.3.0 - psutil~=5.6 - psycopg[binary]~=3.0 diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index c6e613a55b..d8bbc2ddc5 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -15,10 +15,11 @@ import os import shutil import subprocess +from typing import Optional from aiida.common.warnings import warn_deprecation from aiida.transports import cli as transport_cli -from aiida.transports.transport import BlockingTransport, TransportInternalError +from aiida.transports.transport import BlockingTransport, TransportInternalError, _TransportPath, path_2_str # refactor or raise the limit: issue #1784 @@ -93,7 +94,7 @@ def curdir(self): raise TransportInternalError('Error, local method called for LocalTransport without opening the channel first') - def chdir(self, path): + def chdir(self, path: _TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -106,6 +107,7 @@ def chdir(self, path): '`chdir()` is deprecated and will be removed in the next major version.', version=3, ) + path = path_2_str(path) new_path = os.path.join(self.curdir, path) if not os.path.isdir(new_path): raise OSError(f"'{new_path}' is not a valid directory") @@ -114,13 +116,15 @@ def chdir(self, path): self._internal_dir = os.path.normpath(new_path) - def chown(self, path, uid, gid): + def chown(self, path: _TransportPath, uid, gid): + path = path_2_str(path) os.chown(path, uid, gid) - def normalize(self, path='.'): + def normalize(self, path: _TransportPath = '.'): """Normalizes path, eliminating double slashes, etc.. :param path: path to normalize """ + path = path_2_str(path) return os.path.realpath(os.path.join(self.curdir, path)) def getcwd(self): @@ -132,8 +136,9 @@ def getcwd(self): return self.curdir @staticmethod - def _os_path_split_asunder(path): + def _os_path_split_asunder(path: _TransportPath): """Used by makedirs, Takes path (a str) and returns a list deconcatenating the path.""" + path = path_2_str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -147,7 +152,7 @@ def _os_path_split_asunder(path): parts.reverse() return parts - def makedirs(self, path, ignore_existing=False): + def makedirs(self, path: _TransportPath, ignore_existing=False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -158,6 +163,7 @@ def makedirs(self, path, ignore_existing=False): :raise OSError: If the directory already exists and is not ignore_existing """ + path = path_2_str(path) # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -173,7 +179,7 @@ def makedirs(self, path, ignore_existing=False): if not os.path.exists(this_dir): os.mkdir(this_dir) - def mkdir(self, path, ignore_existing=False): + def mkdir(self, path: _TransportPath, ignore_existing=False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -182,33 +188,37 @@ def mkdir(self, path, ignore_existing=False): :raise OSError: If the directory already exists. """ + path = path_2_str(path) if ignore_existing and self.isdir(path): return os.mkdir(os.path.join(self.curdir, path)) - def rmdir(self, path): + def rmdir(self, path: _TransportPath): """Removes a folder at location path. :param path: path to remove """ + path = path_2_str(path) os.rmdir(os.path.join(self.curdir, path)) - def isdir(self, path): + def isdir(self, path: _TransportPath): """Checks if 'path' is a directory. :return: a boolean """ + path = path_2_str(path) if not path: return False return os.path.isdir(os.path.join(self.curdir, path)) - def chmod(self, path, mode): + def chmod(self, path: _TransportPath, mode): """Changes permission bits of object at path :param path: path to modify :param mode: permission bits :raise OSError: if path does not exist. """ + path = path_2_str(path) if not path: raise OSError('Directory not given in input') real_path = os.path.join(self.curdir, path) @@ -219,7 +229,7 @@ def chmod(self, path, mode): # please refactor: issue #1782 - def put(self, localpath, remotepath, *args, **kwargs): + def put(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): """Copies a file or a folder from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -233,6 +243,8 @@ def put(self, localpath, remotepath, *args, **kwargs): :raise OSError: if remotepath is not valid :raise ValueError: if localpath is not valid """ + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) from aiida.common.warnings import warn_deprecation if 'ignore_noexisting' in kwargs: @@ -299,7 +311,7 @@ def put(self, localpath, remotepath, *args, **kwargs): else: raise OSError(f'The local path {localpath} does not exist') - def putfile(self, localpath, remotepath, *args, **kwargs): + def putfile(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): """Copies a file from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -312,6 +324,9 @@ def putfile(self, localpath, remotepath, *args, **kwargs): :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) + overwrite = kwargs.get('overwrite', args[0] if args else True) if not remotepath: raise OSError('Input remotepath to putfile must be a non empty string') @@ -330,7 +345,7 @@ def putfile(self, localpath, remotepath, *args, **kwargs): shutil.copyfile(localpath, the_destination) - def puttree(self, localpath, remotepath, *args, **kwargs): + def puttree(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): """Copies a folder recursively from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -345,6 +360,8 @@ def puttree(self, localpath, remotepath, *args, **kwargs): :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -370,11 +387,12 @@ def puttree(self, localpath, remotepath, *args, **kwargs): shutil.copytree(localpath, the_destination, symlinks=not dereference, dirs_exist_ok=overwrite) - def rmtree(self, path): + def rmtree(self, path: _TransportPath): """Remove tree as rm -r would do :param path: a string to path """ + path = path_2_str(path) the_path = os.path.join(self.curdir, path) try: shutil.rmtree(the_path) @@ -388,7 +406,7 @@ def rmtree(self, path): # please refactor: issue #1781 - def get(self, remotepath, localpath, *args, **kwargs): + def get(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): """Copies a folder or a file recursively from 'remote' remotepath to 'local' localpath. Automatically redirects to getfile or gettree. @@ -403,6 +421,8 @@ def get(self, remotepath, localpath, *args, **kwargs): :raise OSError: if 'remote' remotepath is not valid :raise ValueError: if 'local' localpath is not valid """ + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) ignore_nonexisting = kwargs.get('ignore_nonexisting', args[2] if len(args) > 2 else False) @@ -454,7 +474,7 @@ def get(self, remotepath, localpath, *args, **kwargs): else: raise OSError(f'The remote path {remotepath} does not exist') - def getfile(self, remotepath, localpath, *args, **kwargs): + def getfile(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): """Copies a file recursively from 'remote' remotepath to 'local' localpath. @@ -467,6 +487,8 @@ def getfile(self, remotepath, localpath, *args, **kwargs): :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) overwrite = kwargs.get('overwrite', args[0] if args else True) if not localpath: raise ValueError('Input localpath to get function must be a non empty string') @@ -482,7 +504,7 @@ def getfile(self, remotepath, localpath, *args, **kwargs): shutil.copyfile(the_source, localpath) - def gettree(self, remotepath, localpath, *args, **kwargs): + def gettree(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): """Copies a folder recursively from 'remote' remotepath to 'local' localpath. @@ -495,6 +517,8 @@ def gettree(self, remotepath, localpath, *args, **kwargs): :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -521,7 +545,7 @@ def gettree(self, remotepath, localpath, *args, **kwargs): # please refactor: issue #1780 on github - def copy(self, remotesource, remotedestination, dereference=False, recursive=True): + def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False, recursive=True): """Copies a file or a folder from 'remote' remotesource to 'remote' remotedestination. Automatically redirects to copyfile or copytree. @@ -534,6 +558,8 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru :raise ValueError: if 'remote' remotesource or remotedestinationis not valid :raise OSError: if remotesource does not exist """ + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copy must be a non empty object') if not remotedestination: @@ -581,7 +607,7 @@ def copy(self, remotesource, remotedestination, dereference=False, recursive=Tru # With self.copytree, the (possible) relative path is OK self.copytree(remotesource, remotedestination, dereference) - def copyfile(self, remotesource, remotedestination, dereference=False): + def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): """Copies a file from 'remote' remotesource to 'remote' remotedestination. @@ -592,6 +618,8 @@ def copyfile(self, remotesource, remotedestination, dereference=False): :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copyfile must be a non empty object') if not remotedestination: @@ -607,7 +635,7 @@ def copyfile(self, remotesource, remotedestination, dereference=False): else: shutil.copyfile(the_source, the_destination) - def copytree(self, remotesource, remotedestination, dereference=False): + def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): """Copies a folder from 'remote' remotesource to 'remote' remotedestination. @@ -618,6 +646,8 @@ def copytree(self, remotesource, remotedestination, dereference=False): :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copytree must be a non empty object') if not remotedestination: @@ -633,11 +663,12 @@ def copytree(self, remotesource, remotedestination, dereference=False): shutil.copytree(the_source, the_destination, symlinks=not dereference) - def get_attribute(self, path): + def get_attribute(self, path: _TransportPath): """Returns an object FileAttribute, as specified in aiida.transports. :param path: the path of the given file. """ + path = path_2_str(path) from aiida.transports.util import FileAttribute os_attr = os.lstat(os.path.join(self.curdir, path)) @@ -648,10 +679,12 @@ def get_attribute(self, path): aiida_attr[key] = getattr(os_attr, key) return aiida_attr - def _local_listdir(self, path, pattern=None): + def _local_listdir(self, path: _TransportPath, pattern=None): """Act on the local folder, for the rest, same as listdir.""" import re + path = path_2_str(path) + if not pattern: return os.listdir(path) @@ -665,12 +698,13 @@ def _local_listdir(self, path, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path='.', pattern=None): + def listdir(self, path: _TransportPath = '.', pattern=None): """:return: a list containing the names of the entries in the directory. :param path: default ='.' :param pattern: if set, returns the list of files matching pattern. Unix only. (Use to emulate ls * for example) """ + path = path_2_str(path) the_path = os.path.join(self.curdir, path).strip() if not pattern: try: @@ -687,20 +721,22 @@ def listdir(self, path='.', pattern=None): the_path += '/' return [re.sub(the_path, '', i) for i in filtered_list] - def remove(self, path): + def remove(self, path: _TransportPath): """Removes a file at position path.""" + path = path_2_str(path) os.remove(os.path.join(self.curdir, path)) - def isfile(self, path): + def isfile(self, path: _TransportPath): """Checks if object at path is a file. Returns a boolean. """ + path = path_2_str(path) if not path: return False return os.path.isfile(os.path.join(self.curdir, path)) @contextlib.contextmanager - def _exec_command_internal(self, command, workdir=None, **kwargs): + def _exec_command_internal(self, command, workdir: Optional[_TransportPath] = None, **kwargs): """Executes the specified command in bash login shell. @@ -725,6 +761,7 @@ def _exec_command_internal(self, command, workdir=None, **kwargs): """ from aiida.common.escaping import escape_for_bash + workdir = path_2_str(workdir) # Note: The outer shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. bash_commmand = f'{self._bash_command_str}-c ' @@ -747,7 +784,7 @@ def _exec_command_internal(self, command, workdir=None, **kwargs): ) as process: yield process - def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): + def exec_command_wait_bytes(self, command, stdin=None, workdir: Optional[_TransportPath] = None, **kwargs): """Executes the specified command and waits for it to finish. :param command: the command to execute @@ -759,6 +796,7 @@ def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both bytes and the return_value is an int. """ + workdir = path_2_str(workdir) with self._exec_command_internal(command, workdir) as process: if stdin is not None: # Implicitly assume that the desired encoding is 'utf-8' if I receive a string. @@ -801,7 +839,7 @@ def line_encoder(iterator, encoding='utf-8'): return retval, output_text, stderr_text - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: _TransportPath): """Return a string to be run using os.system in order to connect via the transport to the remote directory. @@ -812,11 +850,12 @@ def gotocomputer_command(self, remotedir): :param str remotedir: the full path of the remote directory """ + remotedir = path_2_str(remotedir) connect_string = self._gotocomputer_string(remotedir) cmd = f'bash -c {connect_string}' return cmd - def rename(self, oldpath, newpath): + def rename(self, oldpath: _TransportPath, newpath: _TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder @@ -825,6 +864,8 @@ def rename(self, oldpath, newpath): :raises OSError: if src/dst is not found :raises ValueError: if src/dst is not a valid string """ + oldpath = path_2_str(oldpath) + newpath = path_2_str(newpath) if not oldpath: raise ValueError(f'Source {oldpath} is not a valid string') if not newpath: @@ -836,15 +877,15 @@ def rename(self, oldpath, newpath): shutil.move(oldpath, newpath) - def symlink(self, remotesource, remotedestination): + def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPath): """Create a symbolic link between the remote source and the remote remotedestination :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ - remotesource = os.path.normpath(remotesource) - remotedestination = os.path.normpath(remotedestination) + remotesource = os.path.normpath(path_2_str(remotesource)) + remotedestination = os.path.normpath(path_2_str(remotedestination)) if self.has_magic(remotesource): if self.has_magic(remotedestination): @@ -863,8 +904,9 @@ def symlink(self, remotesource, remotedestination): except OSError: raise OSError(f'!!: {remotesource}, {self.curdir}, {remotedestination}') - def path_exists(self, path): + def path_exists(self, path: _TransportPath): """Check if path exists""" + path = path_2_str(path) return os.path.exists(os.path.join(self.curdir, path)) diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index a7f45e6e87..6c810fe523 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -13,7 +13,6 @@ import os import re from stat import S_ISDIR, S_ISREG -from typing import Optional import click @@ -22,7 +21,7 @@ from aiida.common.escaping import escape_for_bash from aiida.common.warnings import warn_deprecation -from ..transport import BlockingTransport, TransportInternalError, _TransportPath, fix_path +from ..transport import BlockingTransport, TransportInternalError, _TransportPath, path_2_str __all__ = ('parse_sshconfig', 'convert_to_bool', 'SshTransport') @@ -598,7 +597,7 @@ def chdir(self, path: _TransportPath): ) from paramiko.sftp import SFTPError - path = fix_path(path) + path = path_2_str(path) old_path = self.sftp.getcwd() if path is not None: try: @@ -627,7 +626,7 @@ def chdir(self, path: _TransportPath): def normalize(self, path: _TransportPath = '.'): """Returns the normalized path (removing double slashes, etc...)""" - path = fix_path(path) + path = path_2_str(path) return self.sftp.normalize(path) @@ -644,7 +643,7 @@ def stat(self, path: _TransportPath): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ - path = fix_path(path) + path = path_2_str(path) return self.sftp.stat(path) @@ -658,7 +657,7 @@ def lstat(self, path: _TransportPath): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ - path = fix_path(path) + path = path_2_str(path) return self.sftp.lstat(path) @@ -693,7 +692,7 @@ def makedirs(self, path: _TransportPath, ignore_existing: bool = False): :raise OSError: If the directory already exists. """ - path = fix_path(path) + path = path_2_str(path) # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -725,7 +724,7 @@ def mkdir(self, path: _TransportPath, ignore_existing: bool = False): :raise OSError: If the directory already exists. """ - path = fix_path(path) + path = path_2_str(path) if ignore_existing and self.isdir(path): return @@ -754,7 +753,7 @@ def rmtree(self, path: _TransportPath): :raise OSError: if the rm execution failed. """ - path = fix_path(path) + path = path_2_str(path) # Assuming linux rm command! rm_exe = 'rm' @@ -776,10 +775,10 @@ def rmtree(self, path: _TransportPath): def rmdir(self, path: _TransportPath): """Remove the folder named 'path' if empty.""" - path = fix_path(path) + path = path_2_str(path) self.sftp.rmdir(path) - def chown(self, path, uid, gid): + def chown(self, path: _TransportPath, uid, gid): """Change owner permissions of a file. For now, this is not implemented for the SSH transport. @@ -792,11 +791,11 @@ def isdir(self, path: _TransportPath): """ # Return False on empty string (paramiko would map this to the local # folder instead) - path = fix_path(path) + path = path_2_str(path) if not path: return False - path = fix_path(path) + path = path_2_str(path) try: return S_ISDIR(self.stat(path).st_mode) except OSError as exc: @@ -811,7 +810,7 @@ def chmod(self, path: _TransportPath, mode): :param path: path to file :param mode: new permission bits (integer) """ - path = fix_path(path) + path = path_2_str(path) if not path: raise OSError('Input path is an empty argument.') @@ -822,7 +821,7 @@ def _os_path_split_asunder(path: _TransportPath): """Used by makedirs. Takes path and returns a list deconcatenating the path """ - path = fix_path(path) + path = path_2_str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -841,9 +840,9 @@ def put( localpath: _TransportPath, remotepath: _TransportPath, callback=None, - dereference: Optional[bool] = True, - overwrite: Optional[bool] = True, - ignore_nonexisting: Optional[bool] = False, + dereference: bool = True, + overwrite: bool = True, + ignore_nonexisting: bool = False, ): """Put a file or a folder from local to remote. Redirects to putfile or puttree. @@ -858,8 +857,8 @@ def put( :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist """ - localpath = fix_path(localpath) - remotepath = fix_path(remotepath) + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) if not dereference: raise NotImplementedError @@ -916,8 +915,8 @@ def putfile( localpath: _TransportPath, remotepath: _TransportPath, callback=None, - dereference: Optional[bool] = True, - overwrite: Optional[bool] = True, + dereference: bool = True, + overwrite: bool = True, ): """Put a file from local to remote. @@ -930,8 +929,8 @@ def putfile( :raise OSError: if the localpath does not exist, or unintentionally overwriting """ - localpath = fix_path(localpath) - remotepath = fix_path(remotepath) + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) if not dereference: raise NotImplementedError @@ -949,8 +948,8 @@ def puttree( localpath: _TransportPath, remotepath: _TransportPath, callback=None, - dereference: Optional[bool] = True, - overwrite: Optional[bool] = True, + dereference: bool = True, + overwrite: bool = True, ): """Put a folder recursively from local to remote. @@ -970,8 +969,8 @@ def puttree( .. note:: setting dereference equal to True could cause infinite loops. see os.walk() documentation """ - localpath = fix_path(localpath) - remotepath = fix_path(remotepath) + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) if not dereference: raise NotImplementedError @@ -1023,9 +1022,9 @@ def get( remotepath: _TransportPath, localpath: _TransportPath, callback=None, - dereference: Optional[bool] = True, - overwrite: Optional[bool] = True, - ignore_nonexisting: Optional[bool] = False, + dereference: bool = True, + overwrite: bool = True, + ignore_nonexisting: bool = False, ): """Get a file or folder from remote to local. Redirects to getfile or gettree. @@ -1041,8 +1040,8 @@ def get( :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found """ - remotepath = fix_path(remotepath) - localpath = fix_path(localpath) + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) if not dereference: raise NotImplementedError @@ -1096,8 +1095,8 @@ def getfile( remotepath: _TransportPath, localpath: _TransportPath, callback=None, - dereference: Optional[bool] = True, - overwrite: Optional[bool] = True, + dereference: bool = True, + overwrite: bool = True, ): """Get a file from remote to local. @@ -1109,8 +1108,8 @@ def getfile( :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting """ - remotepath = fix_path(remotepath) - localpath = fix_path(localpath) + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -1136,8 +1135,8 @@ def gettree( remotepath: _TransportPath, localpath: _TransportPath, callback=None, - dereference: Optional[bool] = True, - overwrite: Optional[bool] = True, + dereference: bool = True, + overwrite: bool = True, ): """Get a folder recursively from remote to local. @@ -1153,8 +1152,8 @@ def gettree( :raise OSError: if the remotepath is not found :raise OSError: if unintentionally overwriting """ - remotepath = fix_path(remotepath) - localpath = fix_path(localpath) + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) if not dereference: raise NotImplementedError @@ -1195,7 +1194,7 @@ def get_attribute(self, path: _TransportPath): """Returns the object Fileattribute, specified in aiida.transports Receives in input the path of a given file. """ - path = fix_path(path) + path = path_2_str(path) from aiida.transports.util import FileAttribute paramiko_attr = self.lstat(path) @@ -1207,14 +1206,14 @@ def get_attribute(self, path: _TransportPath): return aiida_attr def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference: bool = False): - remotesource = fix_path(remotesource) - remotedestination = fix_path(remotedestination) + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) return self.copy(remotesource, remotedestination, dereference) def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference: bool = False): - remotesource = fix_path(remotesource) - remotedestination = fix_path(remotedestination) + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) return self.copy(remotesource, remotedestination, dereference, recursive=True) @@ -1241,8 +1240,8 @@ def copy( .. note:: setting dereference equal to True could cause infinite loops. """ - remotesource = fix_path(remotesource) - remotedestination = fix_path(remotedestination) + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) # In the majority of cases, we should deal with linux cp commands cp_flags = '-f' @@ -1285,11 +1284,9 @@ def copy( else: self._exec_cp(cp_exe, cp_flags, remotesource, remotedestination) - def _exec_cp(self, cp_exe: str, cp_flags: str, src: _TransportPath, dst: _TransportPath): + def _exec_cp(self, cp_exe: str, cp_flags: str, src: str, dst: str): """Execute the ``cp`` command on the remote machine.""" # to simplify writing the above copy function - src = fix_path(src) - dst = fix_path(dst) command = f'{cp_exe} {cp_flags} {escape_for_bash(src)} {escape_for_bash(dst)}' retval, stdout, stderr = self.exec_command_wait_bytes(command) @@ -1327,14 +1324,14 @@ def _local_listdir(path: str, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path='.', pattern=None): + def listdir(self, path: _TransportPath = '.', pattern=None): """Get the list of files at path. :param path: default = '.' :param pattern: returns the list of files matching pattern. Unix only. (Use to emulate ``ls *`` for example) """ - path = fix_path(path) + path = path_2_str(path) if path.startswith('/'): abs_dir = path @@ -1349,12 +1346,12 @@ def listdir(self, path='.', pattern=None): abs_dir += '/' return [re.sub(abs_dir, '', i) for i in filtered_list] - def remove(self, path): + def remove(self, path: _TransportPath): """Remove a single file at 'path'""" - path = fix_path(path) + path = path_2_str(path) return self.sftp.remove(path) - def rename(self, oldpath, newpath): + def rename(self, oldpath: _TransportPath, newpath: _TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder @@ -1368,8 +1365,8 @@ def rename(self, oldpath, newpath): if not newpath: raise ValueError(f'Destination {newpath} is not a valid path') - oldpath = fix_path(oldpath) - newpath = fix_path(newpath) + oldpath = path_2_str(oldpath) + newpath = path_2_str(newpath) if not self.isfile(oldpath): if not self.isdir(oldpath): @@ -1383,7 +1380,7 @@ def rename(self, oldpath, newpath): return self.sftp.rename(oldpath, newpath) - def isfile(self, path): + def isfile(self, path: _TransportPath): """Return True if the given path is a file, False otherwise. Return False also if the path does not exist. """ @@ -1393,7 +1390,7 @@ def isfile(self, path): if not path: return False - path = fix_path(path) + path = path_2_str(path) try: self.logger.debug( f"stat for path '{path}' ('{self.normalize(path)}'): {self.stat(path)} [{self.stat(path).st_mode}]" @@ -1454,7 +1451,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1, work return stdin, stdout, stderr, channel def exec_command_wait_bytes( - self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir=None + self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir: _TransportPath = None ): """Executes the specified command and waits for it to finish. @@ -1474,6 +1471,8 @@ def exec_command_wait_bytes( import socket import time + workdir = path_2_str(workdir) + ssh_stdin, stdout, stderr, channel = self._exec_command_internal( command, combine_stderr, bufsize=bufsize, workdir=workdir ) @@ -1567,11 +1566,11 @@ def exec_command_wait_bytes( return (retval, b''.join(stdout_bytes), b''.join(stderr_bytes)) - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: _TransportPath): """Specific gotocomputer string to connect to a given remote computer via ssh and directly go to the calculation folder. """ - remotedir = fix_path(remotedir) + remotedir = path_2_str(remotedir) further_params = [] if 'username' in self._connect_args: @@ -1595,25 +1594,25 @@ def gotocomputer_command(self, remotedir): cmd = f'ssh -t {self._machine} {further_params_str} {connect_string}' return cmd - def _symlink(self, source, dest): + def _symlink(self, source: _TransportPath, dest: _TransportPath): """Wrap SFTP symlink call without breaking API :param source: source of link :param dest: link to create """ - source = fix_path(source) - dest = fix_path(dest) + source = path_2_str(source) + dest = path_2_str(dest) self.sftp.symlink(source, dest) - def symlink(self, remotesource, remotedestination): + def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ - remotesource = fix_path(remotesource) - remotedestination = fix_path(remotedestination) + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) # paramiko gives some errors if path is starting with '.' source = os.path.normpath(remotesource) dest = os.path.normpath(remotedestination) @@ -1631,11 +1630,11 @@ def symlink(self, remotesource, remotedestination): else: self._symlink(source, dest) - def path_exists(self, path): + def path_exists(self, path: _TransportPath): """Check if path exists""" import errno - path = fix_path(path) + path = path_2_str(path) try: self.stat(path) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index c698acb9ca..444746c4dd 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -6,15 +6,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Plugin for transport over SSH asynchronously. +"""Plugin for transport over SSH asynchronously.""" -Since for many dependencies the blocking methods are required, -this plugin develops both blocking methods, as well. -""" - -## TODO: -## and start writing tests! -## put & get methods could be simplified with the asyncssh.sftp.mget() & put() method or sftp.glob() +## TODO: put & get methods could be simplified with the asyncssh.sftp.mget() & put() method or sftp.glob() import asyncio import glob import os @@ -28,7 +22,7 @@ from aiida.common.escaping import escape_for_bash from aiida.common.exceptions import InvalidOperation -from ..transport import AsyncTransport, TransportInternalError, _TransportPath, fix_path +from ..transport import AsyncTransport, BlockingTransport, TransportInternalError, _TransportPath, path_2_str __all__ = ('AsyncSshTransport',) @@ -111,6 +105,12 @@ def __init__(self, *args, **kwargs): self.script_during = kwargs.pop('script_during', 'None') async def open_async(self): + """Open the transport. + This plugin supports running scripts before and during the connection. + The scripts are run locally, not on the remote machine. + + :raises InvalidOperation: if the transport is already open + """ if self._is_open: raise InvalidOperation('Cannot open the transport twice') @@ -129,6 +129,10 @@ async def open_async(self): return self async def close_async(self): + """Close the transport. + + :raises InvalidOperation: if the transport is already closed + """ if not self._is_open: raise InvalidOperation('Cannot close the transport: it is already closed') @@ -139,23 +143,35 @@ async def close_async(self): def __str__(self): return f"{'OPEN' if self._is_open else 'CLOSED'} [AsyncSshTransport]" - async def get_async(self, remotepath, localpath, dereference=True, overwrite=True, ignore_nonexisting=False): + async def get_async( + self, + remotepath: _TransportPath, + localpath: _TransportPath, + dereference=True, + overwrite=True, + ignore_nonexisting=False, + ): """Get a file or folder from remote to local. Redirects to getfile or gettree. - :param remotepath: a remote path - :param localpath: an (absolute) local path + :param remotepath: an absolute remote path + :param localpath: an absolute local path :param dereference: follow symbolic links. - Default = True (default behaviour in paramiko). - False is not implemented. + Default = True :param overwrite: if True overwrites files and folders. Default = False + :type remotepath: _TransportPath + :type localpath: _TransportPath + :type dereference: bool + :type overwrite: bool + :type ignore_nonexisting: bool + :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found """ - remotepath = fix_path(remotepath) - localpath = fix_path(localpath) + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -203,19 +219,28 @@ async def get_async(self, remotepath, localpath, dereference=True, overwrite=Tru else: raise OSError(f'The remote path {remotepath} does not exist') - async def getfile_async(self, remotepath, localpath, dereference=True, overwrite=True): + async def getfile_async( + self, remotepath: _TransportPath, localpath: _TransportPath, dereference=True, overwrite=True + ): """Get a file from remote to local. - :param remotepath: a remote path - :param localpath: an (absolute) local path + :param remotepath: an absolute remote path + :param localpath: an absolute local path :param overwrite: if True overwrites files and folders. Default = False + :param dereference: follow symbolic links. + Default = True + + :type remotepath: _TransportPath + :type localpath: _TransportPath + :type dereference: bool + :type overwrite: bool :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting """ - remotepath = fix_path(remotepath) - localpath = fix_path(localpath) + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -230,23 +255,29 @@ async def getfile_async(self, remotepath, localpath, dereference=True, overwrite except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') - async def gettree_async(self, remotepath, localpath, dereference=True, overwrite=True): + async def gettree_async( + self, remotepath: _TransportPath, localpath: _TransportPath, dereference=True, overwrite=True + ): """Get a folder recursively from remote to local. - :param remotepath: a remote path - :param localpath: an (absolute) local path + :param remotepath: an absolute remote path + :param localpath: an absolute local path :param dereference: follow symbolic links. - Default = True (default behaviour in paramiko). - False is not implemented. + Default = True :param overwrite: if True overwrites files and folders. - Default = False + Default = True + + :type remotepath: _TransportPath + :type localpath: _TransportPath + :type dereference: bool + :type overwrite: bool :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found :raise OSError: if unintentionally overwriting """ - remotepath = fix_path(remotepath) - localpath = fix_path(localpath) + remotepath = path_2_str(remotepath) + localpath = path_2_str(localpath) if not remotepath: raise OSError('Remotepath must be a non empty string') @@ -283,22 +314,35 @@ async def gettree_async(self, remotepath, localpath, dereference=True, overwrite except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') - async def put_async(self, localpath, remotepath, dereference=True, overwrite=True, ignore_nonexisting=False): + async def put_async( + self, + localpath: _TransportPath, + remotepath: _TransportPath, + dereference=True, + overwrite=True, + ignore_nonexisting=False, + ): """Put a file or a folder from local to remote. Redirects to putfile or puttree. - :param localpath: an (absolute) local path - :param remotepath: a remote path - :param dereference: follow symbolic links (boolean). - Default = True (default behaviour in paramiko). False is not implemented. - :param overwrite: if True overwrites files and folders (boolean). - Default = False. + :param remotepath: an absolute remote path + :param localpath: an absolute local path + :param dereference: follow symbolic links + Default = True + :param overwrite: if True overwrites files and folders + Default = False + + :type remotepath: _TransportPath + :type localpath: _TransportPath + :type dereference: bool + :type overwrite: bool + :type ignore_nonexisting: bool :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist """ - localpath = fix_path(localpath) - remotepath = fix_path(remotepath) + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -348,20 +392,27 @@ async def put_async(self, localpath, remotepath, dereference=True, overwrite=Tru elif not ignore_nonexisting: raise OSError(f'The local path {localpath} does not exist') - async def putfile_async(self, localpath, remotepath, dereference=True, overwrite=True): + async def putfile_async( + self, localpath: _TransportPath, remotepath: _TransportPath, dereference=True, overwrite=True + ): """Put a file from local to remote. - :param localpath: an (absolute) local path - :param remotepath: a remote path - :param overwrite: if True overwrites files and folders (boolean). - Default = True. + :param remotepath: an absolute remote path + :param localpath: an absolute local path + :param overwrite: if True overwrites files and folders + Default = True + + :type remotepath: _TransportPath + :type localpath: _TransportPath + :type dereference: bool + :type overwrite: bool :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist, or unintentionally overwriting """ - localpath = fix_path(localpath) - remotepath = fix_path(remotepath) + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -376,27 +427,29 @@ async def putfile_async(self, localpath, remotepath, dereference=True, overwrite except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') - async def puttree_async(self, localpath, remotepath, dereference=True, overwrite=True): + async def puttree_async( + self, localpath: _TransportPath, remotepath: _TransportPath, dereference=True, overwrite=True + ): """Put a folder recursively from local to remote. - By default, overwrite. - - :param localpath: an (absolute) local path - :param remotepath: a remote path - :param dereference: follow symbolic links (boolean) - Default = True (default behaviour in paramiko). False is not implemented. + :param localpath: an absolute local path + :param remotepath: an absolute remote path + :param dereference: follow symbolic links + Default = True :param overwrite: if True overwrites files and folders (boolean). Default = True + :type localpath: _TransportPath + :type remotepath: _TransportPath + :type dereference: bool + :type overwrite: bool + :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist, or trying to overwrite :raise OSError: if remotepath is invalid - - .. note:: setting dereference equal to True could cause infinite loops. - see os.walk() documentation """ - localpath = fix_path(localpath) - remotepath = fix_path(remotepath) + localpath = path_2_str(localpath) + remotepath = path_2_str(remotepath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -444,9 +497,25 @@ async def copy_async( recursive: bool = True, preserve: bool = False, ): - """ """ - remotesource = fix_path(remotesource) - remotedestination = fix_path(remotedestination) + """Copy a file or a folder from remote to remote. + + :param remotesource: path to the remote source directory / file + :param remotedestination: path to the remote destination directory / file + :param dereference: follow symbolic links + :param recursive: copy recursively + :param preserve: preserve file attributes + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + :type dereference: bool + :type recursive: bool + :type preserve: bool + + :raises: OSError, src does not exist or if the copy execution failed. + """ + + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) if self.has_magic(remotedestination): raise ValueError('Pathname patterns are not allowed in the destination') @@ -484,6 +553,20 @@ async def copyfile_async( dereference: bool = False, preserve: bool = False, ): + """Copy a file from remote to remote. + + :param remotesource: path to the remote source file + :param remotedestination: path to the remote destination file + :param dereference: follow symbolic links + :param preserve: preserve file attributes + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + :type dereference: bool + :type preserve: bool + + :raises: OSError, src does not exist or if the copy execution failed. + """ return await self.copy_async(remotesource, remotedestination, dereference, recursive=False, preserve=preserve) async def copytree_async( @@ -493,6 +576,20 @@ async def copytree_async( dereference: bool = False, preserve: bool = False, ): + """Copy a folder from remote to remote. + + :param remotesource: path to the remote source directory + :param remotedestination: path to the remote destination directory + :param dereference: follow symbolic links + :param preserve: preserve file attributes + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + :type dereference: bool + :type preserve: bool + + :raises: OSError, src does not exist or if the copy execution failed. + """ return await self.copy_async(remotesource, remotedestination, dereference, recursive=True, preserve=preserve) async def exec_command_wait_async( @@ -500,14 +597,14 @@ async def exec_command_wait_async( command: str, stdin: Optional[str] = None, encoding: str = 'utf-8', - workdir: Union[_TransportPath, None] = None, + workdir: Optional[_TransportPath] = None, timeout: Optional[float] = 2, **kwargs, ): """Execute a command on the remote machine and wait for it to finish. :param command: the command to execute - :param stdin: the standard input to pass to the command + :param stdin: the input to pass to the command :param encoding: (IGNORED) this is here just to keep the same signature as the one in `BlockingTransport` class :param workdir: the working directory where to execute the command :param timeout: the timeout in seconds @@ -515,7 +612,7 @@ async def exec_command_wait_async( :type command: str :type stdin: str :type encoding: str - :type workdir: str + :type workdir: Union[_TransportPath, None] :type timeout: float :return: a tuple with the return code, the stdout and the stderr of the command @@ -533,9 +630,30 @@ async def exec_command_wait_async( # Since the command is str, both stdout and stderr are strings return (result.returncode, ''.join(str(result.stdout)), ''.join(str(result.stderr))) - async def get_attribute_async(self, path): - """ """ - path = fix_path(path) + async def get_attribute_async(self, path: _TransportPath): + """Return an object FixedFieldsAttributeDict for file in a given path, + as defined in aiida.common.extendeddicts + Each attribute object consists in a dictionary with the following keys: + + * st_size: size of files, in bytes + + * st_uid: user id of owner + + * st_gid: group id of owner + + * st_mode: protection bits + + * st_atime: time of most recent access + + * st_mtime: time of most recent modification + + :param path: path to file + + :type path: _TransportPath + + :return: object FixedFieldsAttributeDict + """ + path = path_2_str(path) from aiida.transports.util import FileAttribute asyncssh_attr = await self._sftp.lstat(path) @@ -558,36 +676,56 @@ async def get_attribute_async(self, path): raise NotImplementedError(f'Mapping the {key} attribute is not implemented') return aiida_attr - async def isdir_async(self, path): + async def isdir_async(self, path: _TransportPath): """Return True if the given path is a directory, False otherwise. Return False also if the path does not exist. + + :param path: the absolute path to check + + :type path: _TransportPath + + :return: True if the path is a directory, False otherwise """ # Return False on empty string if not path: return False - path = fix_path(path) + path = path_2_str(path) return await self._sftp.isdir(path) - async def isfile_async(self, path): + async def isfile_async(self, path: _TransportPath): """Return True if the given path is a file, False otherwise. Return False also if the path does not exist. + + :param path: the absolute path to check + + :type path: _TransportPath + + :return: True if the path is a file, False otherwise """ # Return False on empty string if not path: return False - path = fix_path(path) + path = path_2_str(path) return await self._sftp.isfile(path) async def listdir_async(self, path: _TransportPath, pattern=None): - """ + """Return a list of the names of the entries in the given path. + The list is in arbitrary order. It does not include the special + entries '.' and '..' even if they are present in the directory. + + :param path: an absolute path + :param pattern: if used, listdir returns a list of files matching + filters in Unix style. Unix only. + + :type path: _TransportPath - :param path: the absolute path to list + :return: a list of strings """ - path = fix_path(path) + path = path_2_str(path) if not pattern: list_ = list(await self._sftp.listdir(path)) else: @@ -607,9 +745,12 @@ async def listdir_withattributes_async(self, path: _TransportPath, pattern: Opti The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. - :param str path: absolute path to list - :param str pattern: if used, listdir returns a list of files matching + :param path: absolute path to list + :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. + + :type path: _TransportPath + :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is the following:: @@ -624,7 +765,7 @@ async def listdir_withattributes_async(self, path: _TransportPath, pattern: Opti (if the file is a folder, a directory, ...). 'attributes' behaves as the output of transport.get_attribute(); isdir is a boolean indicating if the object is a directory or not. """ - path = fix_path(path) + path = path_2_str(path) retlist = [] listdir = await self.listdir_async(path, pattern) for file_name in listdir: @@ -639,13 +780,15 @@ async def makedirs_async(self, path, ignore_existing=False): Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. - :param str path: absolute path to directory to create + :param path: absolute path to directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist + :type path: _TransportPath + :raises: OSError, if directory at path already exists """ - path = fix_path(path) + path = path_2_str(path) try: await self._sftp.makedirs(path, exist_ok=ignore_existing) @@ -660,13 +803,15 @@ async def makedirs_async(self, path, ignore_existing=False): async def mkdir_async(self, path: _TransportPath, ignore_existing=False): """Create a directory. - :param str path: absolute path to directory to create + :param path: absolute path to directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist + :type path: _TransportPath + :raises: OSError, if directory at path already exists """ - path = fix_path(path) + path = path_2_str(path) try: await self._sftp.mkdir(path) @@ -684,15 +829,20 @@ async def mkdir_async(self, path: _TransportPath, ignore_existing=False): else: raise TransportInternalError(f'Error while creating directory {path}: {exc}') - async def remove_async(self, path): + async def normalize_async(self, path: _TransportPath): + raise NotImplementedError('Not implemented, waiting for a use case.') + + async def remove_async(self, path: _TransportPath): """Remove the file at the given path. This only works on files; for removing folders (directories), use rmdir. - :param str path: path to file to remove + :param path: path to file to remove + + :type path: _TransportPath :raise OSError: if the path is a directory """ - path = fix_path(path) + path = path_2_str(path) # TODO: check if asyncssh does return SFTPFileIsADirectory in this case # if that's the case, we can get rid of the isfile check if await self.isdir_async(path): @@ -700,18 +850,21 @@ async def remove_async(self, path): else: await self._sftp.remove(path) - async def rename_async(self, oldpath, newpath): + async def rename_async(self, oldpath: _TransportPath, newpath: _TransportPath): """ Rename a file or folder from oldpath to newpath. - :param str oldpath: existing name of the file or folder - :param str newpath: new name for the file or folder + :param oldpath: existing name of the file or folder + :param newpath: new name for the file or folder + + :type oldpath: _TransportPath + :type newpath: _TransportPath :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string """ - oldpath = fix_path(oldpath) - newpath = fix_path(newpath) + oldpath = path_2_str(oldpath) + newpath = path_2_str(newpath) if not oldpath or not newpath: raise ValueError('oldpath and newpath must be non-empty strings') @@ -720,40 +873,51 @@ async def rename_async(self, oldpath, newpath): await self._sftp.rename(oldpath, newpath) - async def rmdir_async(self, path): + async def rmdir_async(self, path: _TransportPath): """Remove the folder named path. This works only for empty folders. For recursive remove, use rmtree. :param str path: absolute path to the folder to remove + + :type path: _TransportPath """ - path = fix_path(path) + path = path_2_str(path) try: await self._sftp.rmdir(path) except asyncssh.sftp.SFTPFailure: raise OSError(f'Error while removing directory {path}: probably directory is not empty') - async def rmtree_async(self, path): + async def rmtree_async(self, path: _TransportPath): """Remove the folder named path, and all its contents. :param str path: absolute path to the folder to remove + + :type path: _TransportPath + + :raises OSError: if the operation fails """ - path = fix_path(path) + path = path_2_str(path) try: await self._sftp.rmtree(path, ignore_errors=False) except asyncssh.Error as exc: raise OSError(f'Error while removing directory tree {path}: {exc}') - async def path_exists_async(self, path): - """Returns True if path exists, False otherwise.""" - path = fix_path(path) + async def path_exists_async(self, path: _TransportPath): + """Returns True if path exists, False otherwise. + + :param path: path to check + + :type path: _TransportPath + """ + path = path_2_str(path) return await self._sftp.exists(path) async def whoami_async(self): """Get the remote username - :return: list of username (str), - retval (int), - stderr (str) + :return: username (str), + + :raises OSError: if the command fails """ command = 'whoami' # Assuming here that the username is either ASCII or UTF-8 encoded @@ -767,15 +931,20 @@ async def whoami_async(self): self.logger.error(f"Problem executing whoami. Exit code: {retval}, stdout: '{username}', stderr: '{stderr}'") raise OSError(f'Error while executing whoami. Exit code: {retval}') - async def symlink_async(self, remotesource, remotedestination): + async def symlink_async(self, remotesource: _TransportPath, remotedestination: _TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: absolute path to remote source :param remotedestination: absolute path to remote destination + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + + :raises ValueError: if remotedestination has patterns """ - remotesource = fix_path(remotesource) - remotedestination = fix_path(remotedestination) + remotesource = path_2_str(remotesource) + remotedestination = path_2_str(remotedestination) if self.has_magic(remotesource): if self.has_magic(remotedestination): @@ -784,29 +953,42 @@ async def symlink_async(self, remotesource, remotedestination): # find all files matching pattern for this_source in await self._sftp.glob(remotesource): # create the name of the link: take the last part of the path - this_dest = os.path.join(remotedestination, os.path.split(this_source)[-1]) + this_dest = os.path.join(remotedestination, os.path.split(this_source)[-1]) # type: ignore [arg-type] + # in the line above I am sure that this_source is a string, + # since asyncssh.sftp.glob() returns only str if argument remotesource is a str await self._sftp.symlink(this_source, this_dest) else: await self._sftp.symlink(remotesource, remotedestination) - async def glob_async(self, pathname): + async def glob_async(self, pathname: _TransportPath): """Return a list of paths matching a pathname pattern. The pattern may contain simple shell-style wildcards a la fnmatch. - :param str pathname: the pathname pattern to match. - It should only be an absolute path. + :param pathname: the pathname pattern to match. + It should only be absolute path. + + :type pathname: _TransportPath + :return: a list of paths matching the pattern. """ + pathname = path_2_str(pathname) return await self._sftp.glob(pathname) - async def chmod_async(self, path, mode, follow_symlinks=True): + async def chmod_async(self, path: _TransportPath, mode: int, follow_symlinks: bool = True): """Change the permissions of a file. - :param str path: path to the file - :param int mode: the new permissions + :param path: path to the file + :param mode: the new permissions + :param bool follow_symlinks: if True, follow symbolic links + + :type path: _TransportPath + :type mode: int + :type follow_symlinks: bool + + :raises OSError: if the path is empty """ - path = fix_path(path) + path = path_2_str(path) if not path: raise OSError('Input path is an empty argument.') try: @@ -814,14 +996,20 @@ async def chmod_async(self, path, mode, follow_symlinks=True): except asyncssh.sftp.SFTPNoSuchFile as exc: raise OSError(f'Error {exc}, directory does not exists') - async def chown_async(self, path, uid, gid): + async def chown_async(self, path: _TransportPath, uid: int, gid: int): """Change the owner and group id of a file. - :param str path: path to the file - :param int uid: the new owner id - :param int gid: the new group id + :param path: path to the file + :param uid: the new owner id + :param gid: the new group id + + :type path: _TransportPath + :type uid: int + :type gid: int + + :raises OSError: if the path is empty """ - path = fix_path(path) + path = path_2_str(path) if not path: raise OSError('Input path is an empty argument.') try: @@ -829,15 +1017,25 @@ async def chown_async(self, path, uid, gid): except asyncssh.sftp.SFTPNoSuchFile as exc: raise OSError(f'Error {exc}, directory does not exists') - async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): + async def copy_from_remote_to_remote_async( + self, + transportdestination: Union['BlockingTransport', 'AsyncTransport'], + remotesource: _TransportPath, + remotedestination: _TransportPath, + **kwargs, + ): """Copy files or folders from a remote computer to another remote computer, asynchronously. :param transportdestination: transport to be used for the destination computer - :param str remotesource: path to the remote source directory / file - :param str remotedestination: path to the remote destination directory / file + :param remotesource: path to the remote source directory / file + :param remotedestination: path to the remote destination directory / file :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get + :type transportdestination: Union['BlockingTransport', 'AsyncTransport'] + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + .. note:: the keyword 'dereference' SHOULD be set to False for the final put (onto the destination), while it can be set to the value given in kwargs for the get from the source. In that @@ -880,10 +1078,13 @@ async def copy_from_remote_to_remote_async(self, transportdestination, remotesou os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put ) - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: _TransportPath): + """Return a string to be used to connect to the remote computer. + + :param remotedir: the remote directory to connect to + + :type remotedir: _TransportPath + """ connect_string = self._gotocomputer_string(remotedir) cmd = f'ssh -t {self.machine} {connect_string}' return cmd - - async def normalize_async(self, path: _TransportPath): - raise NotImplementedError('Not implemented, waiting for a use case') diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index 1155a46895..4b9cb2bd7b 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -16,7 +16,7 @@ import sys from collections import OrderedDict from pathlib import Path, PurePosixPath -from typing import Union +from typing import Optional, Union from aiida.common.exceptions import InternalError from aiida.common.lang import classproperty @@ -43,7 +43,7 @@ def validate_positive_number(ctx, param, value): return value -def fix_path(path: _TransportPath) -> str: +def path_2_str(path: _TransportPath) -> str: """Convert an instance of _TransportPath = Union[str, Path, PurePosixPath] instance to a string.""" # We could check if the path is a Path or PurePosixPath instance, but it's too much overhead. return str(path) @@ -116,11 +116,17 @@ def __init__(self, *args, **kwargs): @abc.abstractmethod def open(self): - """Opens a local transport channel""" + """Opens a transport channel + + :raises InvalidOperation: if the transport is already open. + """ @abc.abstractmethod def close(self): - """Closes the local transport channel""" + """Closes the transport channel. + + :raises InvalidOperation: if the transport is already closed. + """ def __enter__(self): """For transports that require opening a connection, opens @@ -249,8 +255,8 @@ def get_safe_open_interval(self): """ return self._safe_open_interval - def has_magic(self, string): - string = fix_path(string) + def has_magic(self, string: _TransportPath): + string = path_2_str(string) """Return True if the given string contains any special shell characters.""" return self._MAGIC_CHECK.search(string) is not None @@ -278,6 +284,10 @@ class BlockingTransport(abc.ABC, _BaseTransport): # keys: 'default', 'prompt', 'help', 'non_interactive_default' _valid_auth_options = [] + def __init__(self, *args, **kwargs): + # if __init__ is overridden in a subclass, it should always call the parent __init__ + super().__init__(*args, **kwargs) + def __repr__(self): return f'<{self.__class__.__name__}: {self!s}>' @@ -286,76 +296,102 @@ def __str__(self): """return [Transport class or subclass]""" @abc.abstractmethod - def chmod(self, path, mode): + def chmod(self, path: _TransportPath, mode): """Change permissions of a path. - :param str path: path to file - :param int mode: new permissions + :param path: path to file + :param mode: new permissions + + :type path: _TransportPath + :type mode: int """ @abc.abstractmethod - def chown(self, path, uid, gid): + def chown(self, path: _TransportPath, uid: int, gid: int): """Change the owner (uid) and group (gid) of a file. As with python's os.chown function, you must pass both arguments, so if you only want to change one, use stat first to retrieve the current owner and group. - :param str path: path to the file to change the owner and group of - :param int uid: new owner's uid - :param int gid: new group id + :param path: path to the file to change the owner and group of + :param uid: new owner's uid + :param gid: new group id + + :type path: _TransportPath + :type uid: int + :type gid: int """ @abc.abstractmethod - def copy(self, remotesource, remotedestination, dereference=False, recursive=True): + def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False, recursive=True): """Copy a file or a directory from remote source to remote destination (On the same remote machine) - :param str remotesource: path of the remote source directory / file - :param str remotedestination: path of the remote destination directory / file + :param remotesource: path of the remote source directory / file + :param remotedestination: path of the remote destination directory / file :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves - :type dereference: bool :param recursive: if True copy directories recursively, otherwise only copy the specified file(s) + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + :type dereference: bool :type recursive: bool :raises: OSError, if one of src or dst does not exist """ @abc.abstractmethod - def copyfile(self, remotesource, remotedestination, dereference=False): + def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): """Copy a file from remote source to remote destination (On the same remote machine) - :param str remotesource: path of the remote source directory / file - :param str remotedestination: path of the remote destination directory / file + :param remotesource: path of the remote source directory / file + :param remotedestination: path of the remote destination directory / file :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath :type dereference: bool :raises OSError: if one of src or dst does not exist """ @abc.abstractmethod - def copytree(self, remotesource, remotedestination, dereference=False): + def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): """Copy a folder from remote source to remote destination (On the same remote machine) - :param str remotesource: path of the remote source directory / file - :param str remotedestination: path of the remote destination directory / file + :param remotesource: path of the remote source directory / file + :param remotedestination: path of the remote destination directory / file :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath :type dereference: bool :raise OSError: if one of src or dst does not exist """ ## non-abtract methods. Plugin developers can safely ingore developing these methods - def copy_from_remote_to_remote(self, transportdestination, remotesource, remotedestination, **kwargs): + def copy_from_remote_to_remote( + self, + transportdestination: Union['BlockingTransport', 'AsyncTransport'], + remotesource: _TransportPath, + remotedestination: _TransportPath, + **kwargs, + ): """Copy files or folders from a remote computer to another remote computer. :param transportdestination: transport to be used for the destination computer - :param str remotesource: path to the remote source directory / file - :param str remotedestination: path to the remote destination directory / file + :param remotesource: path to the remote source directory / file + :param remotedestination: path to the remote destination directory / file :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get + :type transportdestination: Union['BlockingTransport', 'AsyncTransport'] + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + .. note:: the keyword 'dereference' SHOULD be set to False for the final put (onto the destination), while it can be set to the value given in kwargs for the get from the source. In that @@ -394,10 +430,12 @@ def copy_from_remote_to_remote(self, transportdestination, remotesource, remoted # from sandbox.get_abs_path('*') would not work for files # beginning with a dot ('.'). for filename in sandbox.get_content_list(): + # no matter is transpordestination is BlockingTransport or AsyncTransport + # the following method will work, as both classes support put(), blocking method transportdestination.put(os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put) @abc.abstractmethod - def _exec_command_internal(self, command, workdir=None, **kwargs): + def _exec_command_internal(self, command: str, workdir: Optional[_TransportPath] = None, **kwargs): """Execute the command on the shell, similarly to os.system. Enforce the execution to be run from `workdir`. @@ -405,15 +443,19 @@ def _exec_command_internal(self, command, workdir=None, **kwargs): If possible, use the higher-level exec_command_wait function. - :param str command: execute the command given as a string + :param command: execute the command given as a string :param workdir: (optional, default=None) if set, the command will be executed in the specified working directory. + + :type command: str + :type workdir: _TransportPath + :return: stdin, stdout, stderr and the session, when this exists \ (can be None). """ @abc.abstractmethod - def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): + def exec_command_wait_bytes(self, command: str, stdin=None, workdir: Optional[_TransportPath] = None, **kwargs): """Execute the command on the shell, waits for it to finish, and return the retcode, the stdout and the stderr as bytes. @@ -421,14 +463,20 @@ def exec_command_wait_bytes(self, command, stdin=None, workdir=None, **kwargs): The command implementation can have some additional plugin-specific kwargs. - :param str command: execute the command given as a string + :param command: execute the command given as a string :param stdin: (optional,default=None) can be bytes or a file-like object. :param workdir: (optional, default=None) if set, the command will be executed in the specified working directory. + + :type command: str + :type workdir: _TransportPath + :return: a tuple: the retcode (int), stdout (bytes) and stderr (bytes). """ - def exec_command_wait(self, command, stdin=None, encoding='utf-8', workdir=None, **kwargs): + def exec_command_wait( + self, command, stdin=None, encoding='utf-8', workdir: Optional[_TransportPath] = None, **kwargs + ): """Executes the specified command and waits for it to finish. :note: this function also decodes the bytes received into a string with the specified encoding, @@ -446,6 +494,10 @@ def exec_command_wait(self, command, stdin=None, encoding='utf-8', workdir=None, :param workdir: (optional, default=None) if set, the command will be executed in the specified working directory. + :type command: str + :type encoding: str + :type workdir: _TransportPath + :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both strings, decoded with the specified encoding. """ @@ -456,33 +508,42 @@ def exec_command_wait(self, command, stdin=None, encoding='utf-8', workdir=None, return (retval, stdout_bytes.decode(encoding), stderr_bytes.decode(encoding)) @abc.abstractmethod - def get(self, remotepath, localpath, *args, **kwargs): + def get(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): """Retrieve a file or folder from remote source to local destination - dst must be an absolute path (src not necessarily) + both localpath and remotepath must be an absolute path. This method should be able to handle remothpath containing glob patterns, in that case should only downloading matching patterns. - :param remotepath: (str) remote_folder_path - :param localpath: (str) local_folder_path + :param remotepath: remote_folder_path + :param localpath: (local_folder_path + + :type remotepath: _TransportPath + :type localpath: _TransportPath """ @abc.abstractmethod - def getfile(self, remotepath, localpath, *args, **kwargs): + def getfile(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): """Retrieve a file from remote source to local destination - dst must be an absolute path (src not necessarily) + both localpath and remotepath must be an absolute path. + + :param remotepath: remote_folder_path + :param localpath: local_folder_path - :param str remotepath: remote_folder_path - :param str localpath: local_folder_path + :type remotepath: _TransportPath + :type localpath: _TransportPath """ @abc.abstractmethod - def gettree(self, remotepath, localpath, *args, **kwargs): + def gettree(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): """Retrieve a folder recursively from remote source to local destination - dst must be an absolute path (src not necessarily) + both localpath and remotepath must be an absolute path. - :param str remotepath: remote_folder_path - :param str localpath: local_folder_path + :param remotepath: remote_folder_path + :param localpath: local_folder_path + + :type remotepath: _TransportPath + :type localpath: _TransportPath """ @abc.abstractmethod @@ -501,7 +562,7 @@ def getcwd(self): ) @abc.abstractmethod - def get_attribute(self, path): + def get_attribute(self, path: _TransportPath): """Return an object FixedFieldsAttributeDict for file in a given path, as defined in aiida.common.extendeddicts Each attribute object consists in a dictionary with the following keys: @@ -518,14 +579,20 @@ def get_attribute(self, path): * st_mtime: time of most recent modification - :param str path: path to file + :param path: path to file + + :type path: _TransportPath + :return: object FixedFieldsAttributeDict """ - def get_mode(self, path): + def get_mode(self, path: _TransportPath): """Return the portion of the file's mode that can be set by chmod(). - :param str path: path to file + :param path: path to file + + :type path: _TransportPath + :return: the portion of the file's mode that can be set by chmod() """ import stat @@ -533,45 +600,57 @@ def get_mode(self, path): return stat.S_IMODE(self.get_attribute(path).st_mode) @abc.abstractmethod - def isdir(self, path): + def isdir(self, path: _TransportPath): """True if path is an existing directory. Return False also if the path does not exist. - :param str path: path to directory + :param path: path to directory + + :type path: _TransportPath + :return: boolean """ @abc.abstractmethod - def isfile(self, path): + def isfile(self, path: _TransportPath): """Return True if path is an existing file. Return False also if the path does not exist. - :param str path: path to file + :param path: path to file + + :type path: _TransportPath + :return: boolean """ @abc.abstractmethod - def listdir(self, path='.', pattern=None): + def listdir(self, path: _TransportPath = '.', pattern=None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. - :param str path: path to list (default to '.') - :param str pattern: if used, listdir returns a list of files matching + :param path: path to list (default to '.') + DEPRECATED: using '.' is deprecated and will be removed in the next major version. + :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. + + :type path: _TransportPath + :return: a list of strings """ - def listdir_withattributes(self, path: _TransportPath = '.', pattern=None): + def listdir_withattributes(self, path: _TransportPath = '.', pattern: Optional[str] = None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. - :param str path: path to list (default to '.') + :param path: path to list (default to '.') if using a relative path, it is relative to the current working directory, taken from DEPRECATED `self.getcwd()`. - :param str pattern: if used, listdir returns a list of files matching + :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. + :type path: _TransportPath + :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is the following:: @@ -586,7 +665,7 @@ def listdir_withattributes(self, path: _TransportPath = '.', pattern=None): (if the file is a folder, a directory, ...). 'attributes' behaves as the output of transport.get_attribute(); isdir is a boolean indicating if the object is a directory or not. """ - path = fix_path(path) + path = path_2_str(path) retlist = [] if path.startswith('/'): cwd = Path(path).resolve().as_posix() @@ -604,111 +683,133 @@ def listdir_withattributes(self, path: _TransportPath = '.', pattern=None): return retlist @abc.abstractmethod - def makedirs(self, path, ignore_existing=False): + def makedirs(self, path: _TransportPath, ignore_existing=False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. - :param str path: directory to create + :param path: directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist + :type path: _TransportPath :raises: OSError, if directory at path already exists """ @abc.abstractmethod - def mkdir(self, path, ignore_existing=False): + def mkdir(self, path: _TransportPath, ignore_existing=False): """Create a folder (directory) named path. - :param str path: name of the folder to create + :param path: name of the folder to create :param bool ignore_existing: if True, does not give any error if the directory already exists + :type path: _TransportPath :raises: OSError, if directory at path already exists """ @abc.abstractmethod - def normalize(self, path='.'): + def normalize(self, path: _TransportPath = '.'): """Return the normalized path (on the server) of a given path. This can be used to quickly resolve symbolic links or determine what the server is considering to be the "current folder". - :param str path: path to be normalized + :param path: path to be normalized + + :type path: _TransportPath :raise OSError: if the path can't be resolved on the server """ @abc.abstractmethod - def put(self, localpath, remotepath, *args, **kwargs): + def put(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): """Put a file or a directory from local src to remote dst. - src must be an absolute path (dst not necessarily)) + both localpath and remotepath must be an absolute path. Redirects to putfile and puttree. This method should be able to handle localpath containing glob patterns, in that case should only uploading matching patterns. - :param str localpath: absolute path to local source - :param str remotepath: path to remote destination + :param localpath: absolute path to local source + :param remotepath: path to remote destination + + :type localpath: _TransportPath + :type remotepath: _TransportPath """ @abc.abstractmethod - def putfile(self, localpath, remotepath, *args, **kwargs): + def putfile(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): """Put a file from local src to remote dst. - src must be an absolute path (dst not necessarily)) + both localpath and remotepath must be an absolute path. - :param str localpath: absolute path to local file - :param str remotepath: path to remote file + :param localpath: absolute path to local file + :param remotepath: path to remote file + + :type localpath: _TransportPath + :type remotepath: _TransportPath """ @abc.abstractmethod - def puttree(self, localpath, remotepath, *args, **kwargs): + def puttree(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): """Put a folder recursively from local src to remote dst. - src must be an absolute path (dst not necessarily)) + both localpath and remotepath must be an absolute path. + + :param localpath: absolute path to local folder + :param remotepath: path to remote folder - :param str localpath: absolute path to local folder - :param str remotepath: path to remote folder + :type localpath: _TransportPath + :type remotepath: _TransportPath """ @abc.abstractmethod - def remove(self, path): + def remove(self, path: _TransportPath): """Remove the file at the given path. This only works on files; for removing folders (directories), use rmdir. - :param str path: path to file to remove + :param path: path to file to remove + + :type path: _TransportPath :raise OSError: if the path is a directory """ @abc.abstractmethod - def rename(self, oldpath, newpath): + def rename(self, oldpath: _TransportPath, newpath: _TransportPath): """Rename a file or folder from oldpath to newpath. - :param str oldpath: existing name of the file or folder - :param str newpath: new name for the file or folder + :param oldpath: existing name of the file or folder + :param newpath: new name for the file or folder + + :type oldpath: _TransportPath + :type newpath: _TransportPath :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string """ @abc.abstractmethod - def rmdir(self, path): + def rmdir(self, path: _TransportPath): """Remove the folder named path. This works only for empty folders. For recursive remove, use rmtree. - :param str path: absolute path to the folder to remove + :param path: absolute path to the folder to remove + + :type path: _TransportPath """ @abc.abstractmethod - def rmtree(self, path): + def rmtree(self, path: _TransportPath): """Remove recursively the content at path - :param str path: absolute path to remove + :param path: absolute path to remove + + :type path: _TransportPath :raise OSError: if the rm execution failed. """ @abc.abstractmethod - def gotocomputer_command(self, remotedir): + def gotocomputer_command(self, remotedir: _TransportPath): """Return a string to be run using os.system in order to connect via the transport to the remote directory. @@ -718,24 +819,29 @@ def gotocomputer_command(self, remotedir): * A reasonable error message is produced if the folder does not exist - :param str remotedir: the full path of the remote directory + :param remotedir: the full path of the remote directory + + :type remotedir: _TransportPath """ @abc.abstractmethod - def symlink(self, remotesource, remotedestination): + def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: remote source :param remotedestination: remote destination + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath """ def whoami(self): """Get the remote username - :return: list of username (str), - retval (int), - stderr (str) + :return: username (str) + + :raise OSError: if the whoami command fails. """ command = 'whoami' # Assuming here that the username is either ASCII or UTF-8 encoded @@ -750,8 +856,12 @@ def whoami(self): raise OSError(f'Error while executing whoami. Exit code: {retval}') @abc.abstractmethod - def path_exists(self, path): - """Returns True if path exists, False otherwise.""" + def path_exists(self, path: _TransportPath): + """Returns True if path exists, False otherwise. + + :param path: path to check for existence + + :type path: _TransportPath""" # The following definitions are almost copied and pasted # from the python module glob. @@ -761,11 +871,14 @@ def glob(self, pathname: _TransportPath): The pattern may contain simple shell-style wildcards a la fnmatch. :param pathname: the pathname pattern to match. - It should only be absolute path of type _TransportPath. + It should only be an absolute path. DEPRECATED: using relative path is deprecated. + + :type pathname: _TransportPath + :return: a list of paths matching the pattern. """ - pathname = fix_path(pathname) + pathname = path_2_str(pathname) if not pathname.startswith('/'): warn_deprecation( 'Using relative paths across transport in `glob` is deprecated ' @@ -779,6 +892,7 @@ def iglob(self, pathname): The pattern may contain simple shell-style wildcards a la fnmatch. + :param pathname: the pathname pattern to match. """ if not self.has_magic(pathname): # if os.path.lexists(pathname): # ORIGINAL @@ -809,7 +923,11 @@ def iglob(self, pathname): # takes a literal basename (so it only has to check for its existence). def glob1(self, dirname, pattern): - """Match subpaths of dirname against pattern.""" + """Match subpaths of dirname against pattern. + + :param dirname: path to the directory + :param pattern: pattern to match against + """ if not dirname: dirname = self.getcwd() if isinstance(pattern, str) and not isinstance(dirname, str): @@ -823,7 +941,11 @@ def glob1(self, dirname, pattern): return fnmatch.filter(names, pattern) def glob0(self, dirname, basename): - """Wrap basename i a list if it is empty or if dirname/basename is an existing path, else return empty list.""" + """Wrap basename i a list if it is empty or if dirname/basename is an existing path, else return empty list. + + :param dirname: path to the directory + :param basename: basename to match against + """ if basename == '': # `os.path.split()` returns an empty basename for paths ending with a # directory separator. 'q*x/' should match only directories. @@ -978,68 +1100,223 @@ async def glob_async(self, pathname): class AsyncTransport(abc.ABC, _BaseTransport): + """An abstract base class for asynchronous transports. + All methods are asynchronous and should be implemented by subclasses. + avoid overriding the sync methods, as they are implemented for backward compatibility, only.""" + + # This will be used for connection authentication + # To be defined in the subclass, the format is a list of tuples + # where the first element is the name of the parameter and the second + # is a dictionary with the following + # keys: 'default', 'prompt', 'help', 'non_interactive_default' + _valid_auth_options = [] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @abc.abstractmethod async def open_async(self): - """Open the transport.""" + """Open a transport channel. + + :raises InvalidOperation: if the transport is already open. + """ @abc.abstractmethod async def close_async(self): - """Close the transport.""" + """Close the transport channel. + + :raises InvalidOperation: if the transport is already closed. + """ @abc.abstractmethod - async def chmod_async(self, path, mode): - """Change permissions of a path.""" + async def chmod_async(self, path: _TransportPath, mode: int): + """Change permissions of a path. + + :param path: path to file or directory + :param mode: new permissions + + :type path: _TransportPath + :type mode: int + """ @abc.abstractmethod - async def chown_async(self, path, uid, gid): - """Change the owner (uid) and group (gid) of a file.""" + async def chown_async(self, path: _TransportPath, uid: int, gid: int): + """Change the owner (uid) and group (gid) of a file. + + :param path: path to file + :param uid: user id of the new owner + :param gid: group id of the new owner + + :type path: _TransportPath + :type uid: int + :type gid: int + """ @abc.abstractmethod async def copy_async(self, remotesource, remotedestination, dereference=False, recursive=True): """Copy a file or a directory from remote source to remote destination - (On the same remote machine)""" + (On the same remote machine) + + :param remotesource: path to the remote source directory / file + :param remotedestination: path to the remote destination directory / file + :param dereference: follow symbolic links + :param recursive: copy recursively + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + :type dereference: bool + :type recursive: bool + + :raises: OSError, src does not exist or if the copy execution failed. + """ @abc.abstractmethod - async def copyfile_async(self, remotesource, remotedestination, dereference=False): + async def copyfile_async(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): """Copy a file from remote source to remote destination - (On the same remote machine)""" + (On the same remote machine) + + :param remotesource: path to the remote source file + :param remotedestination: path to the remote destination file + :param dereference: follow symbolic links + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + :type dereference: bool + + :raises: OSError, src does not exist or if the copy execution failed.""" @abc.abstractmethod - async def copytree_async(self, remotesource, remotedestination, dereference=False): + async def copytree_async(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): """Copy a folder from remote source to remote destination - (On the same remote machine)""" + (On the same remote machine) + + :param remotesource: path to the remote source folder + :param remotedestination: path to the remote destination folder + :param dereference: follow symbolic links + + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + :type dereference: bool + + :raises: OSError, src does not exist or if the copy execution failed.""" @abc.abstractmethod - async def copy_from_remote_to_remote_async(self, transportdestination, remotesource, remotedestination, **kwargs): - """Copy files or folders from a remote computer to another remote computer.""" + async def copy_from_remote_to_remote_async( + self, + transportdestination: Union['BlockingTransport', 'AsyncTransport'], + remotesource: _TransportPath, + remotedestination: _TransportPath, + **kwargs, + ): + """Copy files or folders from a remote computer to another remote computer. + + :param transportdestination: destination transport + :param remotesource: path to the remote source directory / file + :param remotedestination: path to the remote destination directory / file + :param kwargs: keyword parameters passed to the call to transportdestination.put, + except for 'dereference' that is passed to self.get + + :type transportdestination: Union['BlockingTransport', 'AsyncTransport'] + :type remotesource: _TransportPath + :type remotedestination: _TransportPath + """ @abc.abstractmethod - async def exec_command_wait_async(self, command, stdin=None, encoding='utf-8', workdir=None, **kwargs): - """Executes the specified command and waits for it to finish.""" + async def exec_command_wait_async( + self, + command: str, + stdin: Optional[str] = None, + encoding: str = 'utf-8', + workdir: Optional[_TransportPath] = None, + **kwargs, + ): + """Executes the specified command and waits for it to finish. + + :param command: the command to execute + :param stdin: input to the command + :param encoding: (IGNORED) this is here just to keep the same signature as the one in `BlockingTransport` class + :param workdir: working directory where the command will be executed + + :type command: str + :type stdin: str + :type encoding: str + :type workdir: Union[_TransportPath, None] + + :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both strings. + :rtype: Tuple[int, str, str] + """ @abc.abstractmethod - async def get_async(self, remotepath, localpath, *args, **kwargs): - """Retrieve a file or folder from remote source to local destination""" + async def get_async(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + """Retrieve a file or folder from remote source to local destination + both remotepath and localpath must be absolute paths + + This method should be able to handle remotepath containing glob patterns, + in that case should only downloading matching patterns. + + :param remotepath: remote_folder_path + :param localpath: local_folder_path + + :type remotepath: _TransportPath + :type localpath: _TransportPath + """ @abc.abstractmethod - async def getfile_async(self, remotepath, localpath, *args, **kwargs): - """Retrieve a file from remote source to local destination""" + async def getfile_async(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + """Retrieve a file from remote source to local destination + both remotepath and localpath must be absolute paths + + :param remotepath: remote_folder_path + :param localpath: local_folder_path + + :type remotepath: _TransportPath + :type localpath: _TransportPath + """ @abc.abstractmethod - async def gettree_async(self, remotepath, localpath, *args, **kwargs): - """Retrieve a folder recursively from remote source to local destination""" + async def gettree_async(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + """Retrieve a folder recursively from remote source to local destination + both remotepath and localpath must be absolute paths + + :param remotepath: remote_folder_path + :param localpath: local_folder_path + + :type remotepath: _TransportPath + :type localpath: _TransportPath + """ @abc.abstractmethod - async def get_attribute_async(self, path): - """Return an object FixedFieldsAttributeDict for file in a given path""" + async def get_attribute_async(self, path: _TransportPath): + """Return an object FixedFieldsAttributeDict for file in a given path, + as defined in aiida.common.extendeddicts + Each attribute object consists in a dictionary with the following keys: - async def get_mode_async(self, path): + * st_size: size of files, in bytes + + * st_uid: user id of owner + + * st_gid: group id of owner + + * st_mode: protection bits + + * st_atime: time of most recent access + + * st_mtime: time of most recent modification + + :param path: path to file + + :type path: _TransportPath + + :return: object FixedFieldsAttributeDict + """ + + async def get_mode_async(self, path: _TransportPath): """Return the portion of the file's mode that can be set by chmod(). :param str path: path to file + + :type path: _TransportPath + :return: the portion of the file's mode that can be set by chmod() """ import stat @@ -1048,84 +1325,265 @@ async def get_mode_async(self, path): return stat.S_IMODE(attr.st_mode) @abc.abstractmethod - async def isdir_async(self, path): - """True if path is an existing directory.""" + async def isdir_async(self, path: _TransportPath): + """True if path is an existing directory. + Return False also if the path does not exist. + + :param path: path to directory + + :type path: _TransportPath + + :return: boolean + """ @abc.abstractmethod - async def isfile_async(self, path): - """Return True if path is an existing file.""" + async def isfile_async(self, path: _TransportPath): + """Return True if path is an existing file. + Return False also if the path does not exist. + + :param path: path to file + + :type path: _TransportPath + + :return: boolean + """ @abc.abstractmethod - async def listdir_async(self, path: _TransportPath, pattern=None): - """Return a list of the names of the entries in the given path.""" + async def listdir_async(self, path: _TransportPath, pattern: Optional[str] = None): + """Return a list of the names of the entries in the given path. + The list is in arbitrary order. It does not include the special + entries '.' and '..' even if they are present in the directory. + + :param path: an absolute path + :param pattern: if used, listdir returns a list of files matching + filters in Unix style. Unix only. + + :type path: _TransportPath + + :return: a list of strings + """ @abc.abstractmethod - async def listdir_withattributes_async(self, path: _TransportPath, pattern=None): - """Return a list of the names of the entries in the given path.""" + async def listdir_withattributes_async( + self, + path: _TransportPath, + pattern: Optional[str] = None, + ): + """Return a list of the names of the entries in the given path. + The list is in arbitrary order. It does not include the special + entries '.' and '..' even if they are present in the directory. + + :param path: absolute path to list. + :param pattern: if used, listdir returns a list of files matching + filters in Unix style. Unix only. + + :type path: _TransportPath + :type pattern: str + :return: a list of dictionaries, one per entry. + The schema of the dictionary is + the following:: + + { + 'name': String, + 'attributes': FileAttributeObject, + 'isdir': Bool + } + + where 'name' is the file or folder directory, and any other information is metadata + (if the file is a folder, a directory, ...). 'attributes' behaves as the output of + transport.get_attribute(); isdir is a boolean indicating if the object is a directory or not. + """ @abc.abstractmethod - async def makedirs_async(self, path, ignore_existing=False): - """Super-mkdir; create a leaf directory and all intermediate ones.""" + async def makedirs_async(self, path: _TransportPath, ignore_existing=False): + """Super-mkdir; create a leaf directory and all intermediate ones. + Works like mkdir, except that any intermediate path segment (not + just the rightmost) will be created if it does not exist. + + :param path: directory to create + :param bool ignore_existing: if set to true, it doesn't give any error + if the leaf directory does already exist + :type path: _TransportPath + + :raises: OSError, if directory at path already exists + """ @abc.abstractmethod - async def mkdir_async(self, path, ignore_existing=False): - """Create a folder (directory) named path.""" + async def mkdir_async(self, path: _TransportPath, ignore_existing=False): + """Create a folder (directory) named path. + + :param path: name of the folder to create + :param bool ignore_existing: if True, does not give any error if the + directory already exists. + :type path: _TransportPath + + :raises: OSError, if directory at path already exists + """ @abc.abstractmethod async def normalize_async(self, path: _TransportPath): - """Return the normalized path (on the server) of a given path.""" + """Return the normalized path (on the server) of a given path. + This can be used to quickly resolve symbolic links or determine + what the server is considering to be the "current folder". + + :param path: path to be normalized + + :type path: _TransportPath + + :raise OSError: if the path can't be resolved on the server + """ @abc.abstractmethod - async def put_async(self, localpath, remotepath, *args, **kwargs): - """Put a file or a directory from local src to remote dst.""" + async def put_async(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + """Put a file or a directory from local src to remote dst. + both localpath and remotepath must be absolute paths. + Redirects to putfile and puttree. + + This method should be able to handle localpath containing glob patterns, + in that case should only uploading matching patterns. + + :param localpath: absolute path to local source + :param remotepath: path to remote destination + + :type localpath: _TransportPath + :type remotepath: _TransportPath + """ @abc.abstractmethod - async def putfile_async(self, localpath, remotepath, *args, **kwargs): - """Put a file from local src to remote dst.""" + async def putfile_async(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + """Put a file from local src to remote dst. + both localpath and remotepath must be absolute paths. + + :param localpath: absolute path to local file + :param remotepath: path to remote file + + :type localpath: _TransportPath + :type remotepath: _TransportPath + """ @abc.abstractmethod - async def puttree_async(self, localpath, remotepath, *args, **kwargs): - """Put a folder recursively from local src to remote dst.""" + async def puttree_async(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + """Put a folder recursively from local src to remote dst. + both localpath and remotepath must be absolute paths. + + :param localpath: absolute path to local folder + :param remotepath: path to remote folder + + :type localpath: _TransportPath + :type remotepath: _TransportPath + """ @abc.abstractmethod - async def remove_async(self, path): - """Remove the file at the given path.""" + async def remove_async(self, path: _TransportPath): + """Remove the file at the given path. This only works on files; + for removing folders (directories), use rmdir. + + :param path: path to file to remove + + :type path: _TransportPath + + :raise OSError: if the path is a directory + """ @abc.abstractmethod - async def rename_async(self, oldpath, newpath): - """Rename a file or folder from oldpath to newpath.""" + async def rename_async(self, oldpath: _TransportPath, newpath: _TransportPath): + """Rename a file or folder from oldpath to newpath. + + :param oldpath: existing name of the file or folder + :param newpath: new name for the file or folder + + :type oldpath: _TransportPath + :type newpath: _TransportPath + + :raises OSError: if oldpath/newpath is not found + :raises ValueError: if oldpath/newpath is not a valid string + """ @abc.abstractmethod - async def rmdir_async(self, path): - """Remove the folder named path.""" + async def rmdir_async(self, path: _TransportPath): + """Remove the folder named path. + This works only for empty folders. For recursive remove, use rmtree. + + :param path: absolute path to the folder to remove + + :type path: _TransportPath + """ @abc.abstractmethod - async def rmtree_async(self, path): - """Remove recursively the content at path""" + async def rmtree_async(self, path: _TransportPath): + """Remove recursively the content at path + + :param path: absolute path to remove + + :type path: _TransportPath + + :raise OSError: if the rm execution failed. + """ @abc.abstractmethod - async def symlink_async(self, remotesource, remotedestination): - """Create a symbolic link between the remote source and the remote destination.""" + def gotocomputer_command(self, remotedir: _TransportPath): + """Return a string to be run using os.system in order to connect + via the transport to the remote directory. + + NOTE: This method is not async, abd need not to be, + as it's eventually used for interactive shell commands. + + Expected behaviors: + + * A new bash session is opened + + * A reasonable error message is produced if the folder does not exist + + :param remotedir: the full path of the remote directory + + :type remotedir: _TransportPath + """ @abc.abstractmethod - async def whoami_async(self): - """Get the remote username""" + async def symlink_async(self, remotesource: _TransportPath, remotedestination: _TransportPath): + """Create a symbolic link between the remote source and the remote destination. + + :param remotesource: remote source + :param remotedestination: remote destination + + :param remotesource: absolute path to remote source + :param remotedestination: absolute path to remote destination + """ @abc.abstractmethod - async def path_exists_async(self, path): - """Returns True if path exists, False otherwise.""" + async def whoami_async(self): + """Get the remote username + + :return: username (str) + + :raise OSError: if the whoami command fails. + """ @abc.abstractmethod - async def glob_async(self, pathname): - """Return a list of paths matching a pathname pattern.""" + async def path_exists_async(self, path: _TransportPath): + """Returns True if path exists, False otherwise. + + :param path: path to check for existence + + :type path: _TransportPath + """ @abc.abstractmethod - def gotocomputer_command(self, remotedir): - """Return a string to be run using os.system in order to connect - via the transport to the remote directory.""" + async def glob_async(self, pathname: _TransportPath): + """Return a list of paths matching a pathname pattern. + + The pattern may contain simple shell-style wildcards a la fnmatch. + + :param pathname: the pathname pattern to match. + It should only be absolute path. + + :type pathname: _TransportPath + + :return: a list of paths matching the pattern. + """ ## Blocking counterpart methods. We need these for backwards compatibility - # This is useful, only because some part of engine and + # We need these methods, only because some part of codebase and # many external plugins are synchronous, in those cases blocking calls make more sense. # However, be aware you cannot use these methods in an async functions, # because they will block the event loop. @@ -1140,9 +1598,6 @@ def open(self): def close(self): return self.run_command_blocking(self.close_async) - def chown(self, *args, **kwargs): - raise NotImplementedError('Not implemented, for now') - def get(self, *args, **kwargs): return self.run_command_blocking(self.get_async, *args, **kwargs) From 6e350e7f4f4dfe2af48152650b27b5e77ff6a352 Mon Sep 17 00:00:00 2001 From: Ali Khosravi Date: Thu, 21 Nov 2024 08:22:52 +0100 Subject: [PATCH 09/29] added computer test for ssh_async --- src/aiida/transports/plugins/local.py | 7 ++++--- src/aiida/transports/plugins/ssh.py | 3 ++- src/aiida/transports/plugins/ssh_async.py | 1 + tests/cmdline/commands/test_computer.py | 15 +++++++++++++++ 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index d8bbc2ddc5..32608dad2c 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -761,13 +761,13 @@ def _exec_command_internal(self, command, workdir: Optional[_TransportPath] = No """ from aiida.common.escaping import escape_for_bash - workdir = path_2_str(workdir) + if workdir: + workdir = path_2_str(workdir) # Note: The outer shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. bash_commmand = f'{self._bash_command_str}-c ' command = bash_commmand + escape_for_bash(command) - if workdir: cwd = workdir else: @@ -796,7 +796,8 @@ def exec_command_wait_bytes(self, command, stdin=None, workdir: Optional[_Transp :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both bytes and the return_value is an int. """ - workdir = path_2_str(workdir) + if workdir: + workdir = path_2_str(workdir) with self._exec_command_internal(command, workdir) as process: if stdin is not None: # Implicitly assume that the desired encoding is 'utf-8' if I receive a string. diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index 6c810fe523..c2fb643f6d 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -1471,7 +1471,8 @@ def exec_command_wait_bytes( import socket import time - workdir = path_2_str(workdir) + if workdir: + workdir = path_2_str(workdir) ssh_stdin, stdout, stderr, channel = self._exec_command_internal( command, combine_stderr, bufsize=bufsize, workdir=workdir diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 444746c4dd..77549be75b 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -620,6 +620,7 @@ async def exec_command_wait_async( """ if workdir: + workdir = path_2_str(workdir) command = f'cd {workdir} && {command}' bash_commmand = self._bash_command_str + '-c ' diff --git a/tests/cmdline/commands/test_computer.py b/tests/cmdline/commands/test_computer.py index 422a155214..2741d8e619 100644 --- a/tests/cmdline/commands/test_computer.py +++ b/tests/cmdline/commands/test_computer.py @@ -985,3 +985,18 @@ def test_computer_ssh_auto(run_cli_command, aiida_computer): options = ['core.ssh_auto', computer.uuid, '--non-interactive', '--safe-interval', '0'] run_cli_command(computer_configure, options, use_subprocess=False) assert computer.is_configured + + +def test_computer_ssh_async(run_cli_command, aiida_computer): + """Test setup of computer with ``core.ssh_async`` entry point. + + The configure step should only require the common shared options ``safe_interval`` and ``use_login_shell``. + """ + computer = aiida_computer(transport_type='core.ssh_async').store() + assert not computer.is_configured + + # It is important that 'ssh localhost' is functional in your test environment. + # It should connect without asking for a password. + options = ['core.ssh_async', computer.uuid, '--non-interactive', '--safe-interval', '0', '--machine', 'localhost'] + run_cli_command(computer_configure, options, use_subprocess=False) + assert computer.is_configured From 03ccc304d66c1f718b2ffbead9ed66636c3dd2c5 Mon Sep 17 00:00:00 2001 From: Ali Date: Mon, 25 Nov 2024 16:14:15 +0100 Subject: [PATCH 10/29] review applied --- src/aiida/calculations/monitors/base.py | 4 +- src/aiida/engine/daemon/execmanager.py | 16 +- .../engine/processes/calcjobs/monitors.py | 6 +- src/aiida/engine/transports.py | 6 +- src/aiida/orm/authinfos.py | 4 +- src/aiida/orm/computers.py | 6 +- src/aiida/orm/nodes/data/remote/base.py | 2 +- .../orm/nodes/process/calculation/calcjob.py | 6 +- src/aiida/orm/utils/remote.py | 6 +- src/aiida/plugins/factories.py | 14 +- src/aiida/schedulers/scheduler.py | 4 +- src/aiida/transports/__init__.py | 1 - src/aiida/transports/plugins/local.py | 146 ++++---- src/aiida/transports/plugins/ssh.py | 166 ++++----- src/aiida/transports/plugins/ssh_async.py | 226 ++++++------ src/aiida/transports/transport.py | 330 +++++++++--------- tests/engine/daemon/test_execmanager.py | 1 - tests/manage/tests/test_pytest_fixtures.py | 6 +- tests/orm/test_computers.py | 2 +- tests/plugins/test_factories.py | 6 +- tests/test_calculation_node.py | 4 +- tests/transports/test_all_plugins.py | 4 +- utils/dependency_management.py | 3 + 23 files changed, 483 insertions(+), 486 deletions(-) diff --git a/src/aiida/calculations/monitors/base.py b/src/aiida/calculations/monitors/base.py index 588b0debb6..87d5054915 100644 --- a/src/aiida/calculations/monitors/base.py +++ b/src/aiida/calculations/monitors/base.py @@ -7,10 +7,10 @@ from typing import Union from aiida.orm import CalcJobNode -from aiida.transports import AsyncTransport, BlockingTransport +from aiida.transports import AsyncTransport, Transport -def always_kill(node: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> str | None: +def always_kill(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | None: """Retrieve and inspect files in working directory of job to determine whether the job should be killed. This particular implementation is just for demonstration purposes and will kill the job as long as there is a diff --git a/src/aiida/engine/daemon/execmanager.py b/src/aiida/engine/daemon/execmanager.py index 18e650ebae..ece32173f9 100644 --- a/src/aiida/engine/daemon/execmanager.py +++ b/src/aiida/engine/daemon/execmanager.py @@ -35,7 +35,7 @@ from aiida.schedulers.datastructures import JobState if TYPE_CHECKING: - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found' @@ -64,7 +64,7 @@ def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: async def upload_calculation( node: CalcJobNode, - transport: Union['BlockingTransport', 'AsyncTransport'], + transport: Union['Transport', 'AsyncTransport'], calc_info: CalcInfo, folder: Folder, inputs: Optional[MappingType[str, Any]] = None, @@ -393,9 +393,7 @@ async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path): await transport.put_async(folder.get_abs_path(filename), workdir.joinpath(filename)) -def submit_calculation( - calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'] -) -> str | ExitCode: +def submit_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> str | ExitCode: """Submit a previously uploaded `CalcJob` to the scheduler. :param calculation: the instance of CalcJobNode to submit. @@ -425,7 +423,7 @@ def submit_calculation( return result -async def stash_calculation(calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> None: +async def stash_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None: """Stash files from the working directory of a completed calculation to a permanent remote folder. After a calculation has been completed, optionally stash files from the work directory to a storage location on the @@ -491,7 +489,7 @@ async def stash_calculation(calculation: CalcJobNode, transport: Union['Blocking async def retrieve_calculation( - calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'], retrieved_temporary_folder: str + calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], retrieved_temporary_folder: str ) -> FolderData | None: """Retrieve all the files of a completed job calculation using the given transport. @@ -556,7 +554,7 @@ async def retrieve_calculation( return retrieved_files -def kill_calculation(calculation: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport']) -> None: +def kill_calculation(calculation: CalcJobNode, transport: Union['Transport', 'AsyncTransport']) -> None: """Kill the calculation through the scheduler :param calculation: the instance of CalcJobNode to kill. @@ -591,7 +589,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Union['BlockingTranspo async def retrieve_files_from_list( calculation: CalcJobNode, - transport: Union['BlockingTransport', 'AsyncTransport'], + transport: Union['Transport', 'AsyncTransport'], folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], list]], ) -> None: diff --git a/src/aiida/engine/processes/calcjobs/monitors.py b/src/aiida/engine/processes/calcjobs/monitors.py index a9d2853b1d..e13f01a5f3 100644 --- a/src/aiida/engine/processes/calcjobs/monitors.py +++ b/src/aiida/engine/processes/calcjobs/monitors.py @@ -16,7 +16,7 @@ from aiida.plugins import BaseFactory if t.TYPE_CHECKING: - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport LOGGER = AIIDA_LOGGER.getChild(__name__) @@ -124,7 +124,7 @@ def validate(self): if any(required_parameter not in parameters for required_parameter in ('node', 'transport')): correct_signature = ( - "(node: CalcJobNode, transport: Union['BlockingTransport', 'AsyncTransport'], **kwargs) str | None:" + "(node: CalcJobNode, transport: Union['Transport', 'AsyncTransport'], **kwargs) str | None:" ) raise ValueError( f'The monitor `{self.entry_point}` has an invalid function signature, it should be: {correct_signature}' @@ -179,7 +179,7 @@ def monitors(self) -> collections.OrderedDict: def process( self, node: CalcJobNode, - transport: Union['BlockingTransport', 'AsyncTransport'], + transport: Union['Transport', 'AsyncTransport'], ) -> CalcJobMonitorResult | None: """Call all monitors in order and return the result as one returns anything other than ``None``. diff --git a/src/aiida/engine/transports.py b/src/aiida/engine/transports.py index 33e43e5b62..cade4c04ca 100644 --- a/src/aiida/engine/transports.py +++ b/src/aiida/engine/transports.py @@ -18,7 +18,7 @@ from aiida.orm import AuthInfo if TYPE_CHECKING: - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport _LOGGER = logging.getLogger(__name__) @@ -54,9 +54,7 @@ def loop(self) -> asyncio.AbstractEventLoop: return self._loop @contextlib.contextmanager - def request_transport( - self, authinfo: AuthInfo - ) -> Iterator[Awaitable[Union['BlockingTransport', 'AsyncTransport']]]: + def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable[Union['Transport', 'AsyncTransport']]]: """Request a transport from an authinfo. Because the client is not allowed to request a transport immediately they will instead be given back a future that can be awaited to get the transport:: diff --git a/src/aiida/orm/authinfos.py b/src/aiida/orm/authinfos.py index 3d8a45afa0..bff6ef849d 100644 --- a/src/aiida/orm/authinfos.py +++ b/src/aiida/orm/authinfos.py @@ -21,7 +21,7 @@ from aiida.orm import Computer, User from aiida.orm.implementation import StorageBackend from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401 - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport __all__ = ('AuthInfo',) @@ -166,7 +166,7 @@ def get_workdir(self) -> str: except KeyError: return self.computer.get_workdir() - def get_transport(self) -> Union['BlockingTransport', 'AsyncTransport']: + def get_transport(self) -> Union['Transport', 'AsyncTransport']: """Return a fully configured transport that can be used to connect to the computer set for this instance.""" computer = self.computer transport_type = computer.transport_type diff --git a/src/aiida/orm/computers.py b/src/aiida/orm/computers.py index 46b4ec522b..9bf12fbb2e 100644 --- a/src/aiida/orm/computers.py +++ b/src/aiida/orm/computers.py @@ -23,7 +23,7 @@ from aiida.orm import AuthInfo, User from aiida.orm.implementation import StorageBackend from aiida.schedulers import Scheduler - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport __all__ = ('Computer',) @@ -622,7 +622,7 @@ def is_user_enabled(self, user: 'User') -> bool: # Return False if the user is not configured (in a sense, it is disabled for that user) return False - def get_transport(self, user: Optional['User'] = None) -> Union['BlockingTransport', 'AsyncTransport']: + def get_transport(self, user: Optional['User'] = None) -> Union['Transport', 'AsyncTransport']: """Return a Transport class, configured with all correct parameters. The Transport is closed (meaning that if you want to run any operation with it, you have to open it first (i.e., e.g. for a SSH transport, you have @@ -646,7 +646,7 @@ def get_transport(self, user: Optional['User'] = None) -> Union['BlockingTranspo authinfo = authinfos.AuthInfo.get_collection(self.backend).get(dbcomputer=self, aiidauser=user) return authinfo.get_transport() - def get_transport_class(self) -> Union[Type['BlockingTransport'], Type['AsyncTransport']]: + def get_transport_class(self) -> Union[Type['Transport'], Type['AsyncTransport']]: """Get the transport class for this computer. Can be used to instantiate a transport instance.""" try: return TransportFactory(self.transport_type) diff --git a/src/aiida/orm/nodes/data/remote/base.py b/src/aiida/orm/nodes/data/remote/base.py index 655d2fccad..60e6f9bbee 100644 --- a/src/aiida/orm/nodes/data/remote/base.py +++ b/src/aiida/orm/nodes/data/remote/base.py @@ -118,7 +118,7 @@ def listdir_withattributes(self, path='.'): :param relpath: If 'relpath' is specified, lists the content of the given subfolder. :return: a list of dictionaries, where the documentation - is in :py:class:BlockingTransport.listdir_withattributes. + is in :py:class:Transport.listdir_withattributes. """ authinfo = self.get_authinfo() diff --git a/src/aiida/orm/nodes/process/calculation/calcjob.py b/src/aiida/orm/nodes/process/calculation/calcjob.py index c7580cd91b..3d448d957b 100644 --- a/src/aiida/orm/nodes/process/calculation/calcjob.py +++ b/src/aiida/orm/nodes/process/calculation/calcjob.py @@ -26,7 +26,7 @@ from aiida.parsers import Parser from aiida.schedulers.datastructures import JobInfo, JobState from aiida.tools.calculations import CalculationTools - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport __all__ = ('CalcJobNode',) @@ -450,10 +450,10 @@ def get_authinfo(self) -> 'AuthInfo': return computer.get_authinfo(self.user) - def get_transport(self) -> Union['BlockingTransport', 'AsyncTransport']: + def get_transport(self) -> Union['Transport', 'AsyncTransport']: """Return the transport for this calculation. - :return: Union['BlockingTransport', 'AsyncTransport'] configured + :return: Union['Transport', 'AsyncTransport'] configured with the `AuthInfo` associated to the computer of this node """ return self.get_authinfo().get_transport() diff --git a/src/aiida/orm/utils/remote.py b/src/aiida/orm/utils/remote.py index a8aa19b3fc..2a9846af7c 100644 --- a/src/aiida/orm/utils/remote.py +++ b/src/aiida/orm/utils/remote.py @@ -21,14 +21,14 @@ from aiida import orm from aiida.orm.implementation import StorageBackend - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport -def clean_remote(transport: Union['BlockingTransport', 'AsyncTransport'], path: str) -> None: +def clean_remote(transport: Union['Transport', 'AsyncTransport'], path: str) -> None: """Recursively remove a remote folder, with the given absolute path, and all its contents. The path should be made accessible through the transport channel, which should already be open - :param transport: an open Union['BlockingTransport', 'AsyncTransport'] channel + :param transport: an open Union['Transport', 'AsyncTransport'] channel :param path: an absolute path on the remote made available through the transport """ if not isinstance(path, str): diff --git a/src/aiida/plugins/factories.py b/src/aiida/plugins/factories.py index affce3d405..6cce4bf270 100644 --- a/src/aiida/plugins/factories.py +++ b/src/aiida/plugins/factories.py @@ -42,7 +42,7 @@ from aiida.schedulers import Scheduler from aiida.tools.data.orbital import Orbital from aiida.tools.dbimporters import DbImporter - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> NoReturn: @@ -412,7 +412,7 @@ def StorageFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint @overload def TransportFactory( entry_point_name: str, load: Literal[True] = True -) -> Union[Type['BlockingTransport'], Type['AsyncTransport']]: ... +) -> Union[Type['Transport'], Type['AsyncTransport']]: ... @overload @@ -421,8 +421,8 @@ def TransportFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: def TransportFactory( entry_point_name: str, load: bool = True -) -> Union[EntryPoint, Type['BlockingTransport'], Type['AsyncTransport']]: - """Return the Union['BlockingTransport', 'AsyncTransport'] sub class registered under the given entry point. +) -> Union[EntryPoint, Type['Transport'], Type['AsyncTransport']]: + """Return the Union['Transport', 'AsyncTransport'] sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. @@ -430,16 +430,16 @@ def TransportFactory( """ from inspect import isclass - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport entry_point_group = 'aiida.transports' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) - valid_classes = (BlockingTransport, AsyncTransport) + valid_classes = (Transport, AsyncTransport) if not load: return entry_point - if isclass(entry_point) and (issubclass(entry_point, BlockingTransport) or issubclass(entry_point, AsyncTransport)): + if isclass(entry_point) and (issubclass(entry_point, Transport) or issubclass(entry_point, AsyncTransport)): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) diff --git a/src/aiida/schedulers/scheduler.py b/src/aiida/schedulers/scheduler.py index e9fc2db3e2..3bb540c84a 100644 --- a/src/aiida/schedulers/scheduler.py +++ b/src/aiida/schedulers/scheduler.py @@ -22,7 +22,7 @@ from aiida.schedulers.datastructures import JobInfo, JobResource, JobTemplate, JobTemplateCodeInfo if t.TYPE_CHECKING: - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport __all__ = ('Scheduler', 'SchedulerError', 'SchedulerParsingError') @@ -366,7 +366,7 @@ def transport(self): return self._transport - def set_transport(self, transport: Union['BlockingTransport', 'AsyncTransport']): + def set_transport(self, transport: Union['Transport', 'AsyncTransport']): """Set the transport to be used to query the machine or to submit scripts. This class assumes that the transport is open and active. diff --git a/src/aiida/transports/__init__.py b/src/aiida/transports/__init__.py index 8b6080f77d..f3427ff5e3 100644 --- a/src/aiida/transports/__init__.py +++ b/src/aiida/transports/__init__.py @@ -17,7 +17,6 @@ __all__ = ( 'Transport', - 'BlockingTransport', 'SshTransport', 'AsyncTransport', 'convert_to_bool', diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index 32608dad2c..a3518dda04 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -19,11 +19,11 @@ from aiida.common.warnings import warn_deprecation from aiida.transports import cli as transport_cli -from aiida.transports.transport import BlockingTransport, TransportInternalError, _TransportPath, path_2_str +from aiida.transports.transport import Transport, TransportInternalError, TransportPath, path_to_str # refactor or raise the limit: issue #1784 -class LocalTransport(BlockingTransport): +class LocalTransport(Transport): """Support copy and command execution on the same host on which AiiDA is running via direct file copy and execution commands. @@ -94,7 +94,7 @@ def curdir(self): raise TransportInternalError('Error, local method called for LocalTransport without opening the channel first') - def chdir(self, path: _TransportPath): + def chdir(self, path: TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -107,7 +107,7 @@ def chdir(self, path: _TransportPath): '`chdir()` is deprecated and will be removed in the next major version.', version=3, ) - path = path_2_str(path) + path = path_to_str(path) new_path = os.path.join(self.curdir, path) if not os.path.isdir(new_path): raise OSError(f"'{new_path}' is not a valid directory") @@ -116,15 +116,15 @@ def chdir(self, path: _TransportPath): self._internal_dir = os.path.normpath(new_path) - def chown(self, path: _TransportPath, uid, gid): - path = path_2_str(path) + def chown(self, path: TransportPath, uid, gid): + path = path_to_str(path) os.chown(path, uid, gid) - def normalize(self, path: _TransportPath = '.'): + def normalize(self, path: TransportPath = '.'): """Normalizes path, eliminating double slashes, etc.. :param path: path to normalize """ - path = path_2_str(path) + path = path_to_str(path) return os.path.realpath(os.path.join(self.curdir, path)) def getcwd(self): @@ -136,9 +136,9 @@ def getcwd(self): return self.curdir @staticmethod - def _os_path_split_asunder(path: _TransportPath): + def _os_path_split_asunder(path: TransportPath): """Used by makedirs, Takes path (a str) and returns a list deconcatenating the path.""" - path = path_2_str(path) + path = path_to_str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -152,7 +152,7 @@ def _os_path_split_asunder(path: _TransportPath): parts.reverse() return parts - def makedirs(self, path: _TransportPath, ignore_existing=False): + def makedirs(self, path: TransportPath, ignore_existing=False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -163,7 +163,7 @@ def makedirs(self, path: _TransportPath, ignore_existing=False): :raise OSError: If the directory already exists and is not ignore_existing """ - path = path_2_str(path) + path = path_to_str(path) # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -179,7 +179,7 @@ def makedirs(self, path: _TransportPath, ignore_existing=False): if not os.path.exists(this_dir): os.mkdir(this_dir) - def mkdir(self, path: _TransportPath, ignore_existing=False): + def mkdir(self, path: TransportPath, ignore_existing=False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -188,37 +188,37 @@ def mkdir(self, path: _TransportPath, ignore_existing=False): :raise OSError: If the directory already exists. """ - path = path_2_str(path) + path = path_to_str(path) if ignore_existing and self.isdir(path): return os.mkdir(os.path.join(self.curdir, path)) - def rmdir(self, path: _TransportPath): + def rmdir(self, path: TransportPath): """Removes a folder at location path. :param path: path to remove """ - path = path_2_str(path) + path = path_to_str(path) os.rmdir(os.path.join(self.curdir, path)) - def isdir(self, path: _TransportPath): + def isdir(self, path: TransportPath): """Checks if 'path' is a directory. :return: a boolean """ - path = path_2_str(path) + path = path_to_str(path) if not path: return False return os.path.isdir(os.path.join(self.curdir, path)) - def chmod(self, path: _TransportPath, mode): + def chmod(self, path: TransportPath, mode): """Changes permission bits of object at path :param path: path to modify :param mode: permission bits :raise OSError: if path does not exist. """ - path = path_2_str(path) + path = path_to_str(path) if not path: raise OSError('Directory not given in input') real_path = os.path.join(self.curdir, path) @@ -229,7 +229,7 @@ def chmod(self, path: _TransportPath, mode): # please refactor: issue #1782 - def put(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + def put(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a file or a folder from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -243,8 +243,8 @@ def put(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kw :raise OSError: if remotepath is not valid :raise ValueError: if localpath is not valid """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) from aiida.common.warnings import warn_deprecation if 'ignore_noexisting' in kwargs: @@ -311,7 +311,7 @@ def put(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kw else: raise OSError(f'The local path {localpath} does not exist') - def putfile(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + def putfile(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a file from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -324,8 +324,8 @@ def putfile(self, localpath: _TransportPath, remotepath: _TransportPath, *args, :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) overwrite = kwargs.get('overwrite', args[0] if args else True) if not remotepath: @@ -345,7 +345,7 @@ def putfile(self, localpath: _TransportPath, remotepath: _TransportPath, *args, shutil.copyfile(localpath, the_destination) - def puttree(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + def puttree(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Copies a folder recursively from localpath to remotepath. Automatically redirects to putfile or puttree. @@ -360,8 +360,8 @@ def puttree(self, localpath: _TransportPath, remotepath: _TransportPath, *args, :raise ValueError: if localpath is not valid :raise OSError: if localpath does not exist """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -387,12 +387,12 @@ def puttree(self, localpath: _TransportPath, remotepath: _TransportPath, *args, shutil.copytree(localpath, the_destination, symlinks=not dereference, dirs_exist_ok=overwrite) - def rmtree(self, path: _TransportPath): + def rmtree(self, path: TransportPath): """Remove tree as rm -r would do :param path: a string to path """ - path = path_2_str(path) + path = path_to_str(path) the_path = os.path.join(self.curdir, path) try: shutil.rmtree(the_path) @@ -406,7 +406,7 @@ def rmtree(self, path: _TransportPath): # please refactor: issue #1781 - def get(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + def get(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a folder or a file recursively from 'remote' remotepath to 'local' localpath. Automatically redirects to getfile or gettree. @@ -421,8 +421,8 @@ def get(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kw :raise OSError: if 'remote' remotepath is not valid :raise ValueError: if 'local' localpath is not valid """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) ignore_nonexisting = kwargs.get('ignore_nonexisting', args[2] if len(args) > 2 else False) @@ -474,7 +474,7 @@ def get(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kw else: raise OSError(f'The remote path {remotepath} does not exist') - def getfile(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + def getfile(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a file recursively from 'remote' remotepath to 'local' localpath. @@ -487,8 +487,8 @@ def getfile(self, remotepath: _TransportPath, localpath: _TransportPath, *args, :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) overwrite = kwargs.get('overwrite', args[0] if args else True) if not localpath: raise ValueError('Input localpath to get function must be a non empty string') @@ -504,7 +504,7 @@ def getfile(self, remotepath: _TransportPath, localpath: _TransportPath, *args, shutil.copyfile(the_source, localpath) - def gettree(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + def gettree(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Copies a folder recursively from 'remote' remotepath to 'local' localpath. @@ -517,8 +517,8 @@ def gettree(self, remotepath: _TransportPath, localpath: _TransportPath, *args, :raise ValueError: if 'local' localpath is not valid :raise OSError: if unintentionally overwriting """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) dereference = kwargs.get('dereference', args[0] if args else True) overwrite = kwargs.get('overwrite', args[1] if len(args) > 1 else True) if not remotepath: @@ -545,7 +545,7 @@ def gettree(self, remotepath: _TransportPath, localpath: _TransportPath, *args, # please refactor: issue #1780 on github - def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False, recursive=True): + def copy(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False, recursive=True): """Copies a file or a folder from 'remote' remotesource to 'remote' remotedestination. Automatically redirects to copyfile or copytree. @@ -558,8 +558,8 @@ def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, :raise ValueError: if 'remote' remotesource or remotedestinationis not valid :raise OSError: if remotesource does not exist """ - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copy must be a non empty object') if not remotedestination: @@ -607,7 +607,7 @@ def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, # With self.copytree, the (possible) relative path is OK self.copytree(remotesource, remotedestination, dereference) - def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): + def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copies a file from 'remote' remotesource to 'remote' remotedestination. @@ -618,8 +618,8 @@ def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPa :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copyfile must be a non empty object') if not remotedestination: @@ -635,7 +635,7 @@ def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPa else: shutil.copyfile(the_source, the_destination) - def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): + def copytree(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copies a folder from 'remote' remotesource to 'remote' remotedestination. @@ -646,8 +646,8 @@ def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPa :raise ValueError: if 'remote' remotesource or remotedestination is not valid :raise OSError: if remotesource does not exist """ - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if not remotesource: raise ValueError('Input remotesource to copytree must be a non empty object') if not remotedestination: @@ -663,12 +663,12 @@ def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPa shutil.copytree(the_source, the_destination, symlinks=not dereference) - def get_attribute(self, path: _TransportPath): + def get_attribute(self, path: TransportPath): """Returns an object FileAttribute, as specified in aiida.transports. :param path: the path of the given file. """ - path = path_2_str(path) + path = path_to_str(path) from aiida.transports.util import FileAttribute os_attr = os.lstat(os.path.join(self.curdir, path)) @@ -679,11 +679,11 @@ def get_attribute(self, path: _TransportPath): aiida_attr[key] = getattr(os_attr, key) return aiida_attr - def _local_listdir(self, path: _TransportPath, pattern=None): + def _local_listdir(self, path: TransportPath, pattern=None): """Act on the local folder, for the rest, same as listdir.""" import re - path = path_2_str(path) + path = path_to_str(path) if not pattern: return os.listdir(path) @@ -698,13 +698,13 @@ def _local_listdir(self, path: _TransportPath, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path: _TransportPath = '.', pattern=None): + def listdir(self, path: TransportPath = '.', pattern=None): """:return: a list containing the names of the entries in the directory. :param path: default ='.' :param pattern: if set, returns the list of files matching pattern. Unix only. (Use to emulate ls * for example) """ - path = path_2_str(path) + path = path_to_str(path) the_path = os.path.join(self.curdir, path).strip() if not pattern: try: @@ -721,22 +721,22 @@ def listdir(self, path: _TransportPath = '.', pattern=None): the_path += '/' return [re.sub(the_path, '', i) for i in filtered_list] - def remove(self, path: _TransportPath): + def remove(self, path: TransportPath): """Removes a file at position path.""" - path = path_2_str(path) + path = path_to_str(path) os.remove(os.path.join(self.curdir, path)) - def isfile(self, path: _TransportPath): + def isfile(self, path: TransportPath): """Checks if object at path is a file. Returns a boolean. """ - path = path_2_str(path) + path = path_to_str(path) if not path: return False return os.path.isfile(os.path.join(self.curdir, path)) @contextlib.contextmanager - def _exec_command_internal(self, command, workdir: Optional[_TransportPath] = None, **kwargs): + def _exec_command_internal(self, command, workdir: Optional[TransportPath] = None, **kwargs): """Executes the specified command in bash login shell. @@ -762,7 +762,7 @@ def _exec_command_internal(self, command, workdir: Optional[_TransportPath] = No from aiida.common.escaping import escape_for_bash if workdir: - workdir = path_2_str(workdir) + workdir = path_to_str(workdir) # Note: The outer shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. bash_commmand = f'{self._bash_command_str}-c ' @@ -784,7 +784,7 @@ def _exec_command_internal(self, command, workdir: Optional[_TransportPath] = No ) as process: yield process - def exec_command_wait_bytes(self, command, stdin=None, workdir: Optional[_TransportPath] = None, **kwargs): + def exec_command_wait_bytes(self, command, stdin=None, workdir: Optional[TransportPath] = None, **kwargs): """Executes the specified command and waits for it to finish. :param command: the command to execute @@ -797,7 +797,7 @@ def exec_command_wait_bytes(self, command, stdin=None, workdir: Optional[_Transp are both bytes and the return_value is an int. """ if workdir: - workdir = path_2_str(workdir) + workdir = path_to_str(workdir) with self._exec_command_internal(command, workdir) as process: if stdin is not None: # Implicitly assume that the desired encoding is 'utf-8' if I receive a string. @@ -840,7 +840,7 @@ def line_encoder(iterator, encoding='utf-8'): return retval, output_text, stderr_text - def gotocomputer_command(self, remotedir: _TransportPath): + def gotocomputer_command(self, remotedir: TransportPath): """Return a string to be run using os.system in order to connect via the transport to the remote directory. @@ -851,12 +851,12 @@ def gotocomputer_command(self, remotedir: _TransportPath): :param str remotedir: the full path of the remote directory """ - remotedir = path_2_str(remotedir) + remotedir = path_to_str(remotedir) connect_string = self._gotocomputer_string(remotedir) cmd = f'bash -c {connect_string}' return cmd - def rename(self, oldpath: _TransportPath, newpath: _TransportPath): + def rename(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder @@ -865,8 +865,8 @@ def rename(self, oldpath: _TransportPath, newpath: _TransportPath): :raises OSError: if src/dst is not found :raises ValueError: if src/dst is not a valid string """ - oldpath = path_2_str(oldpath) - newpath = path_2_str(newpath) + oldpath = path_to_str(oldpath) + newpath = path_to_str(newpath) if not oldpath: raise ValueError(f'Source {oldpath} is not a valid string') if not newpath: @@ -878,15 +878,15 @@ def rename(self, oldpath: _TransportPath, newpath: _TransportPath): shutil.move(oldpath, newpath) - def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPath): + def symlink(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote remotedestination :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ - remotesource = os.path.normpath(path_2_str(remotesource)) - remotedestination = os.path.normpath(path_2_str(remotedestination)) + remotesource = os.path.normpath(path_to_str(remotesource)) + remotedestination = os.path.normpath(path_to_str(remotedestination)) if self.has_magic(remotesource): if self.has_magic(remotedestination): @@ -905,9 +905,9 @@ def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPat except OSError: raise OSError(f'!!: {remotesource}, {self.curdir}, {remotedestination}') - def path_exists(self, path: _TransportPath): + def path_exists(self, path: TransportPath): """Check if path exists""" - path = path_2_str(path) + path = path_to_str(path) return os.path.exists(os.path.join(self.curdir, path)) diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index c2fb643f6d..21c0b84784 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -21,7 +21,7 @@ from aiida.common.escaping import escape_for_bash from aiida.common.warnings import warn_deprecation -from ..transport import BlockingTransport, TransportInternalError, _TransportPath, path_2_str +from ..transport import Transport, TransportInternalError, TransportPath, path_to_str __all__ = ('parse_sshconfig', 'convert_to_bool', 'SshTransport') @@ -62,7 +62,7 @@ def convert_to_bool(string): raise ValueError('Invalid boolean value provided') -class SshTransport(BlockingTransport): +class SshTransport(Transport): """Support connection, command execution and data transfer to remote computers via SSH+SFTP.""" # Valid keywords accepted by the connect method of paramiko.SSHClient @@ -581,7 +581,7 @@ def __str__(self): return f"{'OPEN' if self._is_open else 'CLOSED'} [{conn_info}]" - def chdir(self, path: _TransportPath): + def chdir(self, path: TransportPath): """ PLEASE DON'T USE `chdir()` IN NEW DEVELOPMENTS, INSTEAD DIRECTLY PASS ABSOLUTE PATHS TO INTERFACE. `chdir()` is DEPRECATED and will be removed in the next major version. @@ -597,7 +597,7 @@ def chdir(self, path: _TransportPath): ) from paramiko.sftp import SFTPError - path = path_2_str(path) + path = path_to_str(path) old_path = self.sftp.getcwd() if path is not None: try: @@ -624,13 +624,13 @@ def chdir(self, path: _TransportPath): self.chdir(old_path) raise OSError(str(exc)) - def normalize(self, path: _TransportPath = '.'): + def normalize(self, path: TransportPath = '.'): """Returns the normalized path (removing double slashes, etc...)""" - path = path_2_str(path) + path = path_to_str(path) return self.sftp.normalize(path) - def stat(self, path: _TransportPath): + def stat(self, path: TransportPath): """Retrieve information about a file on the remote system. The return value is an object whose attributes correspond to the attributes of Python's ``stat`` structure as returned by ``os.stat``, except that it @@ -643,11 +643,11 @@ def stat(self, path: _TransportPath): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ - path = path_2_str(path) + path = path_to_str(path) return self.sftp.stat(path) - def lstat(self, path: _TransportPath): + def lstat(self, path: TransportPath): """Retrieve information about a file on the remote system, without following symbolic links (shortcuts). This otherwise behaves exactly the same as `stat`. @@ -657,7 +657,7 @@ def lstat(self, path: _TransportPath): :return: a `paramiko.sftp_attr.SFTPAttributes` object containing attributes about the given file. """ - path = path_2_str(path) + path = path_to_str(path) return self.sftp.lstat(path) @@ -677,7 +677,7 @@ def getcwd(self): ) return self.sftp.getcwd() - def makedirs(self, path: _TransportPath, ignore_existing: bool = False): + def makedirs(self, path: TransportPath, ignore_existing: bool = False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -692,7 +692,7 @@ def makedirs(self, path: _TransportPath, ignore_existing: bool = False): :raise OSError: If the directory already exists. """ - path = path_2_str(path) + path = path_to_str(path) # check to avoid creation of empty dirs path = os.path.normpath(path) @@ -715,7 +715,7 @@ def makedirs(self, path: _TransportPath, ignore_existing: bool = False): if not self.isdir(this_dir): self.mkdir(this_dir) - def mkdir(self, path: _TransportPath, ignore_existing: bool = False): + def mkdir(self, path: TransportPath, ignore_existing: bool = False): """Create a folder (directory) named path. :param path: name of the folder to create @@ -724,7 +724,7 @@ def mkdir(self, path: _TransportPath, ignore_existing: bool = False): :raise OSError: If the directory already exists. """ - path = path_2_str(path) + path = path_to_str(path) if ignore_existing and self.isdir(path): return @@ -745,7 +745,7 @@ def mkdir(self, path: _TransportPath, ignore_existing: bool = False): 'or the directory already exists? ({})'.format(path, self.getcwd(), exc) ) - def rmtree(self, path: _TransportPath): + def rmtree(self, path: TransportPath): """Remove a file or a directory at path, recursively Flags used: -r: recursive copy; -f: force, makes the command non interactive; @@ -753,7 +753,7 @@ def rmtree(self, path: _TransportPath): :raise OSError: if the rm execution failed. """ - path = path_2_str(path) + path = path_to_str(path) # Assuming linux rm command! rm_exe = 'rm' @@ -773,29 +773,29 @@ def rmtree(self, path: _TransportPath): self.logger.error(f"Problem executing rm. Exit code: {retval}, stdout: '{stdout}', stderr: '{stderr}'") raise OSError(f'Error while executing rm. Exit code: {retval}') - def rmdir(self, path: _TransportPath): + def rmdir(self, path: TransportPath): """Remove the folder named 'path' if empty.""" - path = path_2_str(path) + path = path_to_str(path) self.sftp.rmdir(path) - def chown(self, path: _TransportPath, uid, gid): + def chown(self, path: TransportPath, uid, gid): """Change owner permissions of a file. For now, this is not implemented for the SSH transport. """ raise NotImplementedError - def isdir(self, path: _TransportPath): + def isdir(self, path: TransportPath): """Return True if the given path is a directory, False otherwise. Return False also if the path does not exist. """ # Return False on empty string (paramiko would map this to the local # folder instead) - path = path_2_str(path) + path = path_to_str(path) if not path: return False - path = path_2_str(path) + path = path_to_str(path) try: return S_ISDIR(self.stat(path).st_mode) except OSError as exc: @@ -804,24 +804,24 @@ def isdir(self, path: _TransportPath): return False raise # Typically if I don't have permissions (errno=13) - def chmod(self, path: _TransportPath, mode): + def chmod(self, path: TransportPath, mode): """Change permissions to path :param path: path to file :param mode: new permission bits (integer) """ - path = path_2_str(path) + path = path_to_str(path) if not path: raise OSError('Input path is an empty argument.') return self.sftp.chmod(path, mode) @staticmethod - def _os_path_split_asunder(path: _TransportPath): + def _os_path_split_asunder(path: TransportPath): """Used by makedirs. Takes path and returns a list deconcatenating the path """ - path = path_2_str(path) + path = path_to_str(path) parts = [] while True: newpath, tail = os.path.split(path) @@ -837,8 +837,8 @@ def _os_path_split_asunder(path: _TransportPath): def put( self, - localpath: _TransportPath, - remotepath: _TransportPath, + localpath: TransportPath, + remotepath: TransportPath, callback=None, dereference: bool = True, overwrite: bool = True, @@ -857,8 +857,8 @@ def put( :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) if not dereference: raise NotImplementedError @@ -912,8 +912,8 @@ def put( def putfile( self, - localpath: _TransportPath, - remotepath: _TransportPath, + localpath: TransportPath, + remotepath: TransportPath, callback=None, dereference: bool = True, overwrite: bool = True, @@ -929,8 +929,8 @@ def putfile( :raise OSError: if the localpath does not exist, or unintentionally overwriting """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) if not dereference: raise NotImplementedError @@ -945,8 +945,8 @@ def putfile( def puttree( self, - localpath: _TransportPath, - remotepath: _TransportPath, + localpath: TransportPath, + remotepath: TransportPath, callback=None, dereference: bool = True, overwrite: bool = True, @@ -969,8 +969,8 @@ def puttree( .. note:: setting dereference equal to True could cause infinite loops. see os.walk() documentation """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) if not dereference: raise NotImplementedError @@ -1019,8 +1019,8 @@ def puttree( def get( self, - remotepath: _TransportPath, - localpath: _TransportPath, + remotepath: TransportPath, + localpath: TransportPath, callback=None, dereference: bool = True, overwrite: bool = True, @@ -1040,8 +1040,8 @@ def get( :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) if not dereference: raise NotImplementedError @@ -1092,8 +1092,8 @@ def get( def getfile( self, - remotepath: _TransportPath, - localpath: _TransportPath, + remotepath: TransportPath, + localpath: TransportPath, callback=None, dereference: bool = True, overwrite: bool = True, @@ -1108,8 +1108,8 @@ def getfile( :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -1132,8 +1132,8 @@ def getfile( def gettree( self, - remotepath: _TransportPath, - localpath: _TransportPath, + remotepath: TransportPath, + localpath: TransportPath, callback=None, dereference: bool = True, overwrite: bool = True, @@ -1152,8 +1152,8 @@ def gettree( :raise OSError: if the remotepath is not found :raise OSError: if unintentionally overwriting """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) if not dereference: raise NotImplementedError @@ -1190,11 +1190,11 @@ def gettree( else: self.getfile(os.path.join(remotepath, item), os.path.join(dest, item)) - def get_attribute(self, path: _TransportPath): + def get_attribute(self, path: TransportPath): """Returns the object Fileattribute, specified in aiida.transports Receives in input the path of a given file. """ - path = path_2_str(path) + path = path_to_str(path) from aiida.transports.util import FileAttribute paramiko_attr = self.lstat(path) @@ -1205,22 +1205,22 @@ def get_attribute(self, path: _TransportPath): aiida_attr[key] = getattr(paramiko_attr, key) return aiida_attr - def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference: bool = False): - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath, dereference: bool = False): + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) return self.copy(remotesource, remotedestination, dereference) - def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference: bool = False): - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + def copytree(self, remotesource: TransportPath, remotedestination: TransportPath, dereference: bool = False): + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) return self.copy(remotesource, remotedestination, dereference, recursive=True) def copy( self, - remotesource: _TransportPath, - remotedestination: _TransportPath, + remotesource: TransportPath, + remotedestination: TransportPath, dereference: bool = False, recursive: bool = True, ): @@ -1240,8 +1240,8 @@ def copy( .. note:: setting dereference equal to True could cause infinite loops. """ - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) # In the majority of cases, we should deal with linux cp commands cp_flags = '-f' @@ -1324,14 +1324,14 @@ def _local_listdir(path: str, pattern=None): base_dir += os.sep return [re.sub(base_dir, '', i) for i in filtered_list] - def listdir(self, path: _TransportPath = '.', pattern=None): + def listdir(self, path: TransportPath = '.', pattern=None): """Get the list of files at path. :param path: default = '.' :param pattern: returns the list of files matching pattern. Unix only. (Use to emulate ``ls *`` for example) """ - path = path_2_str(path) + path = path_to_str(path) if path.startswith('/'): abs_dir = path @@ -1346,12 +1346,12 @@ def listdir(self, path: _TransportPath = '.', pattern=None): abs_dir += '/' return [re.sub(abs_dir, '', i) for i in filtered_list] - def remove(self, path: _TransportPath): + def remove(self, path: TransportPath): """Remove a single file at 'path'""" - path = path_2_str(path) + path = path_to_str(path) return self.sftp.remove(path) - def rename(self, oldpath: _TransportPath, newpath: _TransportPath): + def rename(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param str oldpath: existing name of the file or folder @@ -1365,8 +1365,8 @@ def rename(self, oldpath: _TransportPath, newpath: _TransportPath): if not newpath: raise ValueError(f'Destination {newpath} is not a valid path') - oldpath = path_2_str(oldpath) - newpath = path_2_str(newpath) + oldpath = path_to_str(oldpath) + newpath = path_to_str(newpath) if not self.isfile(oldpath): if not self.isdir(oldpath): @@ -1380,7 +1380,7 @@ def rename(self, oldpath: _TransportPath, newpath: _TransportPath): return self.sftp.rename(oldpath, newpath) - def isfile(self, path: _TransportPath): + def isfile(self, path: TransportPath): """Return True if the given path is a file, False otherwise. Return False also if the path does not exist. """ @@ -1390,7 +1390,7 @@ def isfile(self, path: _TransportPath): if not path: return False - path = path_2_str(path) + path = path_to_str(path) try: self.logger.debug( f"stat for path '{path}' ('{self.normalize(path)}'): {self.stat(path)} [{self.stat(path).st_mode}]" @@ -1451,7 +1451,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1, work return stdin, stdout, stderr, channel def exec_command_wait_bytes( - self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir: _TransportPath = None + self, command, stdin=None, combine_stderr=False, bufsize=-1, timeout=0.01, workdir: TransportPath = None ): """Executes the specified command and waits for it to finish. @@ -1472,7 +1472,7 @@ def exec_command_wait_bytes( import time if workdir: - workdir = path_2_str(workdir) + workdir = path_to_str(workdir) ssh_stdin, stdout, stderr, channel = self._exec_command_internal( command, combine_stderr, bufsize=bufsize, workdir=workdir @@ -1567,11 +1567,11 @@ def exec_command_wait_bytes( return (retval, b''.join(stdout_bytes), b''.join(stderr_bytes)) - def gotocomputer_command(self, remotedir: _TransportPath): + def gotocomputer_command(self, remotedir: TransportPath): """Specific gotocomputer string to connect to a given remote computer via ssh and directly go to the calculation folder. """ - remotedir = path_2_str(remotedir) + remotedir = path_to_str(remotedir) further_params = [] if 'username' in self._connect_args: @@ -1595,25 +1595,25 @@ def gotocomputer_command(self, remotedir: _TransportPath): cmd = f'ssh -t {self._machine} {further_params_str} {connect_string}' return cmd - def _symlink(self, source: _TransportPath, dest: _TransportPath): + def _symlink(self, source: TransportPath, dest: TransportPath): """Wrap SFTP symlink call without breaking API :param source: source of link :param dest: link to create """ - source = path_2_str(source) - dest = path_2_str(dest) + source = path_to_str(source) + dest = path_to_str(dest) self.sftp.symlink(source, dest) - def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPath): + def symlink(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: remote source. Can contain a pattern. :param remotedestination: remote destination """ - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) # paramiko gives some errors if path is starting with '.' source = os.path.normpath(remotesource) dest = os.path.normpath(remotedestination) @@ -1631,11 +1631,11 @@ def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPat else: self._symlink(source, dest) - def path_exists(self, path: _TransportPath): + def path_exists(self, path: TransportPath): """Check if path exists""" import errno - path = path_2_str(path) + path = path_to_str(path) try: self.stat(path) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 77549be75b..5050516577 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -22,7 +22,7 @@ from aiida.common.escaping import escape_for_bash from aiida.common.exceptions import InvalidOperation -from ..transport import AsyncTransport, BlockingTransport, TransportInternalError, _TransportPath, path_2_str +from ..transport import AsyncTransport, Transport, TransportInternalError, TransportPath, path_to_str __all__ = ('AsyncSshTransport',) @@ -145,11 +145,13 @@ def __str__(self): async def get_async( self, - remotepath: _TransportPath, - localpath: _TransportPath, + remotepath: TransportPath, + localpath: TransportPath, dereference=True, overwrite=True, ignore_nonexisting=False, + *args, + **kwargs, ): """Get a file or folder from remote to local. Redirects to getfile or gettree. @@ -161,8 +163,8 @@ async def get_async( :param overwrite: if True overwrites files and folders. Default = False - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath :type dereference: bool :type overwrite: bool :type ignore_nonexisting: bool @@ -170,8 +172,8 @@ async def get_async( :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -220,7 +222,7 @@ async def get_async( raise OSError(f'The remote path {remotepath} does not exist') async def getfile_async( - self, remotepath: _TransportPath, localpath: _TransportPath, dereference=True, overwrite=True + self, remotepath: TransportPath, localpath: TransportPath, dereference=True, overwrite=True, *args, **kwargs ): """Get a file from remote to local. @@ -231,16 +233,16 @@ async def getfile_async( :param dereference: follow symbolic links. Default = True - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath :type dereference: bool :type overwrite: bool :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') @@ -256,7 +258,7 @@ async def getfile_async( raise OSError(f'Error while uploading file {localpath}: {exc}') async def gettree_async( - self, remotepath: _TransportPath, localpath: _TransportPath, dereference=True, overwrite=True + self, remotepath: TransportPath, localpath: TransportPath, dereference=True, overwrite=True, *args, **kwargs ): """Get a folder recursively from remote to local. @@ -267,8 +269,8 @@ async def gettree_async( :param overwrite: if True overwrites files and folders. Default = True - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath :type dereference: bool :type overwrite: bool @@ -276,8 +278,8 @@ async def gettree_async( :raise OSError: if the remotepath is not found :raise OSError: if unintentionally overwriting """ - remotepath = path_2_str(remotepath) - localpath = path_2_str(localpath) + remotepath = path_to_str(remotepath) + localpath = path_to_str(localpath) if not remotepath: raise OSError('Remotepath must be a non empty string') @@ -316,11 +318,13 @@ async def gettree_async( async def put_async( self, - localpath: _TransportPath, - remotepath: _TransportPath, + localpath: TransportPath, + remotepath: TransportPath, dereference=True, overwrite=True, ignore_nonexisting=False, + *args, + **kwargs, ): """Put a file or a folder from local to remote. Redirects to putfile or puttree. @@ -332,8 +336,8 @@ async def put_async( :param overwrite: if True overwrites files and folders Default = False - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath :type dereference: bool :type overwrite: bool :type ignore_nonexisting: bool @@ -341,8 +345,8 @@ async def put_async( :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -393,7 +397,7 @@ async def put_async( raise OSError(f'The local path {localpath} does not exist') async def putfile_async( - self, localpath: _TransportPath, remotepath: _TransportPath, dereference=True, overwrite=True + self, localpath: TransportPath, remotepath: TransportPath, dereference=True, overwrite=True, *args, **kwargs ): """Put a file from local to remote. @@ -402,8 +406,8 @@ async def putfile_async( :param overwrite: if True overwrites files and folders Default = True - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath :type dereference: bool :type overwrite: bool @@ -411,8 +415,8 @@ async def putfile_async( :raise OSError: if the localpath does not exist, or unintentionally overwriting """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -428,7 +432,7 @@ async def putfile_async( raise OSError(f'Error while uploading file {localpath}: {exc}') async def puttree_async( - self, localpath: _TransportPath, remotepath: _TransportPath, dereference=True, overwrite=True + self, localpath: TransportPath, remotepath: TransportPath, dereference=True, overwrite=True, *args, **kwargs ): """Put a folder recursively from local to remote. @@ -439,8 +443,8 @@ async def puttree_async( :param overwrite: if True overwrites files and folders (boolean). Default = True - :type localpath: _TransportPath - :type remotepath: _TransportPath + :type localpath: TransportPath + :type remotepath: TransportPath :type dereference: bool :type overwrite: bool @@ -448,8 +452,8 @@ async def puttree_async( :raise OSError: if the localpath does not exist, or trying to overwrite :raise OSError: if remotepath is invalid """ - localpath = path_2_str(localpath) - remotepath = path_2_str(remotepath) + localpath = path_to_str(localpath) + remotepath = path_to_str(remotepath) if not os.path.isabs(localpath): raise ValueError('The localpath must be an absolute path') @@ -491,8 +495,8 @@ async def puttree_async( async def copy_async( self, - remotesource: _TransportPath, - remotedestination: _TransportPath, + remotesource: TransportPath, + remotedestination: TransportPath, dereference: bool = False, recursive: bool = True, preserve: bool = False, @@ -505,8 +509,8 @@ async def copy_async( :param recursive: copy recursively :param preserve: preserve file attributes - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :type recursive: bool :type preserve: bool @@ -514,8 +518,8 @@ async def copy_async( :raises: OSError, src does not exist or if the copy execution failed. """ - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if self.has_magic(remotedestination): raise ValueError('Pathname patterns are not allowed in the destination') @@ -548,8 +552,8 @@ async def copy_async( async def copyfile_async( self, - remotesource: _TransportPath, - remotedestination: _TransportPath, + remotesource: TransportPath, + remotedestination: TransportPath, dereference: bool = False, preserve: bool = False, ): @@ -560,8 +564,8 @@ async def copyfile_async( :param dereference: follow symbolic links :param preserve: preserve file attributes - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :type preserve: bool @@ -571,8 +575,8 @@ async def copyfile_async( async def copytree_async( self, - remotesource: _TransportPath, - remotedestination: _TransportPath, + remotesource: TransportPath, + remotedestination: TransportPath, dereference: bool = False, preserve: bool = False, ): @@ -583,8 +587,8 @@ async def copytree_async( :param dereference: follow symbolic links :param preserve: preserve file attributes - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :type preserve: bool @@ -597,7 +601,7 @@ async def exec_command_wait_async( command: str, stdin: Optional[str] = None, encoding: str = 'utf-8', - workdir: Optional[_TransportPath] = None, + workdir: Optional[TransportPath] = None, timeout: Optional[float] = 2, **kwargs, ): @@ -605,14 +609,14 @@ async def exec_command_wait_async( :param command: the command to execute :param stdin: the input to pass to the command - :param encoding: (IGNORED) this is here just to keep the same signature as the one in `BlockingTransport` class + :param encoding: (IGNORED) this is here just to keep the same signature as the one in `Transport` class :param workdir: the working directory where to execute the command :param timeout: the timeout in seconds :type command: str :type stdin: str :type encoding: str - :type workdir: Union[_TransportPath, None] + :type workdir: Union[TransportPath, None] :type timeout: float :return: a tuple with the return code, the stdout and the stderr of the command @@ -620,7 +624,7 @@ async def exec_command_wait_async( """ if workdir: - workdir = path_2_str(workdir) + workdir = path_to_str(workdir) command = f'cd {workdir} && {command}' bash_commmand = self._bash_command_str + '-c ' @@ -631,7 +635,7 @@ async def exec_command_wait_async( # Since the command is str, both stdout and stderr are strings return (result.returncode, ''.join(str(result.stdout)), ''.join(str(result.stderr))) - async def get_attribute_async(self, path: _TransportPath): + async def get_attribute_async(self, path: TransportPath): """Return an object FixedFieldsAttributeDict for file in a given path, as defined in aiida.common.extendeddicts Each attribute object consists in a dictionary with the following keys: @@ -650,11 +654,11 @@ async def get_attribute_async(self, path: _TransportPath): :param path: path to file - :type path: _TransportPath + :type path: TransportPath :return: object FixedFieldsAttributeDict """ - path = path_2_str(path) + path = path_to_str(path) from aiida.transports.util import FileAttribute asyncssh_attr = await self._sftp.lstat(path) @@ -677,13 +681,13 @@ async def get_attribute_async(self, path: _TransportPath): raise NotImplementedError(f'Mapping the {key} attribute is not implemented') return aiida_attr - async def isdir_async(self, path: _TransportPath): + async def isdir_async(self, path: TransportPath): """Return True if the given path is a directory, False otherwise. Return False also if the path does not exist. :param path: the absolute path to check - :type path: _TransportPath + :type path: TransportPath :return: True if the path is a directory, False otherwise """ @@ -691,17 +695,17 @@ async def isdir_async(self, path: _TransportPath): if not path: return False - path = path_2_str(path) + path = path_to_str(path) return await self._sftp.isdir(path) - async def isfile_async(self, path: _TransportPath): + async def isfile_async(self, path: TransportPath): """Return True if the given path is a file, False otherwise. Return False also if the path does not exist. :param path: the absolute path to check - :type path: _TransportPath + :type path: TransportPath :return: True if the path is a file, False otherwise """ @@ -709,11 +713,11 @@ async def isfile_async(self, path: _TransportPath): if not path: return False - path = path_2_str(path) + path = path_to_str(path) return await self._sftp.isfile(path) - async def listdir_async(self, path: _TransportPath, pattern=None): + async def listdir_async(self, path: TransportPath, pattern=None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -722,11 +726,11 @@ async def listdir_async(self, path: _TransportPath, pattern=None): :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: _TransportPath + :type path: TransportPath :return: a list of strings """ - path = path_2_str(path) + path = path_to_str(path) if not pattern: list_ = list(await self._sftp.listdir(path)) else: @@ -741,7 +745,7 @@ async def listdir_async(self, path: _TransportPath, pattern=None): return list_ - async def listdir_withattributes_async(self, path: _TransportPath, pattern: Optional[str] = None): + async def listdir_withattributes_async(self, path: TransportPath, pattern: Optional[str] = None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -750,7 +754,7 @@ async def listdir_withattributes_async(self, path: _TransportPath, pattern: Opti :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: _TransportPath + :type path: TransportPath :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is @@ -766,7 +770,7 @@ async def listdir_withattributes_async(self, path: _TransportPath, pattern: Opti (if the file is a folder, a directory, ...). 'attributes' behaves as the output of transport.get_attribute(); isdir is a boolean indicating if the object is a directory or not. """ - path = path_2_str(path) + path = path_to_str(path) retlist = [] listdir = await self.listdir_async(path, pattern) for file_name in listdir: @@ -785,11 +789,11 @@ async def makedirs_async(self, path, ignore_existing=False): :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: _TransportPath + :type path: TransportPath :raises: OSError, if directory at path already exists """ - path = path_2_str(path) + path = path_to_str(path) try: await self._sftp.makedirs(path, exist_ok=ignore_existing) @@ -801,18 +805,18 @@ async def makedirs_async(self, path, ignore_existing=False): else: raise TransportInternalError(f'Error while creating directory {path}: {exc}') - async def mkdir_async(self, path: _TransportPath, ignore_existing=False): + async def mkdir_async(self, path: TransportPath, ignore_existing=False): """Create a directory. :param path: absolute path to directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: _TransportPath + :type path: TransportPath :raises: OSError, if directory at path already exists """ - path = path_2_str(path) + path = path_to_str(path) try: await self._sftp.mkdir(path) @@ -830,20 +834,20 @@ async def mkdir_async(self, path: _TransportPath, ignore_existing=False): else: raise TransportInternalError(f'Error while creating directory {path}: {exc}') - async def normalize_async(self, path: _TransportPath): + async def normalize_async(self, path: TransportPath): raise NotImplementedError('Not implemented, waiting for a use case.') - async def remove_async(self, path: _TransportPath): + async def remove_async(self, path: TransportPath): """Remove the file at the given path. This only works on files; for removing folders (directories), use rmdir. :param path: path to file to remove - :type path: _TransportPath + :type path: TransportPath :raise OSError: if the path is a directory """ - path = path_2_str(path) + path = path_to_str(path) # TODO: check if asyncssh does return SFTPFileIsADirectory in this case # if that's the case, we can get rid of the isfile check if await self.isdir_async(path): @@ -851,21 +855,21 @@ async def remove_async(self, path: _TransportPath): else: await self._sftp.remove(path) - async def rename_async(self, oldpath: _TransportPath, newpath: _TransportPath): + async def rename_async(self, oldpath: TransportPath, newpath: TransportPath): """ Rename a file or folder from oldpath to newpath. :param oldpath: existing name of the file or folder :param newpath: new name for the file or folder - :type oldpath: _TransportPath - :type newpath: _TransportPath + :type oldpath: TransportPath + :type newpath: TransportPath :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string """ - oldpath = path_2_str(oldpath) - newpath = path_2_str(newpath) + oldpath = path_to_str(oldpath) + newpath = path_to_str(newpath) if not oldpath or not newpath: raise ValueError('oldpath and newpath must be non-empty strings') @@ -874,43 +878,43 @@ async def rename_async(self, oldpath: _TransportPath, newpath: _TransportPath): await self._sftp.rename(oldpath, newpath) - async def rmdir_async(self, path: _TransportPath): + async def rmdir_async(self, path: TransportPath): """Remove the folder named path. This works only for empty folders. For recursive remove, use rmtree. :param str path: absolute path to the folder to remove - :type path: _TransportPath + :type path: TransportPath """ - path = path_2_str(path) + path = path_to_str(path) try: await self._sftp.rmdir(path) except asyncssh.sftp.SFTPFailure: raise OSError(f'Error while removing directory {path}: probably directory is not empty') - async def rmtree_async(self, path: _TransportPath): + async def rmtree_async(self, path: TransportPath): """Remove the folder named path, and all its contents. :param str path: absolute path to the folder to remove - :type path: _TransportPath + :type path: TransportPath :raises OSError: if the operation fails """ - path = path_2_str(path) + path = path_to_str(path) try: await self._sftp.rmtree(path, ignore_errors=False) except asyncssh.Error as exc: raise OSError(f'Error while removing directory tree {path}: {exc}') - async def path_exists_async(self, path: _TransportPath): + async def path_exists_async(self, path: TransportPath): """Returns True if path exists, False otherwise. :param path: path to check - :type path: _TransportPath + :type path: TransportPath """ - path = path_2_str(path) + path = path_to_str(path) return await self._sftp.exists(path) async def whoami_async(self): @@ -932,20 +936,20 @@ async def whoami_async(self): self.logger.error(f"Problem executing whoami. Exit code: {retval}, stdout: '{username}', stderr: '{stderr}'") raise OSError(f'Error while executing whoami. Exit code: {retval}') - async def symlink_async(self, remotesource: _TransportPath, remotedestination: _TransportPath): + async def symlink_async(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: absolute path to remote source :param remotedestination: absolute path to remote destination - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :raises ValueError: if remotedestination has patterns """ - remotesource = path_2_str(remotesource) - remotedestination = path_2_str(remotedestination) + remotesource = path_to_str(remotesource) + remotedestination = path_to_str(remotedestination) if self.has_magic(remotesource): if self.has_magic(remotedestination): @@ -961,7 +965,7 @@ async def symlink_async(self, remotesource: _TransportPath, remotedestination: _ else: await self._sftp.symlink(remotesource, remotedestination) - async def glob_async(self, pathname: _TransportPath): + async def glob_async(self, pathname: TransportPath): """Return a list of paths matching a pathname pattern. The pattern may contain simple shell-style wildcards a la fnmatch. @@ -969,27 +973,27 @@ async def glob_async(self, pathname: _TransportPath): :param pathname: the pathname pattern to match. It should only be absolute path. - :type pathname: _TransportPath + :type pathname: TransportPath :return: a list of paths matching the pattern. """ - pathname = path_2_str(pathname) + pathname = path_to_str(pathname) return await self._sftp.glob(pathname) - async def chmod_async(self, path: _TransportPath, mode: int, follow_symlinks: bool = True): + async def chmod_async(self, path: TransportPath, mode: int, follow_symlinks: bool = True): """Change the permissions of a file. :param path: path to the file :param mode: the new permissions :param bool follow_symlinks: if True, follow symbolic links - :type path: _TransportPath + :type path: TransportPath :type mode: int :type follow_symlinks: bool :raises OSError: if the path is empty """ - path = path_2_str(path) + path = path_to_str(path) if not path: raise OSError('Input path is an empty argument.') try: @@ -997,20 +1001,20 @@ async def chmod_async(self, path: _TransportPath, mode: int, follow_symlinks: bo except asyncssh.sftp.SFTPNoSuchFile as exc: raise OSError(f'Error {exc}, directory does not exists') - async def chown_async(self, path: _TransportPath, uid: int, gid: int): + async def chown_async(self, path: TransportPath, uid: int, gid: int): """Change the owner and group id of a file. :param path: path to the file :param uid: the new owner id :param gid: the new group id - :type path: _TransportPath + :type path: TransportPath :type uid: int :type gid: int :raises OSError: if the path is empty """ - path = path_2_str(path) + path = path_to_str(path) if not path: raise OSError('Input path is an empty argument.') try: @@ -1020,9 +1024,9 @@ async def chown_async(self, path: _TransportPath, uid: int, gid: int): async def copy_from_remote_to_remote_async( self, - transportdestination: Union['BlockingTransport', 'AsyncTransport'], - remotesource: _TransportPath, - remotedestination: _TransportPath, + transportdestination: Union['Transport', 'AsyncTransport'], + remotesource: TransportPath, + remotedestination: TransportPath, **kwargs, ): """Copy files or folders from a remote computer to another remote computer, asynchronously. @@ -1033,9 +1037,9 @@ async def copy_from_remote_to_remote_async( :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get - :type transportdestination: Union['BlockingTransport', 'AsyncTransport'] - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type transportdestination: Union['Transport', 'AsyncTransport'] + :type remotesource: TransportPath + :type remotedestination: TransportPath .. note:: the keyword 'dereference' SHOULD be set to False for the final put (onto the destination), while it can be set to the @@ -1079,12 +1083,12 @@ async def copy_from_remote_to_remote_async( os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put ) - def gotocomputer_command(self, remotedir: _TransportPath): + def gotocomputer_command(self, remotedir: TransportPath): """Return a string to be used to connect to the remote computer. :param remotedir: the remote directory to connect to - :type remotedir: _TransportPath + :type remotedir: TransportPath """ connect_string = self._gotocomputer_string(remotedir) cmd = f'ssh -t {self.machine} {connect_string}' diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index 4b9cb2bd7b..fea7ee4bfa 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -22,9 +22,9 @@ from aiida.common.lang import classproperty from aiida.common.warnings import warn_deprecation -__all__ = ('Transport', 'AsyncTransport', 'BlockingTransport') +__all__ = ('AsyncTransport', 'Transport') -_TransportPath = Union[str, Path, PurePosixPath] +TransportPath = Union[str, Path, PurePosixPath] def validate_positive_number(ctx, param, value): @@ -43,8 +43,8 @@ def validate_positive_number(ctx, param, value): return value -def path_2_str(path: _TransportPath) -> str: - """Convert an instance of _TransportPath = Union[str, Path, PurePosixPath] instance to a string.""" +def path_to_str(path: TransportPath) -> str: + """Convert an instance of TransportPath = Union[str, Path, PurePosixPath] instance to a string.""" # We could check if the path is a Path or PurePosixPath instance, but it's too much overhead. return str(path) @@ -255,8 +255,8 @@ def get_safe_open_interval(self): """ return self._safe_open_interval - def has_magic(self, string: _TransportPath): - string = path_2_str(string) + def has_magic(self, string: TransportPath): + string = path_to_str(string) """Return True if the given string contains any special shell characters.""" return self._MAGIC_CHECK.search(string) is not None @@ -273,7 +273,7 @@ def _gotocomputer_string(self, remotedir): return connect_string -class BlockingTransport(abc.ABC, _BaseTransport): +class Transport(abc.ABC, _BaseTransport): """Abstract class for a generic blocking transport. A plugin inhereting from this class should implement the blocking methods, only.""" @@ -296,18 +296,18 @@ def __str__(self): """return [Transport class or subclass]""" @abc.abstractmethod - def chmod(self, path: _TransportPath, mode): + def chmod(self, path: TransportPath, mode): """Change permissions of a path. :param path: path to file :param mode: new permissions - :type path: _TransportPath + :type path: TransportPath :type mode: int """ @abc.abstractmethod - def chown(self, path: _TransportPath, uid: int, gid: int): + def chown(self, path: TransportPath, uid: int, gid: int): """Change the owner (uid) and group (gid) of a file. As with python's os.chown function, you must pass both arguments, so if you only want to change one, use stat first to retrieve the @@ -317,13 +317,13 @@ def chown(self, path: _TransportPath, uid: int, gid: int): :param uid: new owner's uid :param gid: new group id - :type path: _TransportPath + :type path: TransportPath :type uid: int :type gid: int """ @abc.abstractmethod - def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False, recursive=True): + def copy(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False, recursive=True): """Copy a file or a directory from remote source to remote destination (On the same remote machine) @@ -332,8 +332,8 @@ def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves :param recursive: if True copy directories recursively, otherwise only copy the specified file(s) - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :type recursive: bool @@ -341,7 +341,7 @@ def copy(self, remotesource: _TransportPath, remotedestination: _TransportPath, """ @abc.abstractmethod - def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): + def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copy a file from remote source to remote destination (On the same remote machine) @@ -349,15 +349,15 @@ def copyfile(self, remotesource: _TransportPath, remotedestination: _TransportPa :param remotedestination: path of the remote destination directory / file :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :raises OSError: if one of src or dst does not exist """ @abc.abstractmethod - def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): + def copytree(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copy a folder from remote source to remote destination (On the same remote machine) @@ -365,8 +365,8 @@ def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPa :param remotedestination: path of the remote destination directory / file :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :raise OSError: if one of src or dst does not exist @@ -375,9 +375,9 @@ def copytree(self, remotesource: _TransportPath, remotedestination: _TransportPa ## non-abtract methods. Plugin developers can safely ingore developing these methods def copy_from_remote_to_remote( self, - transportdestination: Union['BlockingTransport', 'AsyncTransport'], - remotesource: _TransportPath, - remotedestination: _TransportPath, + transportdestination: Union['Transport', 'AsyncTransport'], + remotesource: TransportPath, + remotedestination: TransportPath, **kwargs, ): """Copy files or folders from a remote computer to another remote computer. @@ -388,9 +388,9 @@ def copy_from_remote_to_remote( :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get - :type transportdestination: Union['BlockingTransport', 'AsyncTransport'] - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type transportdestination: Union['Transport', 'AsyncTransport'] + :type remotesource: TransportPath + :type remotedestination: TransportPath .. note:: the keyword 'dereference' SHOULD be set to False for the final put (onto the destination), while it can be set to the @@ -430,12 +430,12 @@ def copy_from_remote_to_remote( # from sandbox.get_abs_path('*') would not work for files # beginning with a dot ('.'). for filename in sandbox.get_content_list(): - # no matter is transpordestination is BlockingTransport or AsyncTransport + # no matter is transpordestination is Transport or AsyncTransport # the following method will work, as both classes support put(), blocking method transportdestination.put(os.path.join(sandbox.abspath, filename), remotedestination, **kwargs_put) @abc.abstractmethod - def _exec_command_internal(self, command: str, workdir: Optional[_TransportPath] = None, **kwargs): + def _exec_command_internal(self, command: str, workdir: Optional[TransportPath] = None, **kwargs): """Execute the command on the shell, similarly to os.system. Enforce the execution to be run from `workdir`. @@ -448,14 +448,14 @@ def _exec_command_internal(self, command: str, workdir: Optional[_TransportPath] in the specified working directory. :type command: str - :type workdir: _TransportPath + :type workdir: TransportPath :return: stdin, stdout, stderr and the session, when this exists \ (can be None). """ @abc.abstractmethod - def exec_command_wait_bytes(self, command: str, stdin=None, workdir: Optional[_TransportPath] = None, **kwargs): + def exec_command_wait_bytes(self, command: str, stdin=None, workdir: Optional[TransportPath] = None, **kwargs): """Execute the command on the shell, waits for it to finish, and return the retcode, the stdout and the stderr as bytes. @@ -469,13 +469,13 @@ def exec_command_wait_bytes(self, command: str, stdin=None, workdir: Optional[_T in the specified working directory. :type command: str - :type workdir: _TransportPath + :type workdir: TransportPath :return: a tuple: the retcode (int), stdout (bytes) and stderr (bytes). """ def exec_command_wait( - self, command, stdin=None, encoding='utf-8', workdir: Optional[_TransportPath] = None, **kwargs + self, command, stdin=None, encoding='utf-8', workdir: Optional[TransportPath] = None, **kwargs ): """Executes the specified command and waits for it to finish. @@ -496,7 +496,7 @@ def exec_command_wait( :type command: str :type encoding: str - :type workdir: _TransportPath + :type workdir: TransportPath :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both strings, decoded with the specified encoding. @@ -508,7 +508,7 @@ def exec_command_wait( return (retval, stdout_bytes.decode(encoding), stderr_bytes.decode(encoding)) @abc.abstractmethod - def get(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + def get(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Retrieve a file or folder from remote source to local destination both localpath and remotepath must be an absolute path. @@ -518,32 +518,32 @@ def get(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kw :param remotepath: remote_folder_path :param localpath: (local_folder_path - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath """ @abc.abstractmethod - def getfile(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + def getfile(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Retrieve a file from remote source to local destination both localpath and remotepath must be an absolute path. :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath """ @abc.abstractmethod - def gettree(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + def gettree(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Retrieve a folder recursively from remote source to local destination both localpath and remotepath must be an absolute path. :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath """ @abc.abstractmethod @@ -562,7 +562,7 @@ def getcwd(self): ) @abc.abstractmethod - def get_attribute(self, path: _TransportPath): + def get_attribute(self, path: TransportPath): """Return an object FixedFieldsAttributeDict for file in a given path, as defined in aiida.common.extendeddicts Each attribute object consists in a dictionary with the following keys: @@ -581,17 +581,17 @@ def get_attribute(self, path: _TransportPath): :param path: path to file - :type path: _TransportPath + :type path: TransportPath :return: object FixedFieldsAttributeDict """ - def get_mode(self, path: _TransportPath): + def get_mode(self, path: TransportPath): """Return the portion of the file's mode that can be set by chmod(). :param path: path to file - :type path: _TransportPath + :type path: TransportPath :return: the portion of the file's mode that can be set by chmod() """ @@ -600,31 +600,31 @@ def get_mode(self, path: _TransportPath): return stat.S_IMODE(self.get_attribute(path).st_mode) @abc.abstractmethod - def isdir(self, path: _TransportPath): + def isdir(self, path: TransportPath): """True if path is an existing directory. Return False also if the path does not exist. :param path: path to directory - :type path: _TransportPath + :type path: TransportPath :return: boolean """ @abc.abstractmethod - def isfile(self, path: _TransportPath): + def isfile(self, path: TransportPath): """Return True if path is an existing file. Return False also if the path does not exist. :param path: path to file - :type path: _TransportPath + :type path: TransportPath :return: boolean """ @abc.abstractmethod - def listdir(self, path: _TransportPath = '.', pattern=None): + def listdir(self, path: TransportPath = '.', pattern=None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -634,12 +634,12 @@ def listdir(self, path: _TransportPath = '.', pattern=None): :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: _TransportPath + :type path: TransportPath :return: a list of strings """ - def listdir_withattributes(self, path: _TransportPath = '.', pattern: Optional[str] = None): + def listdir_withattributes(self, path: TransportPath = '.', pattern: Optional[str] = None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -649,7 +649,7 @@ def listdir_withattributes(self, path: _TransportPath = '.', pattern: Optional[s taken from DEPRECATED `self.getcwd()`. :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: _TransportPath + :type path: TransportPath :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is @@ -665,7 +665,7 @@ def listdir_withattributes(self, path: _TransportPath = '.', pattern: Optional[s (if the file is a folder, a directory, ...). 'attributes' behaves as the output of transport.get_attribute(); isdir is a boolean indicating if the object is a directory or not. """ - path = path_2_str(path) + path = path_to_str(path) retlist = [] if path.startswith('/'): cwd = Path(path).resolve().as_posix() @@ -683,7 +683,7 @@ def listdir_withattributes(self, path: _TransportPath = '.', pattern: Optional[s return retlist @abc.abstractmethod - def makedirs(self, path: _TransportPath, ignore_existing=False): + def makedirs(self, path: TransportPath, ignore_existing=False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -691,38 +691,38 @@ def makedirs(self, path: _TransportPath, ignore_existing=False): :param path: directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: _TransportPath + :type path: TransportPath :raises: OSError, if directory at path already exists """ @abc.abstractmethod - def mkdir(self, path: _TransportPath, ignore_existing=False): + def mkdir(self, path: TransportPath, ignore_existing=False): """Create a folder (directory) named path. :param path: name of the folder to create :param bool ignore_existing: if True, does not give any error if the directory already exists - :type path: _TransportPath + :type path: TransportPath :raises: OSError, if directory at path already exists """ @abc.abstractmethod - def normalize(self, path: _TransportPath = '.'): + def normalize(self, path: TransportPath = '.'): """Return the normalized path (on the server) of a given path. This can be used to quickly resolve symbolic links or determine what the server is considering to be the "current folder". :param path: path to be normalized - :type path: _TransportPath + :type path: TransportPath :raise OSError: if the path can't be resolved on the server """ @abc.abstractmethod - def put(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + def put(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Put a file or a directory from local src to remote dst. both localpath and remotepath must be an absolute path. Redirects to putfile and puttree. @@ -733,83 +733,83 @@ def put(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kw :param localpath: absolute path to local source :param remotepath: path to remote destination - :type localpath: _TransportPath - :type remotepath: _TransportPath + :type localpath: TransportPath + :type remotepath: TransportPath """ @abc.abstractmethod - def putfile(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + def putfile(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Put a file from local src to remote dst. both localpath and remotepath must be an absolute path. :param localpath: absolute path to local file :param remotepath: path to remote file - :type localpath: _TransportPath - :type remotepath: _TransportPath + :type localpath: TransportPath + :type remotepath: TransportPath """ @abc.abstractmethod - def puttree(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + def puttree(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Put a folder recursively from local src to remote dst. both localpath and remotepath must be an absolute path. :param localpath: absolute path to local folder :param remotepath: path to remote folder - :type localpath: _TransportPath - :type remotepath: _TransportPath + :type localpath: TransportPath + :type remotepath: TransportPath """ @abc.abstractmethod - def remove(self, path: _TransportPath): + def remove(self, path: TransportPath): """Remove the file at the given path. This only works on files; for removing folders (directories), use rmdir. :param path: path to file to remove - :type path: _TransportPath + :type path: TransportPath :raise OSError: if the path is a directory """ @abc.abstractmethod - def rename(self, oldpath: _TransportPath, newpath: _TransportPath): + def rename(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param oldpath: existing name of the file or folder :param newpath: new name for the file or folder - :type oldpath: _TransportPath - :type newpath: _TransportPath + :type oldpath: TransportPath + :type newpath: TransportPath :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string """ @abc.abstractmethod - def rmdir(self, path: _TransportPath): + def rmdir(self, path: TransportPath): """Remove the folder named path. This works only for empty folders. For recursive remove, use rmtree. :param path: absolute path to the folder to remove - :type path: _TransportPath + :type path: TransportPath """ @abc.abstractmethod - def rmtree(self, path: _TransportPath): + def rmtree(self, path: TransportPath): """Remove recursively the content at path :param path: absolute path to remove - :type path: _TransportPath + :type path: TransportPath :raise OSError: if the rm execution failed. """ @abc.abstractmethod - def gotocomputer_command(self, remotedir: _TransportPath): + def gotocomputer_command(self, remotedir: TransportPath): """Return a string to be run using os.system in order to connect via the transport to the remote directory. @@ -821,19 +821,19 @@ def gotocomputer_command(self, remotedir: _TransportPath): :param remotedir: the full path of the remote directory - :type remotedir: _TransportPath + :type remotedir: TransportPath """ @abc.abstractmethod - def symlink(self, remotesource: _TransportPath, remotedestination: _TransportPath): + def symlink(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: remote source :param remotedestination: remote destination - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath """ def whoami(self): @@ -856,16 +856,16 @@ def whoami(self): raise OSError(f'Error while executing whoami. Exit code: {retval}') @abc.abstractmethod - def path_exists(self, path: _TransportPath): + def path_exists(self, path: TransportPath): """Returns True if path exists, False otherwise. :param path: path to check for existence - :type path: _TransportPath""" + :type path: TransportPath""" # The following definitions are almost copied and pasted # from the python module glob. - def glob(self, pathname: _TransportPath): + def glob(self, pathname: TransportPath): """Return a list of paths matching a pathname pattern. The pattern may contain simple shell-style wildcards a la fnmatch. @@ -874,11 +874,11 @@ def glob(self, pathname: _TransportPath): It should only be an absolute path. DEPRECATED: using relative path is deprecated. - :type pathname: _TransportPath + :type pathname: TransportPath :return: a list of paths matching the pattern. """ - pathname = path_2_str(pathname) + pathname = path_to_str(pathname) if not pathname.startswith('/'): warn_deprecation( 'Using relative paths across transport in `glob` is deprecated ' @@ -1038,7 +1038,7 @@ async def listdir_async(self, path, pattern=None): """Counterpart to listdir() that is async.""" return self.listdir(path, pattern) - async def listdir_withattributes_async(self, path: _TransportPath, pattern=None): + async def listdir_withattributes_async(self, path: TransportPath, pattern=None): """Counterpart to listdir_withattributes() that is async.""" return self.listdir_withattributes(path, pattern) @@ -1129,25 +1129,25 @@ async def close_async(self): """ @abc.abstractmethod - async def chmod_async(self, path: _TransportPath, mode: int): + async def chmod_async(self, path: TransportPath, mode: int): """Change permissions of a path. :param path: path to file or directory :param mode: new permissions - :type path: _TransportPath + :type path: TransportPath :type mode: int """ @abc.abstractmethod - async def chown_async(self, path: _TransportPath, uid: int, gid: int): + async def chown_async(self, path: TransportPath, uid: int, gid: int): """Change the owner (uid) and group (gid) of a file. :param path: path to file :param uid: user id of the new owner :param gid: group id of the new owner - :type path: _TransportPath + :type path: TransportPath :type uid: int :type gid: int """ @@ -1162,8 +1162,8 @@ async def copy_async(self, remotesource, remotedestination, dereference=False, r :param dereference: follow symbolic links :param recursive: copy recursively - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :type recursive: bool @@ -1171,7 +1171,7 @@ async def copy_async(self, remotesource, remotedestination, dereference=False, r """ @abc.abstractmethod - async def copyfile_async(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): + async def copyfile_async(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copy a file from remote source to remote destination (On the same remote machine) @@ -1179,14 +1179,14 @@ async def copyfile_async(self, remotesource: _TransportPath, remotedestination: :param remotedestination: path to the remote destination file :param dereference: follow symbolic links - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :raises: OSError, src does not exist or if the copy execution failed.""" @abc.abstractmethod - async def copytree_async(self, remotesource: _TransportPath, remotedestination: _TransportPath, dereference=False): + async def copytree_async(self, remotesource: TransportPath, remotedestination: TransportPath, dereference=False): """Copy a folder from remote source to remote destination (On the same remote machine) @@ -1194,8 +1194,8 @@ async def copytree_async(self, remotesource: _TransportPath, remotedestination: :param remotedestination: path to the remote destination folder :param dereference: follow symbolic links - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type remotesource: TransportPath + :type remotedestination: TransportPath :type dereference: bool :raises: OSError, src does not exist or if the copy execution failed.""" @@ -1203,9 +1203,9 @@ async def copytree_async(self, remotesource: _TransportPath, remotedestination: @abc.abstractmethod async def copy_from_remote_to_remote_async( self, - transportdestination: Union['BlockingTransport', 'AsyncTransport'], - remotesource: _TransportPath, - remotedestination: _TransportPath, + transportdestination: Union['Transport', 'AsyncTransport'], + remotesource: TransportPath, + remotedestination: TransportPath, **kwargs, ): """Copy files or folders from a remote computer to another remote computer. @@ -1216,9 +1216,9 @@ async def copy_from_remote_to_remote_async( :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get - :type transportdestination: Union['BlockingTransport', 'AsyncTransport'] - :type remotesource: _TransportPath - :type remotedestination: _TransportPath + :type transportdestination: Union['Transport', 'AsyncTransport'] + :type remotesource: TransportPath + :type remotedestination: TransportPath """ @abc.abstractmethod @@ -1227,27 +1227,27 @@ async def exec_command_wait_async( command: str, stdin: Optional[str] = None, encoding: str = 'utf-8', - workdir: Optional[_TransportPath] = None, + workdir: Optional[TransportPath] = None, **kwargs, ): """Executes the specified command and waits for it to finish. :param command: the command to execute :param stdin: input to the command - :param encoding: (IGNORED) this is here just to keep the same signature as the one in `BlockingTransport` class + :param encoding: (IGNORED) this is here just to keep the same signature as the one in `Transport` class :param workdir: working directory where the command will be executed :type command: str :type stdin: str :type encoding: str - :type workdir: Union[_TransportPath, None] + :type workdir: Union[TransportPath, None] :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both strings. :rtype: Tuple[int, str, str] """ @abc.abstractmethod - async def get_async(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + async def get_async(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Retrieve a file or folder from remote source to local destination both remotepath and localpath must be absolute paths @@ -1257,36 +1257,36 @@ async def get_async(self, remotepath: _TransportPath, localpath: _TransportPath, :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath """ @abc.abstractmethod - async def getfile_async(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + async def getfile_async(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Retrieve a file from remote source to local destination both remotepath and localpath must be absolute paths :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath """ @abc.abstractmethod - async def gettree_async(self, remotepath: _TransportPath, localpath: _TransportPath, *args, **kwargs): + async def gettree_async(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwargs): """Retrieve a folder recursively from remote source to local destination both remotepath and localpath must be absolute paths :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: _TransportPath - :type localpath: _TransportPath + :type remotepath: TransportPath + :type localpath: TransportPath """ @abc.abstractmethod - async def get_attribute_async(self, path: _TransportPath): + async def get_attribute_async(self, path: TransportPath): """Return an object FixedFieldsAttributeDict for file in a given path, as defined in aiida.common.extendeddicts Each attribute object consists in a dictionary with the following keys: @@ -1305,17 +1305,17 @@ async def get_attribute_async(self, path: _TransportPath): :param path: path to file - :type path: _TransportPath + :type path: TransportPath :return: object FixedFieldsAttributeDict """ - async def get_mode_async(self, path: _TransportPath): + async def get_mode_async(self, path: TransportPath): """Return the portion of the file's mode that can be set by chmod(). :param str path: path to file - :type path: _TransportPath + :type path: TransportPath :return: the portion of the file's mode that can be set by chmod() """ @@ -1325,31 +1325,31 @@ async def get_mode_async(self, path: _TransportPath): return stat.S_IMODE(attr.st_mode) @abc.abstractmethod - async def isdir_async(self, path: _TransportPath): + async def isdir_async(self, path: TransportPath): """True if path is an existing directory. Return False also if the path does not exist. :param path: path to directory - :type path: _TransportPath + :type path: TransportPath :return: boolean """ @abc.abstractmethod - async def isfile_async(self, path: _TransportPath): + async def isfile_async(self, path: TransportPath): """Return True if path is an existing file. Return False also if the path does not exist. :param path: path to file - :type path: _TransportPath + :type path: TransportPath :return: boolean """ @abc.abstractmethod - async def listdir_async(self, path: _TransportPath, pattern: Optional[str] = None): + async def listdir_async(self, path: TransportPath, pattern: Optional[str] = None): """Return a list of the names of the entries in the given path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory. @@ -1358,7 +1358,7 @@ async def listdir_async(self, path: _TransportPath, pattern: Optional[str] = Non :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: _TransportPath + :type path: TransportPath :return: a list of strings """ @@ -1366,7 +1366,7 @@ async def listdir_async(self, path: _TransportPath, pattern: Optional[str] = Non @abc.abstractmethod async def listdir_withattributes_async( self, - path: _TransportPath, + path: TransportPath, pattern: Optional[str] = None, ): """Return a list of the names of the entries in the given path. @@ -1377,7 +1377,7 @@ async def listdir_withattributes_async( :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: _TransportPath + :type path: TransportPath :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is @@ -1395,7 +1395,7 @@ async def listdir_withattributes_async( """ @abc.abstractmethod - async def makedirs_async(self, path: _TransportPath, ignore_existing=False): + async def makedirs_async(self, path: TransportPath, ignore_existing=False): """Super-mkdir; create a leaf directory and all intermediate ones. Works like mkdir, except that any intermediate path segment (not just the rightmost) will be created if it does not exist. @@ -1403,38 +1403,38 @@ async def makedirs_async(self, path: _TransportPath, ignore_existing=False): :param path: directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: _TransportPath + :type path: TransportPath :raises: OSError, if directory at path already exists """ @abc.abstractmethod - async def mkdir_async(self, path: _TransportPath, ignore_existing=False): + async def mkdir_async(self, path: TransportPath, ignore_existing=False): """Create a folder (directory) named path. :param path: name of the folder to create :param bool ignore_existing: if True, does not give any error if the directory already exists. - :type path: _TransportPath + :type path: TransportPath :raises: OSError, if directory at path already exists """ @abc.abstractmethod - async def normalize_async(self, path: _TransportPath): + async def normalize_async(self, path: TransportPath): """Return the normalized path (on the server) of a given path. This can be used to quickly resolve symbolic links or determine what the server is considering to be the "current folder". :param path: path to be normalized - :type path: _TransportPath + :type path: TransportPath :raise OSError: if the path can't be resolved on the server """ @abc.abstractmethod - async def put_async(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + async def put_async(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Put a file or a directory from local src to remote dst. both localpath and remotepath must be absolute paths. Redirects to putfile and puttree. @@ -1445,83 +1445,83 @@ async def put_async(self, localpath: _TransportPath, remotepath: _TransportPath, :param localpath: absolute path to local source :param remotepath: path to remote destination - :type localpath: _TransportPath - :type remotepath: _TransportPath + :type localpath: TransportPath + :type remotepath: TransportPath """ @abc.abstractmethod - async def putfile_async(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + async def putfile_async(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Put a file from local src to remote dst. both localpath and remotepath must be absolute paths. :param localpath: absolute path to local file :param remotepath: path to remote file - :type localpath: _TransportPath - :type remotepath: _TransportPath + :type localpath: TransportPath + :type remotepath: TransportPath """ @abc.abstractmethod - async def puttree_async(self, localpath: _TransportPath, remotepath: _TransportPath, *args, **kwargs): + async def puttree_async(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwargs): """Put a folder recursively from local src to remote dst. both localpath and remotepath must be absolute paths. :param localpath: absolute path to local folder :param remotepath: path to remote folder - :type localpath: _TransportPath - :type remotepath: _TransportPath + :type localpath: TransportPath + :type remotepath: TransportPath """ @abc.abstractmethod - async def remove_async(self, path: _TransportPath): + async def remove_async(self, path: TransportPath): """Remove the file at the given path. This only works on files; for removing folders (directories), use rmdir. :param path: path to file to remove - :type path: _TransportPath + :type path: TransportPath :raise OSError: if the path is a directory """ @abc.abstractmethod - async def rename_async(self, oldpath: _TransportPath, newpath: _TransportPath): + async def rename_async(self, oldpath: TransportPath, newpath: TransportPath): """Rename a file or folder from oldpath to newpath. :param oldpath: existing name of the file or folder :param newpath: new name for the file or folder - :type oldpath: _TransportPath - :type newpath: _TransportPath + :type oldpath: TransportPath + :type newpath: TransportPath :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string """ @abc.abstractmethod - async def rmdir_async(self, path: _TransportPath): + async def rmdir_async(self, path: TransportPath): """Remove the folder named path. This works only for empty folders. For recursive remove, use rmtree. :param path: absolute path to the folder to remove - :type path: _TransportPath + :type path: TransportPath """ @abc.abstractmethod - async def rmtree_async(self, path: _TransportPath): + async def rmtree_async(self, path: TransportPath): """Remove recursively the content at path :param path: absolute path to remove - :type path: _TransportPath + :type path: TransportPath :raise OSError: if the rm execution failed. """ @abc.abstractmethod - def gotocomputer_command(self, remotedir: _TransportPath): + def gotocomputer_command(self, remotedir: TransportPath): """Return a string to be run using os.system in order to connect via the transport to the remote directory. @@ -1536,11 +1536,11 @@ def gotocomputer_command(self, remotedir: _TransportPath): :param remotedir: the full path of the remote directory - :type remotedir: _TransportPath + :type remotedir: TransportPath """ @abc.abstractmethod - async def symlink_async(self, remotesource: _TransportPath, remotedestination: _TransportPath): + async def symlink_async(self, remotesource: TransportPath, remotedestination: TransportPath): """Create a symbolic link between the remote source and the remote destination. :param remotesource: remote source @@ -1560,16 +1560,16 @@ async def whoami_async(self): """ @abc.abstractmethod - async def path_exists_async(self, path: _TransportPath): + async def path_exists_async(self, path: TransportPath): """Returns True if path exists, False otherwise. :param path: path to check for existence - :type path: _TransportPath + :type path: TransportPath """ @abc.abstractmethod - async def glob_async(self, pathname: _TransportPath): + async def glob_async(self, pathname: TransportPath): """Return a list of paths matching a pathname pattern. The pattern may contain simple shell-style wildcards a la fnmatch. @@ -1577,7 +1577,7 @@ async def glob_async(self, pathname: _TransportPath): :param pathname: the pathname pattern to match. It should only be absolute path. - :type pathname: _TransportPath + :type pathname: TransportPath :return: a list of paths matching the pattern. """ @@ -1683,10 +1683,6 @@ def normalize(self, *args, **kwargs): return self.run_command_blocking(self.normalize_async, *args, **kwargs) -# This is here for backwards compatibility -Transport = BlockingTransport - - class TransportInternalError(InternalError): """Raised if there is a transport error that is raised to an internal error (e.g. a transport method called without opening the channel first). diff --git a/tests/engine/daemon/test_execmanager.py b/tests/engine/daemon/test_execmanager.py index 9ced2a0cd4..7ff83fc1ed 100644 --- a/tests/engine/daemon/test_execmanager.py +++ b/tests/engine/daemon/test_execmanager.py @@ -129,7 +129,6 @@ def test_retrieve_files_from_list( with LocalTransport() as transport: node = generate_calcjob_node(workdir=source) runner.loop.run_until_complete(execmanager.retrieve_files_from_list(node, transport, target, retrieve_list)) - # await execmanager.retrieve_files_from_list(node, transport, target, retrieve_list) assert serialize_file_hierarchy(target, read_bytes=False) == expected_hierarchy diff --git a/tests/manage/tests/test_pytest_fixtures.py b/tests/manage/tests/test_pytest_fixtures.py index c3f4de39dc..763fd0c721 100644 --- a/tests/manage/tests/test_pytest_fixtures.py +++ b/tests/manage/tests/test_pytest_fixtures.py @@ -6,7 +6,7 @@ from aiida.manage.configuration import get_config from aiida.manage.configuration.config import Config from aiida.orm import Computer -from aiida.transports import AsyncTransport, BlockingTransport +from aiida.transports import AsyncTransport, Transport def test_profile_config(): @@ -29,7 +29,7 @@ def test_aiida_computer_local(aiida_computer_local): assert computer.transport_type == 'core.local' with computer.get_transport() as transport: - assert isinstance(transport, BlockingTransport) + assert isinstance(transport, Transport) # Calling it again with the same label should simply return the existing computer computer_alt = aiida_computer_local(label=computer.label) @@ -52,7 +52,7 @@ def test_aiida_computer_ssh(aiida_computer_ssh): assert computer.transport_type == 'core.ssh' with computer.get_transport() as transport: - assert isinstance(transport, BlockingTransport) + assert isinstance(transport, Transport) # Calling it again with the same label should simply return the existing computer computer_alt = aiida_computer_ssh(label=computer.label) diff --git a/tests/orm/test_computers.py b/tests/orm/test_computers.py index 572cfa9c7c..adb1cbaae8 100644 --- a/tests/orm/test_computers.py +++ b/tests/orm/test_computers.py @@ -67,7 +67,7 @@ def test_get_minimum_job_poll_interval(self): # No transport class defined: fall back on class default. assert computer.get_minimum_job_poll_interval() == Computer.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT - # BlockingTransport class defined: use default of the transport class. + # Transport class defined: use default of the transport class. transport = TransportFactory('core.local') computer.transport_type = 'core.local' assert computer.get_minimum_job_poll_interval() == transport.DEFAULT_MINIMUM_JOB_POLL_INTERVAL diff --git a/tests/plugins/test_factories.py b/tests/plugins/test_factories.py index f077e5ba6d..a8d82478b4 100644 --- a/tests/plugins/test_factories.py +++ b/tests/plugins/test_factories.py @@ -18,7 +18,7 @@ from aiida.schedulers import Scheduler from aiida.tools.data.orbital import Orbital from aiida.tools.dbimporters import DbImporter -from aiida.transports import AsyncTransport, BlockingTransport +from aiida.transports import AsyncTransport, Transport def custom_load_entry_point(group, name): @@ -69,7 +69,7 @@ def work_function(): }, 'aiida.transports': { 'valid_A': AsyncTransport, - 'valid_B': BlockingTransport, + 'valid_B': Transport, 'invalid': Node, }, 'aiida.workflows': { @@ -191,7 +191,7 @@ def test_storage_factory(self): def test_transport_factory(self): """Test the ``TransportFactory``.""" plugin = factories.TransportFactory('valid_B') - assert plugin is BlockingTransport + assert plugin is Transport plugin = factories.TransportFactory('valid_A') assert plugin is AsyncTransport diff --git a/tests/test_calculation_node.py b/tests/test_calculation_node.py index d845de6df4..1c0af3a2bb 100644 --- a/tests/test_calculation_node.py +++ b/tests/test_calculation_node.py @@ -120,7 +120,7 @@ def test_get_authinfo(self): def test_get_transport(self): """Test that we can get the Transport object from the calculation instance.""" - from aiida.transports import AsyncTransport, BlockingTransport + from aiida.transports import AsyncTransport, Transport transport = self.calcjob.get_transport() - assert isinstance(transport, BlockingTransport) or isinstance(transport, AsyncTransport) + assert isinstance(transport, Transport) or isinstance(transport, AsyncTransport) diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 707a38020b..a6de95f236 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -26,7 +26,7 @@ import psutil import pytest from aiida.plugins import SchedulerFactory, TransportFactory, entry_point -from aiida.transports import AsyncTransport, BlockingTransport +from aiida.transports import AsyncTransport, Transport # TODO : test for copy with pattern # TODO : test for copy with/without patterns, overwriting folder @@ -35,7 +35,7 @@ @pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) -def custom_transport(request, tmp_path, monkeypatch) -> Union['BlockingTransport', 'AsyncTransport']: +def custom_transport(request, tmp_path, monkeypatch) -> Union['Transport', 'AsyncTransport']: """Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``.""" plugin = TransportFactory(request.param) diff --git a/utils/dependency_management.py b/utils/dependency_management.py index 14d4c6c862..0d750bb442 100644 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -295,8 +295,11 @@ def check_requirements(extras, github_annotate): for requirement_abstract in requirements_abstract: for requirement_concrete in requirements_concrete: if '@' in str(requirement_concrete): + # `@` is not listed as a valid `specifier` in `class Specifier`. version = str(requirement_concrete).split('@')[1] abstract_contains = version in str(requirement_abstract) + # `abstract_contains` is a boolean indicating whether the requirement abstract + # contains the same version as in the concrete requirement. else: version = Specifier(str(requirement_concrete.specifier)).version abstract_contains = requirement_abstract.specifier.contains(version) From 178bf7be5762a75abed535e68f8df957c14b994c Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 27 Nov 2024 17:58:40 +0100 Subject: [PATCH 11/29] chnage from machine to machine_ --- src/aiida/transports/plugins/ssh_async.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 5050516577..07b0d0e54f 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -61,7 +61,7 @@ class AsyncSshTransport(AsyncTransport): # note, I intentionally wanted to keep connection parameters as simple as possible. _valid_auth_options = [ ( - 'machine', + 'machine_', { 'type': str, 'prompt': 'machine as in `ssh machine` command', @@ -100,7 +100,7 @@ class AsyncSshTransport(AsyncTransport): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.machine = kwargs.pop('machine') + self.machine = kwargs.pop('machine_') self.script_before = kwargs.pop('script_before', 'None') self.script_during = kwargs.pop('script_during', 'None') From cc0bc5c0f787761d1417021c8a2cf2a619da9177 Mon Sep 17 00:00:00 2001 From: Ali Date: Fri, 29 Nov 2024 14:30:55 +0100 Subject: [PATCH 12/29] review applied --- tests/plugins/test_factories.py | 8 ++++---- utils/dependency_management.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/plugins/test_factories.py b/tests/plugins/test_factories.py index a8d82478b4..524f387e7c 100644 --- a/tests/plugins/test_factories.py +++ b/tests/plugins/test_factories.py @@ -68,8 +68,8 @@ def work_function(): 'invalid': Node, }, 'aiida.transports': { - 'valid_A': AsyncTransport, - 'valid_B': Transport, + 'valid_AsyncTransport': AsyncTransport, + 'valid_Transport': Transport, 'invalid': Node, }, 'aiida.workflows': { @@ -190,10 +190,10 @@ def test_storage_factory(self): @pytest.mark.usefixtures('mock_load_entry_point') def test_transport_factory(self): """Test the ``TransportFactory``.""" - plugin = factories.TransportFactory('valid_B') + plugin = factories.TransportFactory('valid_Transport') assert plugin is Transport - plugin = factories.TransportFactory('valid_A') + plugin = factories.TransportFactory('valid_AsyncTransport') assert plugin is AsyncTransport with pytest.raises(InvalidEntryPointTypeError): diff --git a/utils/dependency_management.py b/utils/dependency_management.py index 0d750bb442..865ffd4ecc 100644 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -295,8 +295,9 @@ def check_requirements(extras, github_annotate): for requirement_abstract in requirements_abstract: for requirement_concrete in requirements_concrete: if '@' in str(requirement_concrete): - # `@` is not listed as a valid `specifier` in `class Specifier`. - version = str(requirement_concrete).split('@')[1] + # For versions that are using `@` specifier as for git repos, + # since `@` is not listed as a valid `specifier` in class `Specifier`. + version = '@'.join(str(requirement_concrete).split('@')[1:]) abstract_contains = version in str(requirement_abstract) # `abstract_contains` is a boolean indicating whether the requirement abstract # contains the same version as in the concrete requirement. From 3210c27f2345fee85cbd3ab0d5405c9d7a114fb6 Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 4 Dec 2024 14:30:52 +0100 Subject: [PATCH 13/29] copy-remote adopted with behaviour of asyncssh --- src/aiida/transports/plugins/ssh_async.py | 103 +++++++++++++++++----- 1 file changed, 82 insertions(+), 21 deletions(-) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 07b0d0e54f..346f497139 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -503,8 +503,8 @@ async def copy_async( ): """Copy a file or a folder from remote to remote. - :param remotesource: path to the remote source directory / file - :param remotedestination: path to the remote destination directory / file + :param remotesource: abs path to the remote source directory / file + :param remotedestination: abs path to the remote destination directory / file :param dereference: follow symbolic links :param recursive: copy recursively :param preserve: preserve file attributes @@ -528,27 +528,88 @@ async def copy_async( if not remotesource: raise ValueError('remotesource must be a non empty string') - try: + # SFTP.copy() supports remote copy only in very recent version OpenSSH 9.0 and later. + # For the older versions, it downloads the file and uploads it again! + # For performance reasons, we should check if the remote copy is supported, if so use + # self._sftp.mcopy() & self._sftp.copy() otherwise send a `cp` command to the remote machine. + # This is a temporary solution until the feature is implemented in asyncssh: + # See here: https://github.com/ronf/asyncssh/issues/724 + if False: + # self._sftp._supports_copy_data: + try: # type: ignore[unreachable] + if self.has_magic(remotesource): + await self._sftp.mcopy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + ) + else: + if not await self.path_exists_async(remotesource): + raise OSError(f'The remote path {remotesource} does not exist') + await self._sftp.copy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + ) + except asyncssh.sftp.SFTPFailure as exc: + raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') + else: + # I copy pasted the whole logic below from SshTransport class: + + async def _exec_cp(cp_exe: str, cp_flags: str, src: str, dst: str): + """Execute the ``cp`` command on the remote machine.""" + # to simplify writing the above copy function + command = f'{cp_exe} {cp_flags} {escape_for_bash(src)} {escape_for_bash(dst)}' + + retval, stdout, stderr = await self.exec_command_wait_async(command) + + if retval == 0: + if stderr.strip(): + self.logger.warning(f'There was nonempty stderr in the cp command: {stderr}') + else: + self.logger.error( + "Problem executing cp. Exit code: {}, stdout: '{}', " "stderr: '{}', command: '{}'".format( + retval, stdout, stderr, command + ) + ) + if 'No such file or directory' in str(stderr): + raise FileNotFoundError(f'Error while executing cp: {stderr}') + + raise OSError( + 'Error while executing cp. Exit code: {}, ' + "stdout: '{}', stderr: '{}', " + "command: '{}'".format(retval, stdout, stderr, command) + ) + + cp_exe = 'cp' + cp_flags = '-f' + + if recursive: + cp_flags += ' -r' + + if preserve: + cp_flags += ' -p' + + if dereference: + # use -L; --dereference is not supported on mac + cp_flags += ' -L' + if self.has_magic(remotesource): - await self._sftp.mcopy( - remotesource, - remotedestination, - preserve=preserve, - recurse=recursive, - follow_symlinks=dereference, - ) + to_copy_list = await self.glob_async(remotesource) + + if len(to_copy_list) > 1: + if not self.path_exists(remotedestination) or self.isfile(remotedestination): + raise OSError("Can't copy more than one file in the same destination file") + + for file in to_copy_list: + await _exec_cp(cp_exe, cp_flags, file, remotedestination) + else: - if not await self.path_exists_async(remotesource): - raise OSError(f'The remote path {remotesource} does not exist') - await self._sftp.copy( - remotesource, - remotedestination, - preserve=preserve, - recurse=recursive, - follow_symlinks=dereference, - ) - except asyncssh.sftp.SFTPFailure as exc: - raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') + await _exec_cp(cp_exe, cp_flags, remotesource, remotedestination) async def copyfile_async( self, From 65f0663e2d73eb799ffc1ac3fc0fe640b9c5baf6 Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 4 Dec 2024 18:23:41 +0100 Subject: [PATCH 14/29] remove str() use from test_all_plugins --- tests/transports/test_all_plugins.py | 199 +++++++++++++-------------- 1 file changed, 95 insertions(+), 104 deletions(-) diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index bc66b311b7..aad923d404 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -13,7 +13,6 @@ import io import os -import random import shutil import signal import tempfile @@ -21,7 +20,6 @@ import uuid from pathlib import Path from typing import Union -from pathlib import Path import psutil import pytest @@ -97,6 +95,7 @@ def test_makedirs(custom_transport, tmpdir): with pytest.raises(OSError): transport.mkdir(tmpdir / 'sampledir') + def test_is_dir(custom_transport, tmpdir): with custom_transport as transport: _scratch = tmpdir / 'sampledir' @@ -106,8 +105,6 @@ def test_is_dir(custom_transport, tmpdir): assert not transport.isdir(_scratch / 'does_not_exist') - - def test_rmtree(custom_transport, tmp_path_remote, tmp_path_local): """Verify the functioning of rmtree command""" with custom_transport as transport: @@ -137,6 +134,7 @@ def test_rmtree(custom_transport, tmp_path_remote, tmp_path_local): transport.rmdir(_remote / 'sampledir') assert not _scratch.exists() + def test_listdir(custom_transport, tmp_path_remote): """Create directories, verify listdir, delete a folder with subfolders""" with custom_transport as transport: @@ -154,7 +152,7 @@ def test_listdir(custom_transport, tmp_path_remote): assert sorted(list_found) == sorted(list_of_dir + list_of_files) - assert sorted(transport.listdir(tmp_path_remote, 'a*')), sorted(['as', 'a2', 'a4f']) + assert sorted(transport.listdir(tmp_path_remote, 'a*')), sorted(['as', 'a2', 'a4f']) assert sorted(transport.listdir(tmp_path_remote, 'a?')), sorted(['as', 'a2']) assert sorted(transport.listdir(tmp_path_remote, 'a[2-4]*')), sorted(['a2', 'a4f']) @@ -210,7 +208,7 @@ def test_dir_copy(custom_transport, tmp_path_remote): src_dir = tmp_path_remote / 'copy_src' transport.mkdir(src_dir) - dst_dir = tmp_path_remote / 'copy_dst' + dst_dir = tmp_path_remote / 'copy_dst' transport.copy(src_dir, dst_dir) with pytest.raises(ValueError): @@ -225,12 +223,12 @@ def test_dir_permissions_creation_modification(custom_transport, tmp_path_remote non-existing folder """ with custom_transport as transport: - directory = tmp_path_remote / 'test' + directory = tmp_path_remote / 'test' transport.makedirs(directory) # change permissions - transport.chmod(directory, 0o777) + transport.chmod(directory, 0o777) # test if the security bits have changed assert transport.get_mode(directory) == 0o777 @@ -257,13 +255,13 @@ def test_dir_permissions_creation_modification(custom_transport, tmp_path_remote fake_dir = 'pippo' with pytest.raises(OSError): # chmod to a non existing folder - transport.chmod(tmp_path_remote / fake_dir, 0o777) + transport.chmod(tmp_path_remote / fake_dir, 0o777) def test_dir_reading_permissions(custom_transport, tmp_path_remote): - """Try to enter a directory with no read & write permissions.""" + """Try to enter a directory with no read & write permissions.""" with custom_transport as transport: - directory = tmp_path_remote / 'test' + directory = tmp_path_remote / 'test' # create directory with non default permissions transport.mkdir(directory) @@ -272,22 +270,22 @@ def test_dir_reading_permissions(custom_transport, tmp_path_remote): transport.chmod(directory, 0) # test if the security bits have changed - assert transport.get_mode(directory) == 0 + assert transport.get_mode(directory) == 0 - # TODO : the test leaves a directory even if it is successful + # TODO : the test leaves a directory even if it is successful # The bug is in paramiko. After lowering the permissions, # I cannot restore them to higher values # transport.rmdir(directory) -def test_isfile_isdir(custom_transport, tmp_path_remote): +def test_isfile_isdir(custom_transport, tmp_path_remote): with custom_transport as transport: - # return False on empty string + # return False on empty string assert not transport.isdir('') assert not transport.isfile('') # return False on non-existing files assert not transport.isfile(tmp_path_remote / 'does_not_exist') - assert not transport.isdir(tmp_path_remote / 'does_not_exist') + assert not transport.isdir(tmp_path_remote / 'does_not_exist') # isfile and isdir should not confuse files and directories Path(tmp_path_remote / 'samplefile').touch() @@ -317,14 +315,13 @@ def test_chdir_to_empty_string(custom_transport): assert new_dir == transport.getcwd() - def test_put_and_get(custom_transport, tmp_path_remote, tmp_path_local): """Test putting and getting files.""" directory = 'tmp_try' with custom_transport as transport: (tmp_path_local / directory).mkdir() - transport.mkdir(str(tmp_path_remote / directory)) + transport.mkdir(tmp_path_remote / directory) local_file_name = 'file.txt' retrieved_file_name = 'file_retrieved.txt' @@ -332,9 +329,9 @@ def test_put_and_get(custom_transport, tmp_path_remote, tmp_path_local): remote_file_name = 'file_remote.txt' # here use full path in src and dst - local_file_abs_path = str(tmp_path_local / directory / local_file_name) - retrieved_file_abs_path = str(tmp_path_local / directory / retrieved_file_name) - remote_file_abs_path = str(tmp_path_remote / directory / remote_file_name) + local_file_abs_path = tmp_path_local / directory / local_file_name + retrieved_file_abs_path = tmp_path_local / directory / retrieved_file_name + remote_file_abs_path = tmp_path_remote / directory / remote_file_name text = 'Viva Verdi\n' with open(local_file_abs_path, 'w', encoding='utf8') as fhandle: @@ -343,7 +340,7 @@ def test_put_and_get(custom_transport, tmp_path_remote, tmp_path_local): transport.put(local_file_abs_path, remote_file_abs_path) transport.get(remote_file_abs_path, retrieved_file_abs_path) - list_of_files = transport.listdir(str(tmp_path_remote / directory)) + list_of_files = transport.listdir((tmp_path_remote / directory)) # it is False because local_file_name has the full path, # while list_of_files has not assert local_file_name not in list_of_files @@ -360,7 +357,7 @@ def test_putfile_and_getfile(custom_transport, tmp_path_remote, tmp_path_local): with custom_transport as transport: (local_dir / directory).mkdir() - transport.mkdir(str(remote_dir / directory)) + transport.mkdir((remote_dir / directory)) local_file_name = 'file.txt' retrieved_file_name = 'file_retrieved.txt' @@ -368,9 +365,9 @@ def test_putfile_and_getfile(custom_transport, tmp_path_remote, tmp_path_local): remote_file_name = 'file_remote.txt' # here use full path in src and dst - local_file_abs_path = str(local_dir / directory / local_file_name) - retrieved_file_abs_path = str(local_dir / directory / retrieved_file_name) - remote_file_abs_path = str(remote_dir / directory / remote_file_name) + local_file_abs_path = local_dir / directory / local_file_name + retrieved_file_abs_path = local_dir / directory / retrieved_file_name + remote_file_abs_path = remote_dir / directory / remote_file_name text = 'Viva Verdi\n' with open(local_file_abs_path, 'w', encoding='utf8') as fhandle: @@ -396,7 +393,7 @@ def test_put_get_abs_path_file(custom_transport, tmp_path_remote, tmp_path_local with custom_transport as transport: (local_dir / directory).mkdir() - transport.mkdir(str(remote_dir / directory)) + transport.mkdir((remote_dir / directory)) local_file_name = 'file.txt' retrieved_file_name = 'file_retrieved.txt' @@ -405,8 +402,8 @@ def test_put_get_abs_path_file(custom_transport, tmp_path_remote, tmp_path_local local_file_rel_path = local_file_name remote_file_rel_path = remote_file_name - retrieved_file_abs_path = str(local_dir / directory / retrieved_file_name) - remote_file_abs_path = str(remote_dir / directory / remote_file_name) + retrieved_file_abs_path = local_dir / directory / retrieved_file_name + remote_file_abs_path = remote_dir / directory / remote_file_name # partial_file_name is not an abs path with pytest.raises(ValueError): @@ -442,7 +439,7 @@ def test_put_get_empty_string_file(custom_transport, tmp_path_remote, tmp_path_l with custom_transport as transport: (local_dir / directory).mkdir() - transport.mkdir(str(remote_dir / directory)) + transport.mkdir((remote_dir / directory)) local_file_name = 'file.txt' retrieved_file_name = 'file_retrieved.txt' @@ -450,9 +447,9 @@ def test_put_get_empty_string_file(custom_transport, tmp_path_remote, tmp_path_l remote_file_name = 'file_remote.txt' # here use full path in src and dst - local_file_abs_path = str(local_dir / directory / local_file_name) - retrieved_file_abs_path = str(local_dir / directory / retrieved_file_name) - remote_file_abs_path = str(remote_dir / directory / remote_file_name) + local_file_abs_path = local_dir / directory / local_file_name + retrieved_file_abs_path = local_dir / directory / retrieved_file_name + remote_file_abs_path = remote_dir / directory / remote_file_name text = 'Viva Verdi\n' with open(local_file_abs_path, 'w', encoding='utf8') as fhandle: @@ -525,8 +522,8 @@ def test_put_and_get_tree(custom_transport, tmp_path_remote, tmp_path_local): fhandle.write(text) # here use full path in src and dst - transport.puttree(str(local_subfolder), str(remote_subfolder)) - transport.gettree(str(remote_subfolder), str(retrieved_subfolder)) + transport.puttree((local_subfolder), (remote_subfolder)) + transport.gettree((remote_subfolder), (retrieved_subfolder)) list_of_dirs = [p.name for p in (local_dir / directory).iterdir()] @@ -659,25 +656,25 @@ def test_copy(custom_transport, tmp_path_remote): transport.copy(base_dir / '*.txt', workdir / 'prova') # fifth test, copying one file into a folder - transport.mkdir(str(workdir / 'prova')) - transport.copy(str(base_dir / 'a.txt'), str(workdir / 'prova')) - assert set(transport.listdir(str(workdir / 'prova'))) == set(['a.txt']) - transport.rmtree(str(workdir / 'prova')) + transport.mkdir((workdir / 'prova')) + transport.copy((base_dir / 'a.txt'), (workdir / 'prova')) + assert set(transport.listdir((workdir / 'prova'))) == set(['a.txt']) + transport.rmtree((workdir / 'prova')) # sixth test, copying one file into a file - transport.copy(str(base_dir / 'a.txt'), str(workdir / 'prova')) - assert transport.isfile(str(workdir / 'prova')) - transport.remove(str(workdir / 'prova')) + transport.copy((base_dir / 'a.txt'), (workdir / 'prova')) + assert transport.isfile((workdir / 'prova')) + transport.remove((workdir / 'prova')) # copy of folder into an existing folder # NOTE: the command cp has a different behavior on Mac vs Ubuntu # tests performed locally on a Mac may result in a failure. - transport.mkdir(str(workdir / 'prova')) - transport.copy(str(base_dir), str(workdir / 'prova')) - assert set(['origin']) == set(transport.listdir(str(workdir / 'prova'))) - assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(str(workdir / 'prova' / 'origin'))) - transport.rmtree(str(workdir / 'prova')) + transport.mkdir((workdir / 'prova')) + transport.copy((base_dir), (workdir / 'prova')) + assert set(['origin']) == set(transport.listdir((workdir / 'prova'))) + assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir((workdir / 'prova' / 'origin'))) + transport.rmtree((workdir / 'prova')) # exit - transport.rmtree(str(workdir)) + transport.rmtree((workdir)) def test_put(custom_transport, tmp_path_remote, tmp_path_local): @@ -708,55 +705,55 @@ def test_put(custom_transport, tmp_path_remote, tmp_path_local): fhandle.write(text) # first test the put. Copy of two files matching patterns, into a folder - transport.put(str(local_base_dir / '*.txt'), str(remote_workdir)) - assert set(['a.txt', 'c.txt']) == set(transport.listdir(str(remote_workdir))) - transport.remove(str(remote_workdir / 'a.txt')) - transport.remove(str(remote_workdir / 'c.txt')) + transport.put((local_base_dir / '*.txt'), (remote_workdir)) + assert set(['a.txt', 'c.txt']) == set(transport.listdir((remote_workdir))) + transport.remove((remote_workdir / 'a.txt')) + transport.remove((remote_workdir / 'c.txt')) # second test put. Put of two folders - transport.put(str(local_base_dir), str(remote_workdir / 'prova')) - assert set(['prova']) == set(transport.listdir(str(remote_workdir))) - assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(str(remote_workdir / 'prova'))) - transport.rmtree(str(remote_workdir / 'prova')) + transport.put((local_base_dir), (remote_workdir / 'prova')) + assert set(['prova']) == set(transport.listdir((remote_workdir))) + assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir((remote_workdir / 'prova'))) + transport.rmtree((remote_workdir / 'prova')) # third test put. Can copy one file into a new file - transport.put(str(local_base_dir / '*.tmp'), str(remote_workdir / 'prova')) - assert transport.isfile(str(remote_workdir / 'prova')) - transport.remove(str(remote_workdir / 'prova')) + transport.put((local_base_dir / '*.tmp'), (remote_workdir / 'prova')) + assert transport.isfile((remote_workdir / 'prova')) + transport.remove((remote_workdir / 'prova')) # fourth test put: can't copy more than one file to the same file, # i.e., the destination should be a folder with pytest.raises(OSError): - transport.put(str(local_base_dir / '*.txt'), str(remote_workdir / 'prova')) + transport.put((local_base_dir / '*.txt'), (remote_workdir / 'prova')) # can't copy folder to an exist file with open(remote_workdir / 'existing.txt', 'w', encoding='utf8') as fhandle: fhandle.write(text) with pytest.raises(OSError): - transport.put(str(local_base_dir), str(remote_workdir / 'existing.txt')) - transport.remove(str(remote_workdir / 'existing.txt')) + transport.put((local_base_dir), (remote_workdir / 'existing.txt')) + transport.remove((remote_workdir / 'existing.txt')) # fifth test, copying one file into a folder - transport.mkdir(str(remote_workdir / 'prova')) - transport.put(str(local_base_dir / 'a.txt'), str(remote_workdir / 'prova')) - assert set(transport.listdir(str(remote_workdir / 'prova'))) == set(['a.txt']) - transport.rmtree(str(remote_workdir / 'prova')) + transport.mkdir((remote_workdir / 'prova')) + transport.put((local_base_dir / 'a.txt'), (remote_workdir / 'prova')) + assert set(transport.listdir((remote_workdir / 'prova'))) == set(['a.txt']) + transport.rmtree((remote_workdir / 'prova')) # sixth test, copying one file into a file - transport.put(str(local_base_dir / 'a.txt'), str(remote_workdir / 'prova')) - assert transport.isfile(str(remote_workdir / 'prova')) - transport.remove(str(remote_workdir / 'prova')) + transport.put((local_base_dir / 'a.txt'), (remote_workdir / 'prova')) + assert transport.isfile((remote_workdir / 'prova')) + transport.remove((remote_workdir / 'prova')) # put of folder into an existing folder # NOTE: the command cp has a different behavior on Mac vs Ubuntu # tests performed locally on a Mac may result in a failure. - transport.mkdir(str(remote_workdir / 'prova')) - transport.put(str(local_base_dir), str(remote_workdir / 'prova')) - assert set(['origin']) == set(transport.listdir(str(remote_workdir / 'prova'))) - assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir(str(remote_workdir / 'prova' / 'origin'))) - transport.rmtree(str(remote_workdir / 'prova')) + transport.mkdir((remote_workdir / 'prova')) + transport.put((local_base_dir), (remote_workdir / 'prova')) + assert set(['origin']) == set(transport.listdir((remote_workdir / 'prova'))) + assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir((remote_workdir / 'prova' / 'origin'))) + transport.rmtree((remote_workdir / 'prova')) # exit - transport.rmtree(str(remote_workdir)) + transport.rmtree((remote_workdir)) def test_get(custom_transport, tmp_path_remote, tmp_path_local): @@ -786,27 +783,27 @@ def test_get(custom_transport, tmp_path_remote, tmp_path_local): fhandle.write(text) # first test get. Get two files matching patterns, from mocked remote folder into a local folder - transport.get(str(remote_base_dir / '*.txt'), str(local_workdir)) + transport.get((remote_base_dir / '*.txt'), (local_workdir)) assert set(['a.txt', 'c.txt']) == set([p.name for p in (local_workdir).iterdir()]) (local_workdir / 'a.txt').unlink() (local_workdir / 'c.txt').unlink() # second. Copy of folder into a non existing folder - transport.get(str(remote_base_dir), str(local_workdir / 'prova')) + transport.get((remote_base_dir), (local_workdir / 'prova')) assert set(['prova']) == set([p.name for p in local_workdir.iterdir()]) assert set(['a.txt', 'b.tmp', 'c.txt']) == set([p.name for p in (local_workdir / 'prova').iterdir()]) shutil.rmtree(local_workdir / 'prova') # third. copy of folder into an existing folder (local_workdir / 'prova').mkdir() - transport.get(str(remote_base_dir), str(local_workdir / 'prova')) + transport.get((remote_base_dir), (local_workdir / 'prova')) assert set(['prova']) == set([p.name for p in local_workdir.iterdir()]) assert set(['origin']) == set([p.name for p in (local_workdir / 'prova').iterdir()]) assert set(['a.txt', 'b.tmp', 'c.txt']) == set([p.name for p in (local_workdir / 'prova' / 'origin').iterdir()]) shutil.rmtree(local_workdir / 'prova') # test get one file into a new file prova - transport.get(str(remote_base_dir / '*.tmp'), str(local_workdir / 'prova')) + transport.get((remote_base_dir / '*.tmp'), (local_workdir / 'prova')) assert set(['prova']) == set([p.name for p in local_workdir.iterdir()]) assert (local_workdir / 'prova').is_file() (local_workdir / 'prova').unlink() @@ -814,22 +811,22 @@ def test_get(custom_transport, tmp_path_remote, tmp_path_local): # fourth test copy: can't copy more than one file on the same file, # i.e., the destination should be a folder with pytest.raises(OSError): - transport.get(str(remote_base_dir / '*.txt'), str(local_workdir / 'prova')) + transport.get((remote_base_dir / '*.txt'), (local_workdir / 'prova')) # copy of folder into file with open(local_workdir / 'existing.txt', 'w', encoding='utf8') as fhandle: fhandle.write(text) with pytest.raises(OSError): - transport.get(str(remote_base_dir), str(local_workdir / 'existing.txt')) + transport.get((remote_base_dir), (local_workdir / 'existing.txt')) (local_workdir / 'existing.txt').unlink() # fifth test, copying one file into a folder (local_workdir / 'prova').mkdir() - transport.get(str(remote_base_dir / 'a.txt'), str(local_workdir / 'prova')) + transport.get((remote_base_dir / 'a.txt'), (local_workdir / 'prova')) assert set(['a.txt']) == set([p.name for p in (local_workdir / 'prova').iterdir()]) shutil.rmtree(local_workdir / 'prova') # sixth test, copying one file into a file - transport.get(str(remote_base_dir / 'a.txt'), str(local_workdir / 'prova')) + transport.get((remote_base_dir / 'a.txt'), (local_workdir / 'prova')) assert (local_workdir / 'prova').is_file() (local_workdir / 'prova').unlink() @@ -841,9 +838,9 @@ def test_put_get_abs_path_tree(custom_transport, tmp_path_remote, tmp_path_local directory = 'tmp_try' with custom_transport as transport: - local_subfolder = str(local_dir / directory / 'tmp1') - remote_subfolder = str(remote_dir / 'tmp2') - retrieved_subfolder = str(local_dir / directory / 'tmp3') + local_subfolder = local_dir / directory / 'tmp1' + remote_subfolder = remote_dir / 'tmp2' + retrieved_subfolder = local_dir / directory / 'tmp3' (local_dir / directory / local_subfolder).mkdir(parents=True) @@ -889,9 +886,6 @@ def test_put_get_abs_path_tree(custom_transport, tmp_path_remote, tmp_path_local transport.gettree(remote_subfolder, 'delete_me_tree') - - - def test_put_get_empty_string_tree(custom_transport, tmp_path_remote, tmp_path_local): """Test of exception put/get of empty strings""" local_dir = tmp_path_local @@ -933,7 +927,7 @@ def test_put_get_empty_string_tree(custom_transport, tmp_path_remote, tmp_path_l # TODO : get doesn't retrieve empty files. # Is it what we want? - transport.gettree(str(remote_subfolder), str(retrieved_subfolder)) + transport.gettree((remote_subfolder), (retrieved_subfolder)) assert 'file.txt' in [p.name for p in retrieved_subfolder.iterdir()] @@ -944,13 +938,13 @@ def test_gettree_nested_directory(custom_transport, tmp_path_remote, tmp_path_lo dir_path = tmp_path_remote / 'sub' / 'path' dir_path.mkdir(parents=True) - file_path = str(dir_path / 'filename.txt') + file_path = dir_path / 'filename.txt' with open(file_path, 'wb') as handle: handle.write(content) with custom_transport as transport: - transport.gettree(str(tmp_path_remote), str(tmp_path_local)) + transport.gettree((tmp_path_remote), (tmp_path_local)) assert (tmp_path_local / 'sub' / 'path' / 'filename.txt').is_file @@ -966,8 +960,7 @@ def test_exec_pwd(custom_transport, tmp_path_remote): """ # Start value if not hasattr(custom_transport, 'chdir'): - return - + return with custom_transport as transport: # To compare with: getcwd uses the normalized ('realpath') path @@ -991,7 +984,6 @@ def test_exec_pwd(custom_transport, tmp_path_remote): assert stdout.strip() == subfolder_fullpath assert stderr == '' - def test_exec_with_stdin_string(custom_transport): """Test command execution with a stdin string.""" @@ -1120,7 +1112,7 @@ def test_transfer_big_stdout(custom_transport, tmp_path_remote): tmpf.flush() # I put a file with specific content there at the right file name - transport.putfile(tmpf.name, directory_path / fname) + transport.putfile(tmpf.name, directory_path / fname) python_code = r"""import sys @@ -1144,12 +1136,12 @@ def test_transfer_big_stdout(custom_transport, tmp_path_remote): tmpf.flush() # I put a file with specific content there at the right file name - transport.putfile(tmpf.name, directory_path / script_fname) + transport.putfile(tmpf.name, directory_path / script_fname) # I get its content via the stdout; emulate also network slowness (note I cat twice) retcode, stdout, stderr = transport.exec_command_wait( f'cat {fname} ; sleep 1 ; cat {fname}', workdir=directory_path - ) + ) assert stderr == '' assert stdout == fcontent + fcontent assert retcode == 0 @@ -1157,7 +1149,7 @@ def test_transfer_big_stdout(custom_transport, tmp_path_remote): # I get its content via the stderr; emulate also network slowness (note I cat twice) retcode, stdout, stderr = transport.exec_command_wait( f'cat {fname} >&2 ; sleep 1 ; cat {fname} >&2', workdir=directory_path - ) + ) assert stderr == fcontent + fcontent assert stdout == '' assert retcode == 0 @@ -1170,15 +1162,14 @@ def test_transfer_big_stdout(custom_transport, tmp_path_remote): # line_repetitions, file_line, file_line)) # However this is pretty slow (and using 'cat' of a file containing only one line is even slower) - retcode, stdout, stderr = transport.exec_command_wait(f'python3 {script_fname}', workdir=directory_path) + retcode, stdout, stderr = transport.exec_command_wait(f'python3 {script_fname}', workdir=directory_path) assert stderr == fcontent assert stdout == fcontent assert retcode == 0 - -def test_asynchronous_execution(custom_transport, tmp_path): +def test_asynchronous_execution(custom_transport, tmp_path): """Test that the execution of a long(ish) command via the direct scheduler does not block. This is a regression test for #3094, where running a long job on the direct scheduler @@ -1199,10 +1190,10 @@ def test_asynchronous_execution(custom_transport, tmp_path): tmpf.write(b'#!/bin/bash\nsleep 10\n') tmpf.flush() - transport.putfile(tmpf.name, str(tmp_path / script_fname)) + transport.putfile(tmpf.name, tmp_path / script_fname) timestamp_before = time.time() - job_id_string = scheduler.submit_job(str(tmp_path), script_fname) + job_id_string = scheduler.submit_job(tmp_path, script_fname) elapsed_time = time.time() - timestamp_before # We want to get back control. If it takes < 5 seconds, it means that it is not blocking From a809b9868ba34d559ba95f5b64c086e4c2bae529 Mon Sep 17 00:00:00 2001 From: Ali Date: Thu, 5 Dec 2024 13:33:56 +0100 Subject: [PATCH 15/29] copy() are now aligned with fresh development on asyncssh --- environment.yml | 2 +- pyproject.toml | 2 +- requirements/requirements-py-3.10.txt | 2 +- requirements/requirements-py-3.11.txt | 2 +- requirements/requirements-py-3.12.txt | 2 +- requirements/requirements-py-3.9.txt | 2 +- src/aiida/tools/pytest_fixtures/__init__.py | 1 - src/aiida/transports/__init__.py | 4 +- src/aiida/transports/cli.py | 2 +- src/aiida/transports/plugins/ssh.py | 4 + src/aiida/transports/plugins/ssh_async.py | 197 +++++++++++++------- src/aiida/transports/util.py | 6 +- tests/transports/test_all_plugins.py | 11 +- 13 files changed, 149 insertions(+), 88 deletions(-) diff --git a/environment.yml b/environment.yml index abbe6282f0..96028588a3 100644 --- a/environment.yml +++ b/environment.yml @@ -8,7 +8,7 @@ dependencies: - python~=3.9 - alembic~=1.2 - archive-path~=0.4.2 -- asyncssh~=2.18.0 +- asyncssh@ git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh - circus~=0.18.0 - click-spinner~=0.1.8 - click~=8.1 diff --git a/pyproject.toml b/pyproject.toml index 27f89146b2..7dfefa90a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ 'alembic~=1.2', 'archive-path~=0.4.2', - 'asyncssh~=2.18.0', + 'asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh', 'circus~=0.18.0', 'click-spinner~=0.1.8', 'click~=8.1', diff --git a/requirements/requirements-py-3.10.txt b/requirements/requirements-py-3.10.txt index caa3b3c8d4..7dbf31bc4b 100644 --- a/requirements/requirements-py-3.10.txt +++ b/requirements/requirements-py-3.10.txt @@ -20,7 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.2.1 async-generator==1.10 -asyncssh~=2.18.0 +asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh attrs==23.1.0 babel==2.12.1 backcall==0.2.0 diff --git a/requirements/requirements-py-3.11.txt b/requirements/requirements-py-3.11.txt index cf79aedbde..2c39bcb2be 100644 --- a/requirements/requirements-py-3.11.txt +++ b/requirements/requirements-py-3.11.txt @@ -20,7 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.2.1 async-generator==1.10 -asyncssh~=2.18.0 +asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh attrs==23.1.0 babel==2.12.1 backcall==0.2.0 diff --git a/requirements/requirements-py-3.12.txt b/requirements/requirements-py-3.12.txt index 428089b7a4..b900063e78 100644 --- a/requirements/requirements-py-3.12.txt +++ b/requirements/requirements-py-3.12.txt @@ -20,7 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.4.0 async-generator==1.10 -asyncssh~=2.18.0 +asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh attrs==23.1.0 babel==2.13.1 backcall==0.2.0 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index a19a3b730b..67e81bc655 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -20,7 +20,7 @@ ase==3.22.1 asn1crypto==1.5.1 asttokens==2.2.1 async-generator==1.10 -asyncssh~=2.18.0 +asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh attrs==23.1.0 babel==2.12.1 backcall==0.2.0 diff --git a/src/aiida/tools/pytest_fixtures/__init__.py b/src/aiida/tools/pytest_fixtures/__init__.py index a400a03ec9..e19d4c455e 100644 --- a/src/aiida/tools/pytest_fixtures/__init__.py +++ b/src/aiida/tools/pytest_fixtures/__init__.py @@ -35,7 +35,6 @@ 'aiida_computer_local', 'aiida_computer_ssh', 'aiida_computer_ssh_async', - 'aiida_computer', 'aiida_config', 'aiida_config_factory', 'aiida_config_tmp', diff --git a/src/aiida/transports/__init__.py b/src/aiida/transports/__init__.py index f3427ff5e3..7a7d472869 100644 --- a/src/aiida/transports/__init__.py +++ b/src/aiida/transports/__init__.py @@ -16,9 +16,9 @@ from .transport import * __all__ = ( - 'Transport', - 'SshTransport', 'AsyncTransport', + 'SshTransport', + 'Transport', 'convert_to_bool', 'parse_sshconfig', ) diff --git a/src/aiida/transports/cli.py b/src/aiida/transports/cli.py index 6088eb08f6..5faa2d6f80 100644 --- a/src/aiida/transports/cli.py +++ b/src/aiida/transports/cli.py @@ -140,7 +140,7 @@ def transport_options(transport_type): """Decorate a command with all options for a computer configure subcommand for transport_type.""" def apply_options(func): - """Decorate the command functionn with the appropriate options for the transport type.""" + """Decorate the command function with the appropriate options for the transport type.""" options_list = list_transport_options(transport_type) options_list.reverse() func = arguments.COMPUTER(callback=partial(match_comp_transport, transport_type=transport_type))(func) diff --git a/src/aiida/transports/plugins/ssh.py b/src/aiida/transports/plugins/ssh.py index b4ab80e968..8cfe607a34 100644 --- a/src/aiida/transports/plugins/ssh.py +++ b/src/aiida/transports/plugins/ssh.py @@ -231,6 +231,10 @@ class SshTransport(Transport): # if too large commands are sent, clogging the outputs or logs _MAX_EXEC_COMMAND_LOG_SIZE = None + # NOTE: all the methods that start with _get_ are class methods that + # return a suggestion for the specific field. They are being used in + # a function called transport_option_default in transports/cli.py, + # during an interactive `verdi computer configure` command. @classmethod def _get_username_suggestion_string(cls, computer): """Return a suggestion for the specific field.""" diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 346f497139..6faace2e08 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -17,7 +17,7 @@ import asyncssh import click -from asyncssh import SFTPFileAlreadyExists +from asyncssh import SFTPFileAlreadyExists, SFTPOpUnsupported from aiida.common.escaping import escape_for_bash from aiida.common.exceptions import InvalidOperation @@ -61,7 +61,7 @@ class AsyncSshTransport(AsyncTransport): # note, I intentionally wanted to keep connection parameters as simple as possible. _valid_auth_options = [ ( - 'machine_', + 'machine', { 'type': str, 'prompt': 'machine as in `ssh machine` command', @@ -84,25 +84,21 @@ class AsyncSshTransport(AsyncTransport): 'callback': _validate_script, }, ), - ( - 'script_during', - { - 'type': str, - 'default': 'None', - 'prompt': 'Local script to run *during* opening connection (path)', - 'help': '(optional) Specify a script to run *during* opening SSH connection. ' - 'The script should be executable', - 'non_interactive_default': True, - 'callback': _validate_script, - }, - ), ] + @classmethod + def _get_machine_suggestion_string(cls, computer): + """Return a suggestion for the parameter machine.""" + # Originally set as 'Hostname' during `verdi computer setup` + # and is passed as `machine=computer.hostname` in the codebase + # unfortunately, name of hostname and machine are used interchangeably in the aiida-core codebase + # TODO: open an issue to unify the naming + return computer.hostname + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.machine = kwargs.pop('machine_') + self.machine = kwargs.pop('machine') self.script_before = kwargs.pop('script_before', 'None') - self.script_during = kwargs.pop('script_during', 'None') async def open_async(self): """Open the transport. @@ -118,10 +114,6 @@ async def open_async(self): os.system(f'{self.script_before}') self._conn = await asyncssh.connect(self.machine) - - if self.script_during != 'None': - os.system(f'{self.script_during}') - self._sftp = await self._conn.start_sftp_client() self._is_open = True @@ -150,6 +142,7 @@ async def get_async( dereference=True, overwrite=True, ignore_nonexisting=False, + preserve=False, *args, **kwargs, ): @@ -162,12 +155,17 @@ async def get_async( Default = True :param overwrite: if True overwrites files and folders. Default = False + :param ignore_nonexisting: if True, does not raise an error if the remotepath does not exist + Default = False + :param preserve: preserve file attributes + Default = False :type remotepath: TransportPath :type localpath: TransportPath :type dereference: bool :type overwrite: bool :type ignore_nonexisting: bool + :type preserve: bool :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found @@ -202,41 +200,51 @@ async def get_async( if rename_local: # copying more than one file in one directory # here is the case isfile and more than one file remote = os.path.join(localpath, os.path.split(file)[1]) - await self.getfile_async(file, remote, dereference, overwrite) + await self.getfile_async(file, remote, dereference, overwrite, preserve) else: # one file to copy on one file - await self.getfile_async(file, localpath, dereference, overwrite) + await self.getfile_async(file, localpath, dereference, overwrite, preserve) else: - await self.gettree_async(file, localpath, dereference, overwrite) + await self.gettree_async(file, localpath, dereference, overwrite, preserve) elif await self.isdir_async(remotepath): - await self.gettree_async(remotepath, localpath, dereference, overwrite) + await self.gettree_async(remotepath, localpath, dereference, overwrite, preserve) elif await self.isfile_async(remotepath): if os.path.isdir(localpath): remote = os.path.join(localpath, os.path.split(remotepath)[1]) - await self.getfile_async(remotepath, remote, dereference, overwrite) + await self.getfile_async(remotepath, remote, dereference, overwrite, preserve) else: - await self.getfile_async(remotepath, localpath, dereference, overwrite) + await self.getfile_async(remotepath, localpath, dereference, overwrite, preserve) elif ignore_nonexisting: pass else: raise OSError(f'The remote path {remotepath} does not exist') async def getfile_async( - self, remotepath: TransportPath, localpath: TransportPath, dereference=True, overwrite=True, *args, **kwargs + self, + remotepath: TransportPath, + localpath: TransportPath, + dereference=True, + overwrite=True, + preserve=False, + *args, + **kwargs, ): """Get a file from remote to local. :param remotepath: an absolute remote path :param localpath: an absolute local path - :param overwrite: if True overwrites files and folders. - Default = False - :param dereference: follow symbolic links. - Default = True + :param overwrite: if True overwrites files and folders. + Default = False + :param dereference: follow symbolic links. + Default = True + :param preserve: preserve file attributes + Default = False :type remotepath: TransportPath :type localpath: TransportPath :type dereference: bool :type overwrite: bool + :type preserve: bool :raise ValueError: if local path is invalid :raise OSError: if unintentionally overwriting @@ -252,13 +260,24 @@ async def getfile_async( try: await self._sftp.get( - remotepaths=remotepath, localpath=localpath, preserve=True, recurse=False, follow_symlinks=dereference + remotepaths=remotepath, + localpath=localpath, + preserve=preserve, + recurse=False, + follow_symlinks=dereference, ) except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') async def gettree_async( - self, remotepath: TransportPath, localpath: TransportPath, dereference=True, overwrite=True, *args, **kwargs + self, + remotepath: TransportPath, + localpath: TransportPath, + dereference=True, + overwrite=True, + preserve=False, + *args, + **kwargs, ): """Get a folder recursively from remote to local. @@ -268,11 +287,14 @@ async def gettree_async( Default = True :param overwrite: if True overwrites files and folders. Default = True + :param preserve: preserve file attributes + Default = False :type remotepath: TransportPath :type localpath: TransportPath :type dereference: bool :type overwrite: bool + :type preserve: bool :raise ValueError: if local path is invalid :raise OSError: if the remotepath is not found @@ -309,7 +331,7 @@ async def gettree_async( await self._sftp.get( remotepaths=PurePath(remotepath) / content_, localpath=localpath, - preserve=True, + preserve=preserve, recurse=True, follow_symlinks=dereference, ) @@ -323,6 +345,7 @@ async def put_async( dereference=True, overwrite=True, ignore_nonexisting=False, + preserve=False, *args, **kwargs, ): @@ -335,12 +358,17 @@ async def put_async( Default = True :param overwrite: if True overwrites files and folders Default = False + :param ignore_nonexisting: if True, does not raise an error if the localpath does not exist + Default = False + :param preserve: preserve file attributes + Default = False :type remotepath: TransportPath :type localpath: TransportPath :type dereference: bool :type overwrite: bool :type ignore_nonexisting: bool + :type preserve: bool :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist @@ -365,7 +393,7 @@ async def put_async( if await self.isfile_async(remotepath): raise OSError('Remote destination is not a directory') # I can't scp more than one file in a non existing directory - elif not await self.path_exists_async(remotepath): # questo dovrebbe valere solo per file + elif not await self.path_exists_async(remotepath): raise OSError('Remote directory does not exist') else: # the remote path is a directory rename_remote = True @@ -375,29 +403,36 @@ async def put_async( if rename_remote: # copying more than one file in one directory # here is the case isfile and more than one file remotefile = os.path.join(remotepath, os.path.split(file)[1]) - await self.putfile_async(file, remotefile, dereference, overwrite) + await self.putfile_async(file, remotefile, dereference, overwrite, preserve) elif await self.isdir_async(remotepath): # one file to copy in '.' remotefile = os.path.join(remotepath, os.path.split(file)[1]) - await self.putfile_async(file, remotefile, dereference, overwrite) + await self.putfile_async(file, remotefile, dereference, overwrite, preserve) else: # one file to copy on one file - await self.putfile_async(file, remotepath, dereference, overwrite) + await self.putfile_async(file, remotepath, dereference, overwrite, preserve) else: - await self.puttree_async(file, remotepath, dereference, overwrite) + await self.puttree_async(file, remotepath, dereference, overwrite, preserve) elif os.path.isdir(localpath): - await self.puttree_async(localpath, remotepath, dereference, overwrite) + await self.puttree_async(localpath, remotepath, dereference, overwrite, preserve) elif os.path.isfile(localpath): if await self.isdir_async(remotepath): remote = os.path.join(remotepath, os.path.split(localpath)[1]) - await self.putfile_async(localpath, remote, dereference, overwrite) + await self.putfile_async(localpath, remote, dereference, overwrite, preserve) else: - await self.putfile_async(localpath, remotepath, dereference, overwrite) + await self.putfile_async(localpath, remotepath, dereference, overwrite, preserve) elif not ignore_nonexisting: raise OSError(f'The local path {localpath} does not exist') async def putfile_async( - self, localpath: TransportPath, remotepath: TransportPath, dereference=True, overwrite=True, *args, **kwargs + self, + localpath: TransportPath, + remotepath: TransportPath, + dereference=True, + overwrite=True, + preserve=False, + *args, + **kwargs, ): """Put a file from local to remote. @@ -405,11 +440,14 @@ async def putfile_async( :param localpath: an absolute local path :param overwrite: if True overwrites files and folders Default = True + :param preserve: preserve file attributes + Default = False :type remotepath: TransportPath :type localpath: TransportPath :type dereference: bool :type overwrite: bool + :type preserve: bool :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist, @@ -426,13 +464,24 @@ async def putfile_async( try: await self._sftp.put( - localpaths=localpath, remotepath=remotepath, preserve=True, recurse=False, follow_symlinks=dereference + localpaths=localpath, + remotepath=remotepath, + preserve=preserve, + recurse=False, + follow_symlinks=dereference, ) except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') async def puttree_async( - self, localpath: TransportPath, remotepath: TransportPath, dereference=True, overwrite=True, *args, **kwargs + self, + localpath: TransportPath, + remotepath: TransportPath, + dereference=True, + overwrite=True, + preserve=False, + *args, + **kwargs, ): """Put a folder recursively from local to remote. @@ -442,11 +491,14 @@ async def puttree_async( Default = True :param overwrite: if True overwrites files and folders (boolean). Default = True + :param preserve: preserve file attributes + Default = False :type localpath: TransportPath :type remotepath: TransportPath :type dereference: bool :type overwrite: bool + :type preserve: bool :raise ValueError: if local path is invalid :raise OSError: if the localpath does not exist, or trying to overwrite @@ -486,7 +538,7 @@ async def puttree_async( await self._sftp.put( localpaths=PurePath(localpath) / content_, remotepath=remotepath, - preserve=True, + preserve=preserve, recurse=True, follow_symlinks=dereference, ) @@ -508,6 +560,7 @@ async def copy_async( :param dereference: follow symbolic links :param recursive: copy recursively :param preserve: preserve file attributes + Default = False :type remotesource: TransportPath :type remotedestination: TransportPath @@ -532,32 +585,32 @@ async def copy_async( # For the older versions, it downloads the file and uploads it again! # For performance reasons, we should check if the remote copy is supported, if so use # self._sftp.mcopy() & self._sftp.copy() otherwise send a `cp` command to the remote machine. - # This is a temporary solution until the feature is implemented in asyncssh: # See here: https://github.com/ronf/asyncssh/issues/724 - if False: - # self._sftp._supports_copy_data: - try: # type: ignore[unreachable] - if self.has_magic(remotesource): - await self._sftp.mcopy( - remotesource, - remotedestination, - preserve=preserve, - recurse=recursive, - follow_symlinks=dereference, - ) - else: - if not await self.path_exists_async(remotesource): - raise OSError(f'The remote path {remotesource} does not exist') - await self._sftp.copy( - remotesource, - remotedestination, - preserve=preserve, - recurse=recursive, - follow_symlinks=dereference, - ) - except asyncssh.sftp.SFTPFailure as exc: - raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') - else: + try: + if self.has_magic(remotesource): + await self._sftp.mcopy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + remote_only=True, + ) + else: + if not await self.path_exists_async(remotesource): + raise OSError(f'The remote path {remotesource} does not exist') + await self._sftp.copy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + remote_only=True, + ) + except asyncssh.sftp.SFTPFailure as exc: + raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') + except SFTPOpUnsupported: + self.logger.warning('The remote copy is not supported, using the `cp` command to copy the file/folder') # I copy pasted the whole logic below from SshTransport class: async def _exec_cp(cp_exe: str, cp_flags: str, src: str, dst: str): @@ -624,6 +677,7 @@ async def copyfile_async( :param remotedestination: path to the remote destination file :param dereference: follow symbolic links :param preserve: preserve file attributes + Default = False :type remotesource: TransportPath :type remotedestination: TransportPath @@ -647,6 +701,7 @@ async def copytree_async( :param remotedestination: path to the remote destination directory :param dereference: follow symbolic links :param preserve: preserve file attributes + Default = False :type remotesource: TransportPath :type remotedestination: TransportPath diff --git a/src/aiida/transports/util.py b/src/aiida/transports/util.py index dfda089e95..9e338df97d 100644 --- a/src/aiida/transports/util.py +++ b/src/aiida/transports/util.py @@ -93,9 +93,9 @@ async def copy_from_remote_to_remote_async( ): """Copy files or folders from a remote computer to another remote computer. Note: To have a proper async performance, - both transports should be instance `core.async_ssh`. - Even if either or both are not async, the function will work, - but the performance might be lower than the sync version. + both transports should be instance `core.async_ssh`. + Even if either or both are not async, the function will work, + but the performance might be lower than the sync version. :param transportsource: transport to be used for the source computer :param transportdestination: transport to be used for the destination computer diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index aad923d404..eb055c7f0f 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -24,7 +24,7 @@ import psutil import pytest -from aiida.plugins import SchedulerFactory, TransportFactory, entry_point +from aiida.plugins import SchedulerFactory, TransportFactory from aiida.transports import AsyncTransport, Transport # TODO : test for copy with pattern @@ -45,7 +45,8 @@ def tmp_path_local(tmp_path_factory): return tmp_path_factory.mktemp('local') -@pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) +# @pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) +@pytest.fixture(scope='function', params=['core.ssh', 'core.ssh_auto', 'core.ssh_async']) def custom_transport(request, tmp_path_factory, monkeypatch) -> Union['Transport', 'AsyncTransport']: """Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``.""" plugin = TransportFactory(request.param) @@ -492,12 +493,14 @@ def test_put_get_empty_string_file(custom_transport, tmp_path_remote, tmp_path_l t1 = Path(retrieved_file_abs_path).stat().st_mtime_ns # overwrite retrieved_file_name in 0.01 s - time.sleep(0.01) + time.sleep(1) transport.getfile(remote_file_abs_path, retrieved_file_abs_path) assert Path(retrieved_file_abs_path).exists() t2 = Path(retrieved_file_abs_path).stat().st_mtime_ns - # Check st_mtime_ns to sure it is override + # Check st_mtime_ns to sure it is overwritten + # Note: this test will fail if getfile() would preserve the remote timestamp, + # this is supported by core.ssh_async, but the default value is False assert t2 > t1 From 38cfc248f9770cdeafc1ea8d52da0a7870439b32 Mon Sep 17 00:00:00 2001 From: Ali Date: Thu, 5 Dec 2024 14:27:55 +0100 Subject: [PATCH 16/29] fixed some stupid issues --- src/aiida/transports/plugins/local.py | 2 +- tests/transports/test_all_plugins.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/aiida/transports/plugins/local.py b/src/aiida/transports/plugins/local.py index e1f5bce5eb..56fa042734 100644 --- a/src/aiida/transports/plugins/local.py +++ b/src/aiida/transports/plugins/local.py @@ -489,7 +489,7 @@ def getfile(self, remotepath: TransportPath, localpath: TransportPath, *args, ** """ remotepath = path_to_str(remotepath) localpath = path_to_str(localpath) - + if not os.path.isabs(localpath): raise ValueError('localpath must be an absolute path') diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index eb055c7f0f..1290e871f3 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -24,7 +24,7 @@ import psutil import pytest -from aiida.plugins import SchedulerFactory, TransportFactory +from aiida.plugins import SchedulerFactory, TransportFactory, entry_point from aiida.transports import AsyncTransport, Transport # TODO : test for copy with pattern @@ -45,8 +45,11 @@ def tmp_path_local(tmp_path_factory): return tmp_path_factory.mktemp('local') -# @pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) -@pytest.fixture(scope='function', params=['core.ssh', 'core.ssh_auto', 'core.ssh_async']) +# Skip for any transport plugins that are locally installed but are not part of `aiida-core` +@pytest.fixture( + scope='function', + params=[name for name in entry_point.get_entry_point_names('aiida.transports') if name.startswith('core.')], +) def custom_transport(request, tmp_path_factory, monkeypatch) -> Union['Transport', 'AsyncTransport']: """Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``.""" plugin = TransportFactory(request.param) From 799e0f8f8f9abb9d0a1c5a5bbcda45f2a0594cb2 Mon Sep 17 00:00:00 2001 From: Ali Date: Thu, 5 Dec 2024 16:34:20 +0100 Subject: [PATCH 17/29] plumpy hook pointing to async-run branch, now --- environment.yml | 2 +- pyproject.toml | 2 +- requirements/requirements-py-3.10.txt | 2 +- requirements/requirements-py-3.11.txt | 2 +- requirements/requirements-py-3.12.txt | 2 +- requirements/requirements-py-3.9.txt | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/environment.yml b/environment.yml index 96028588a3..b16b2892ee 100644 --- a/environment.yml +++ b/environment.yml @@ -23,7 +23,7 @@ dependencies: - importlib-metadata~=6.0 - numpy~=1.21 - paramiko~=3.0 -- plumpy@ git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy +- plumpy@ git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy - pgsu~=0.3.0 - psutil~=5.6 - psycopg[binary]~=3.0 diff --git a/pyproject.toml b/pyproject.toml index 7dfefa90a7..4a0a702d35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ 'importlib-metadata~=6.0', 'numpy~=1.21', 'paramiko~=3.0', - 'plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy', + 'plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy', 'pgsu~=0.3.0', 'psutil~=5.6', 'psycopg[binary]~=3.0', diff --git a/requirements/requirements-py-3.10.txt b/requirements/requirements-py-3.10.txt index 7dbf31bc4b..8be5c6fbdb 100644 --- a/requirements/requirements-py-3.10.txt +++ b/requirements/requirements-py-3.10.txt @@ -121,7 +121,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy +plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.11.txt b/requirements/requirements-py-3.11.txt index 2c39bcb2be..4dd667c053 100644 --- a/requirements/requirements-py-3.11.txt +++ b/requirements/requirements-py-3.11.txt @@ -120,7 +120,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy +plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 diff --git a/requirements/requirements-py-3.12.txt b/requirements/requirements-py-3.12.txt index b900063e78..fabc5f1569 100644 --- a/requirements/requirements-py-3.12.txt +++ b/requirements/requirements-py-3.12.txt @@ -120,7 +120,7 @@ pillow==10.1.0 platformdirs==3.11.0 plotly==5.17.0 pluggy==1.3.0 -plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy +plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.6 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index 67e81bc655..cf2da0c307 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -123,7 +123,7 @@ pillow==9.5.0 platformdirs==3.6.0 plotly==5.15.0 pluggy==1.0.0 -plumpy@git+https://github.com/khsrali/plumpy.git@allow-async-upload-download#egg=plumpy +plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy prometheus-client==0.17.0 prompt-toolkit==3.0.38 psutil==5.9.5 From 083719346b1153b251812d6db25c9d270bc2eaf9 Mon Sep 17 00:00:00 2001 From: Ali Date: Thu, 5 Dec 2024 18:47:57 +0100 Subject: [PATCH 18/29] updated uv lock --- uv.lock | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/uv.lock b/uv.lock index 01493fe251..b06b360634 100644 --- a/uv.lock +++ b/uv.lock @@ -25,6 +25,7 @@ source = { editable = "." } dependencies = [ { name = "alembic" }, { name = "archive-path" }, + { name = "asyncssh" }, { name = "circus" }, { name = "click" }, { name = "click-spinner" }, @@ -159,6 +160,7 @@ requires-dist = [ { name = "alembic", specifier = "~=1.2" }, { name = "archive-path", specifier = "~=0.4.2" }, { name = "ase", marker = "extra == 'atomic-tools'", specifier = "~=3.18" }, + { name = "asyncssh", git = "https://github.com/ronf/asyncssh.git?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2#033ef54302b2b09d496d68ccf39778b9e5fc89e2" }, { name = "bpython", marker = "extra == 'bpython'", specifier = "~=0.18.0" }, { name = "circus", specifier = "~=0.18.0" }, { name = "click", specifier = "~=8.1" }, @@ -190,7 +192,7 @@ requires-dist = [ { name = "pg8000", marker = "extra == 'tests'", specifier = "~=1.13" }, { name = "pgsu", specifier = "~=0.3.0" }, { name = "pgtest", marker = "extra == 'tests'", specifier = "~=1.3,>=1.3.1" }, - { name = "plumpy", specifier = "~=0.22.3" }, + { name = "plumpy", git = "https://github.com/aiidateam/plumpy.git?rev=async-run" }, { name = "pre-commit", marker = "extra == 'pre-commit'", specifier = "~=3.5" }, { name = "psutil", specifier = "~=5.6" }, { name = "psycopg", extras = ["binary"], specifier = "~=3.0" }, @@ -445,6 +447,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/9f/3c3503693386c4b0f245eaf5ca6198e3b28879ca0a40bde6b0e319793453/async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224", size = 6111 }, ] +[[package]] +name = "asyncssh" +version = "2.18.0" +source = { git = "https://github.com/ronf/asyncssh.git?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2#033ef54302b2b09d496d68ccf39778b9e5fc89e2" } +dependencies = [ + { name = "cryptography" }, + { name = "typing-extensions" }, +] + [[package]] name = "attrs" version = "24.2.0" @@ -2918,15 +2929,12 @@ wheels = [ [[package]] name = "plumpy" version = "0.22.3" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/aiidateam/plumpy.git?rev=async-run#2a003072ffe45e570bd0a76aafe049e559836cb6" } dependencies = [ { name = "kiwipy", extra = ["rmq"] }, { name = "nest-asyncio" }, { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ab/99/6c931d3f4697acd34cf18eb3fbfe96ed55cd0408d9be7c0f316349117a8e/plumpy-0.22.3.tar.gz", hash = "sha256:e58f45e6360f173babf04e2a4abacae9867622768ce2a126c8260db3b46372c4", size = 73582 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/d9/12fd8281f494ca79d6a7a9d40099616d16415be5807959e5b024dffe8aed/plumpy-0.22.3-py3-none-any.whl", hash = "sha256:63ae6c90713f52483836a3b2b3e1941eab7ada920c303092facc27e78229bdc3", size = 74244 }, + { name = "typing-extensions" }, ] [[package]] From 1b96110c3d86ece14639e85dc3639f70255e8961 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Thu, 5 Dec 2024 20:48:29 +0100 Subject: [PATCH 19/29] Fixing uv.lock file for the depedencies from a github repo uv add git+https://github.com/aiidateam/plumpy --branch async-run uv add git+https://github.com/ronf/asyncssh --rev 033ef54302b2b09d496d68ccf39778b9e5fc89e2 --- pyproject.toml | 10 +++++++--- uv.lock | 8 ++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a0a702d35..3b4ed0539d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ 'alembic~=1.2', 'archive-path~=0.4.2', - 'asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh', + "asyncssh", 'circus~=0.18.0', 'click-spinner~=0.1.8', 'click~=8.1', @@ -35,7 +35,7 @@ dependencies = [ 'importlib-metadata~=6.0', 'numpy~=1.21', 'paramiko~=3.0', - 'plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy', + "plumpy", 'pgsu~=0.3.0', 'psutil~=5.6', 'psycopg[binary]~=3.0', @@ -47,7 +47,7 @@ dependencies = [ 'tabulate>=0.8.0,<0.10.0', 'tqdm~=4.45', 'upf_to_json~=0.9.2', - 'wrapt~=1.11' + 'wrapt~=1.11', ] description = 'AiiDA is a workflow manager for computational science with a strong focus on provenance, performance and extensibility.' dynamic = ['version'] # read from aiida/__init__.py @@ -511,3 +511,7 @@ passenv = AIIDA_TEST_WORKERS commands = molecule {posargs:test} """ + +[tool.uv.sources] +plumpy = { git = "https://github.com/aiidateam/plumpy", branch = "async-run" } +asyncssh = { git = "https://github.com/ronf/asyncssh", rev = "033ef54302b2b09d496d68ccf39778b9e5fc89e2" } diff --git a/uv.lock b/uv.lock index b06b360634..1efece9bff 100644 --- a/uv.lock +++ b/uv.lock @@ -160,7 +160,7 @@ requires-dist = [ { name = "alembic", specifier = "~=1.2" }, { name = "archive-path", specifier = "~=0.4.2" }, { name = "ase", marker = "extra == 'atomic-tools'", specifier = "~=3.18" }, - { name = "asyncssh", git = "https://github.com/ronf/asyncssh.git?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2#033ef54302b2b09d496d68ccf39778b9e5fc89e2" }, + { name = "asyncssh", git = "https://github.com/ronf/asyncssh?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2" }, { name = "bpython", marker = "extra == 'bpython'", specifier = "~=0.18.0" }, { name = "circus", specifier = "~=0.18.0" }, { name = "click", specifier = "~=8.1" }, @@ -192,7 +192,7 @@ requires-dist = [ { name = "pg8000", marker = "extra == 'tests'", specifier = "~=1.13" }, { name = "pgsu", specifier = "~=0.3.0" }, { name = "pgtest", marker = "extra == 'tests'", specifier = "~=1.3,>=1.3.1" }, - { name = "plumpy", git = "https://github.com/aiidateam/plumpy.git?rev=async-run" }, + { name = "plumpy", git = "https://github.com/aiidateam/plumpy?branch=async-run" }, { name = "pre-commit", marker = "extra == 'pre-commit'", specifier = "~=3.5" }, { name = "psutil", specifier = "~=5.6" }, { name = "psycopg", extras = ["binary"], specifier = "~=3.0" }, @@ -450,7 +450,7 @@ wheels = [ [[package]] name = "asyncssh" version = "2.18.0" -source = { git = "https://github.com/ronf/asyncssh.git?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2#033ef54302b2b09d496d68ccf39778b9e5fc89e2" } +source = { git = "https://github.com/ronf/asyncssh?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2#033ef54302b2b09d496d68ccf39778b9e5fc89e2" } dependencies = [ { name = "cryptography" }, { name = "typing-extensions" }, @@ -2929,7 +2929,7 @@ wheels = [ [[package]] name = "plumpy" version = "0.22.3" -source = { git = "https://github.com/aiidateam/plumpy.git?rev=async-run#2a003072ffe45e570bd0a76aafe049e559836cb6" } +source = { git = "https://github.com/aiidateam/plumpy?branch=async-run#2a003072ffe45e570bd0a76aafe049e559836cb6" } dependencies = [ { name = "kiwipy", extra = ["rmq"] }, { name = "nest-asyncio" }, From a68240c2a890108a3cea9146a81dd6bffc0270de Mon Sep 17 00:00:00 2001 From: Ali Date: Fri, 6 Dec 2024 09:06:41 +0100 Subject: [PATCH 20/29] fix conflicts --- requirements/requirements-py-3.10.txt | 218 ------------------------- requirements/requirements-py-3.11.txt | 216 ------------------------- requirements/requirements-py-3.12.txt | 216 ------------------------- requirements/requirements-py-3.9.txt | 220 -------------------------- 4 files changed, 870 deletions(-) delete mode 100644 requirements/requirements-py-3.10.txt delete mode 100644 requirements/requirements-py-3.11.txt delete mode 100644 requirements/requirements-py-3.12.txt delete mode 100644 requirements/requirements-py-3.9.txt diff --git a/requirements/requirements-py-3.10.txt b/requirements/requirements-py-3.10.txt deleted file mode 100644 index 423aab6c38..0000000000 --- a/requirements/requirements-py-3.10.txt +++ /dev/null @@ -1,218 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile --extra=atomic_tools --extra=docs --extra=notebook --extra=rest --extra=tests --no-annotate --output-file=requirements/requirements-py-3.10.txt pyproject.toml -# -accessible-pygments==0.0.5 -aiida-export-migration-tests==0.9.0 -aio-pika==9.4.0 -aiormq==6.8.0 -alabaster==0.7.13 -alembic==1.11.1 -aniso8601==9.0.1 -annotated-types==0.7.0 -anyio==3.7.0 -archive-path==0.4.2 -argon2-cffi==21.3.0 -argon2-cffi-bindings==21.2.0 -ase==3.22.1 -asn1crypto==1.5.1 -asttokens==2.2.1 -async-generator==1.10 -asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh -attrs==23.1.0 -babel==2.12.1 -backcall==0.2.0 -bcrypt==4.0.1 -beautifulsoup4==4.12.2 -bleach==6.0.0 -blinker==1.6.2 -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==3.1.0 -circus==0.18.0 -click==8.1.3 -click-spinner==0.1.10 -comm==0.1.3 -contourpy==1.1.0 -coverage[toml]==7.4.1 -cryptography==41.0.1 -cycler==0.11.0 -debugpy==1.6.7 -decorator==5.1.1 -defusedxml==0.7.1 -deprecation==2.1.0 -disk-objectstore==1.2.0 -docstring-parser==0.15 -docutils==0.20.1 -exceptiongroup==1.1.1 -executing==1.2.0 -fastjsonschema==2.17.1 -flask==2.3.2 -flask-cors==3.0.10 -flask-restful==0.3.10 -fonttools==4.40.0 -future==0.18.3 -graphviz==0.20.1 -greenlet==2.0.2 -idna==3.4 -imagesize==1.4.1 -importlib-metadata==6.8.0 -iniconfig==2.0.0 -ipykernel==6.23.2 -ipython==8.14.0 -ipython-genutils==0.2.0 -ipywidgets==8.0.6 -itsdangerous==2.1.2 -jedi==0.18.2 -jinja2==3.1.2 -joblib==1.4.2 -jsonschema[format-nongpl]==3.2.0 -jupyter==1.0.0 -jupyter-cache==0.6.1 -jupyter-client==8.2.0 -jupyter-console==6.6.3 -jupyter-core==5.3.1 -jupyter-events==0.6.3 -jupyter-server==2.6.0 -jupyter-server-terminals==0.4.4 -jupyterlab-pygments==0.2.2 -jupyterlab-widgets==3.0.7 -kiwipy[rmq]==0.8.4 -kiwisolver==1.4.4 -latexcodec==2.0.1 -mako==1.2.4 -markdown-it-py==3.0.0 -markupsafe==2.1.3 -matplotlib==3.7.1 -matplotlib-inline==0.1.6 -mdit-py-plugins==0.4.0 -mdurl==0.1.2 -mistune==3.0.1 -monty==2023.9.25 -mpmath==1.3.0 -multidict==6.0.4 -myst-nb==1.0.0 -myst-parser==2.0.0 -nbclassic==1.0.0 -nbclient==0.7.4 -nbconvert==7.6.0 -nbformat==5.9.0 -nest-asyncio==1.5.6 -networkx==3.1 -notebook==6.5.4 -notebook-shim==0.2.3 -numpy==1.25.0 -overrides==7.3.1 -packaging==23.1 -palettable==3.3.3 -pamqp==3.3.0 -pandas==2.0.2 -pandocfilters==1.5.0 -paramiko==3.4.1 -parso==0.8.3 -pexpect==4.8.0 -pg8000==1.29.8 -pgsu==0.3.0 -pgtest==1.3.2 -pickleshare==0.7.5 -pillow==9.5.0 -platformdirs==3.6.0 -plotly==5.15.0 -pluggy==1.0.0 -plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy -prometheus-client==0.17.0 -prompt-toolkit==3.0.38 -psutil==5.9.5 -psycopg[binary]==3.1.18 -ptyprocess==0.7.0 -pure-eval==0.2.2 -py-cpuinfo==9.0.0 -pybtex==0.24.0 -pycifrw==4.4.5 -pycparser==2.21 -pydantic==2.4.0 -pydantic-core==2.10.0 -pydata-sphinx-theme==0.15.1 -pygments==2.15.1 -pymatgen==2023.9.25 -pympler==1.0.1 -pymysql==0.9.3 -pynacl==1.5.0 -pyparsing==3.1.1 -pyrsistent==0.19.3 -pytest==7.3.2 -pytest-asyncio==0.16.0 -pytest-benchmark==4.0.0 -pytest-cov==4.1.0 -pytest-datadir==1.4.1 -pytest-regressions==2.4.2 -pytest-rerunfailures==12.0 -pytest-timeout==2.2.0 -pytest-xdist==3.6.1 -python-dateutil==2.8.2 -python-json-logger==2.0.7 -python-memcached==1.59 -pytray==0.3.4 -pytz==2021.3 -pyyaml==6.0 -pyzmq==25.1.0 -qtconsole==5.4.3 -qtpy==2.3.1 -requests==2.31.0 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -ruamel-yaml==0.17.32 -ruamel-yaml-clib==0.2.7 -scipy==1.10.1 -scramp==1.4.4 -seekpath==1.9.7 -send2trash==1.8.2 -shortuuid==1.0.11 -six==1.16.0 -sniffio==1.3.0 -snowballstemmer==2.2.0 -soupsieve==2.4.1 -spglib==2.0.2 -sphinx==7.2.6 -sphinx-copybutton==0.5.2 -sphinx-design==0.5.0 -sphinx-intl==2.1.0 -sphinx-notfound-page==1.0.0 -sphinx-sqlalchemy==0.2.0 -sphinxcontrib-applehelp==1.0.4 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.1 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.9 -sphinxext-rediraffe==0.2.7 -sqlalchemy==2.0.23 -stack-data==0.6.2 -sympy==1.12 -tabulate==0.9.0 -tenacity==8.2.2 -terminado==0.17.1 -tinycss2==1.2.1 -tomli==2.0.1 -tornado==6.3.2 -tqdm==4.65.0 -traitlets==5.9.0 -typing-extensions==4.6.3 -tzdata==2023.3 -uncertainties==3.1.7 -upf-to-json==0.9.5 -urllib3==2.0.3 -wcwidth==0.2.6 -webencodings==0.5.1 -websocket-client==1.6.0 -werkzeug==2.3.6 -widgetsnbextension==4.0.7 -wrapt==1.15.0 -yarl==1.9.2 -zipp==3.15.0 - -# The following packages are considered to be unsafe in a requirements file: -# setuptools - diff --git a/requirements/requirements-py-3.11.txt b/requirements/requirements-py-3.11.txt deleted file mode 100644 index e8d2340420..0000000000 --- a/requirements/requirements-py-3.11.txt +++ /dev/null @@ -1,216 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.11 -# by the following command: -# -# pip-compile --extra=atomic_tools --extra=docs --extra=notebook --extra=rest --extra=tests --no-annotate --output-file=requirements/requirements-py-3.9.txt pyproject.toml -# -accessible-pygments==0.0.5 -aiida-export-migration-tests==0.9.0 -aio-pika==9.4.0 -aiormq==6.8.0 -alabaster==0.7.13 -alembic==1.11.1 -aniso8601==9.0.1 -annotated-types==0.7.0 -anyio==3.7.0 -archive-path==0.4.2 -argon2-cffi==21.3.0 -argon2-cffi-bindings==21.2.0 -ase==3.22.1 -asn1crypto==1.5.1 -asttokens==2.2.1 -async-generator==1.10 -asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh -attrs==23.1.0 -babel==2.12.1 -backcall==0.2.0 -bcrypt==4.0.1 -beautifulsoup4==4.12.2 -bleach==6.0.0 -blinker==1.6.2 -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==3.1.0 -circus==0.18.0 -click==8.1.3 -click-spinner==0.1.10 -comm==0.1.3 -contourpy==1.1.0 -coverage[toml]==7.4.1 -cryptography==41.0.1 -cycler==0.11.0 -debugpy==1.6.7 -decorator==5.1.1 -defusedxml==0.7.1 -deprecation==2.1.0 -disk-objectstore==1.2.0 -docstring-parser==0.15 -docutils==0.20.1 -executing==1.2.0 -fastjsonschema==2.17.1 -flask==2.3.2 -flask-cors==3.0.10 -flask-restful==0.3.10 -fonttools==4.40.0 -future==0.18.3 -graphviz==0.20.1 -greenlet==2.0.2 -idna==3.4 -imagesize==1.4.1 -importlib-metadata==6.8.0 -iniconfig==2.0.0 -ipykernel==6.23.2 -ipython==8.14.0 -ipython-genutils==0.2.0 -ipywidgets==8.0.6 -itsdangerous==2.1.2 -jedi==0.18.2 -jinja2==3.1.2 -joblib==1.4.2 -jsonschema[format-nongpl]==3.2.0 -jupyter==1.0.0 -jupyter-cache==0.6.1 -jupyter-client==8.2.0 -jupyter-console==6.6.3 -jupyter-core==5.3.1 -jupyter-events==0.6.3 -jupyter-server==2.6.0 -jupyter-server-terminals==0.4.4 -jupyterlab-pygments==0.2.2 -jupyterlab-widgets==3.0.7 -kiwipy[rmq]==0.8.4 -kiwisolver==1.4.4 -latexcodec==2.0.1 -mako==1.2.4 -markdown-it-py==3.0.0 -markupsafe==2.1.3 -matplotlib==3.7.1 -matplotlib-inline==0.1.6 -mdit-py-plugins==0.4.0 -mdurl==0.1.2 -mistune==3.0.1 -monty==2023.9.25 -mpmath==1.3.0 -multidict==6.0.4 -myst-nb==1.0.0 -myst-parser==2.0.0 -nbclassic==1.0.0 -nbclient==0.7.4 -nbconvert==7.6.0 -nbformat==5.9.0 -nest-asyncio==1.5.6 -networkx==3.1 -notebook==6.5.4 -notebook-shim==0.2.3 -numpy==1.25.0 -overrides==7.3.1 -packaging==23.1 -palettable==3.3.3 -pamqp==3.3.0 -pandas==2.0.2 -pandocfilters==1.5.0 -paramiko==3.4.1 -parso==0.8.3 -pexpect==4.8.0 -pg8000==1.29.8 -pgsu==0.3.0 -pgtest==1.3.2 -pickleshare==0.7.5 -pillow==9.5.0 -platformdirs==3.6.0 -plotly==5.15.0 -pluggy==1.0.0 -plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy -prometheus-client==0.17.0 -prompt-toolkit==3.0.38 -psutil==5.9.5 -psycopg[binary]==3.1.18 -ptyprocess==0.7.0 -pure-eval==0.2.2 -py-cpuinfo==9.0.0 -pybtex==0.24.0 -pycifrw==4.4.5 -pycparser==2.21 -pydantic==2.4.0 -pydantic-core==2.10.0 -pydata-sphinx-theme==0.15.1 -pygments==2.15.1 -pymatgen==2023.9.25 -pympler==1.0.1 -pymysql==0.9.3 -pynacl==1.5.0 -pyparsing==3.1.1 -pyrsistent==0.19.3 -pytest==7.3.2 -pytest-asyncio==0.16.0 -pytest-benchmark==4.0.0 -pytest-cov==4.1.0 -pytest-datadir==1.4.1 -pytest-regressions==2.4.2 -pytest-rerunfailures==12.0 -pytest-timeout==2.2.0 -pytest-xdist==3.6.1 -python-dateutil==2.8.2 -python-json-logger==2.0.7 -python-memcached==1.59 -pytray==0.3.4 -pytz==2021.3 -pyyaml==6.0 -pyzmq==25.1.0 -qtconsole==5.4.3 -qtpy==2.3.1 -requests==2.31.0 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -ruamel-yaml==0.17.32 -ruamel-yaml-clib==0.2.7 -scipy==1.10.1 -scramp==1.4.4 -seekpath==1.9.7 -send2trash==1.8.2 -shortuuid==1.0.11 -six==1.16.0 -sniffio==1.3.0 -snowballstemmer==2.2.0 -soupsieve==2.4.1 -spglib==2.0.2 -sphinx==7.2.6 -sphinx-copybutton==0.5.2 -sphinx-design==0.5.0 -sphinx-intl==2.1.0 -sphinx-notfound-page==1.0.0 -sphinx-sqlalchemy==0.2.0 -sphinxcontrib-applehelp==1.0.4 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.1 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.9 -sphinxext-rediraffe==0.2.7 -sqlalchemy==2.0.23 -stack-data==0.6.2 -sympy==1.12 -tabulate==0.9.0 -tenacity==8.2.2 -terminado==0.17.1 -tinycss2==1.2.1 -tornado==6.3.2 -tqdm==4.65.0 -traitlets==5.9.0 -typing-extensions==4.6.3 -tzdata==2023.3 -uncertainties==3.1.7 -upf-to-json==0.9.5 -urllib3==2.0.3 -wcwidth==0.2.6 -webencodings==0.5.1 -websocket-client==1.6.0 -werkzeug==2.3.6 -widgetsnbextension==4.0.7 -wrapt==1.15.0 -yarl==1.9.2 -zipp==3.15.0 - -# The following packages are considered to be unsafe in a requirements file: -# setuptools - diff --git a/requirements/requirements-py-3.12.txt b/requirements/requirements-py-3.12.txt deleted file mode 100644 index ccff73b6cd..0000000000 --- a/requirements/requirements-py-3.12.txt +++ /dev/null @@ -1,216 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.12 -# by the following command: -# -# pip-compile --extra=atomic_tools --extra=docs --extra=notebook --extra=rest --extra=tests --no-annotate --output-file=requirements/requirements-py-3.12.txt pyproject.toml -# -accessible-pygments==0.0.4 -aiida-export-migration-tests==0.9.0 -aio-pika==9.4.0 -aiormq==6.8.0 -alabaster==0.7.13 -alembic==1.12.0 -aniso8601==9.0.1 -annotated-types==0.7.0 -anyio==4.0.0 -archive-path==0.4.2 -argon2-cffi==23.1.0 -argon2-cffi-bindings==21.2.0 -ase==3.22.1 -asn1crypto==1.5.1 -asttokens==2.4.0 -async-generator==1.10 -asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh -attrs==23.1.0 -babel==2.13.1 -backcall==0.2.0 -bcrypt==4.0.1 -beautifulsoup4==4.12.2 -bleach==6.1.0 -blinker==1.6.3 -certifi==2023.7.22 -cffi==1.16.0 -charset-normalizer==3.3.1 -circus==0.18.0 -click==8.1.7 -click-spinner==0.1.10 -comm==0.1.4 -contourpy==1.1.1 -coverage[toml]==7.4.1 -cryptography==41.0.5 -cycler==0.12.1 -debugpy==1.8.0 -decorator==5.1.1 -defusedxml==0.7.1 -deprecation==2.1.0 -disk-objectstore==1.2.0 -docstring-parser==0.15 -docutils==0.20.1 -executing==2.0.0 -fastjsonschema==2.18.1 -flask==2.3.3 -flask-cors==3.0.10 -flask-restful==0.3.10 -fonttools==4.43.1 -future==0.18.3 -graphviz==0.20.1 -greenlet==3.0.0 -idna==3.4 -imagesize==1.4.1 -importlib-metadata==6.8.0 -iniconfig==2.0.0 -ipykernel==6.25.2 -ipython==8.16.1 -ipython-genutils==0.2.0 -ipywidgets==8.1.1 -itsdangerous==2.1.2 -jedi==0.18.2 -jinja2==3.1.2 -joblib==1.3.2 -jsonschema[format-nongpl]==3.2.0 -jupyter==1.0.0 -jupyter-cache==0.6.1 -jupyter-client==8.4.0 -jupyter-console==6.6.3 -jupyter-core==5.4.0 -jupyter-events==0.6.3 -jupyter-server==2.8.0 -jupyter-server-terminals==0.4.4 -jupyterlab-pygments==0.2.2 -jupyterlab-widgets==3.0.9 -kiwipy[rmq]==0.8.4 -kiwisolver==1.4.5 -latexcodec==2.0.1 -mako==1.2.4 -markdown-it-py==3.0.0 -markupsafe==2.1.3 -matplotlib==3.8.0 -matplotlib-inline==0.1.6 -mdit-py-plugins==0.4.0 -mdurl==0.1.2 -mistune==3.0.2 -monty==2023.9.25 -mpmath==1.3.0 -multidict==6.0.4 -myst-nb==1.0.0 -myst-parser==2.0.0 -nbclassic==1.0.0 -nbclient==0.7.4 -nbconvert==7.9.2 -nbformat==5.9.2 -nest-asyncio==1.5.8 -networkx==3.2 -notebook==6.5.4 -notebook-shim==0.2.3 -numpy==1.26.1 -overrides==7.4.0 -packaging==23.2 -palettable==3.3.3 -pamqp==3.3.0 -pandas==2.1.1 -pandocfilters==1.5.0 -paramiko==3.4.1 -parso==0.8.3 -pexpect==4.8.0 -pg8000==1.30.2 -pgsu==0.3.0 -pgtest==1.3.2 -pickleshare==0.7.5 -pillow==10.1.0 -platformdirs==3.11.0 -plotly==5.17.0 -pluggy==1.3.0 -plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy -prometheus-client==0.17.1 -prompt-toolkit==3.0.39 -psutil==5.9.6 -psycopg[binary]==3.1.18 -ptyprocess==0.7.0 -pure-eval==0.2.2 -py-cpuinfo==9.0.0 -pybtex==0.24.0 -pycifrw==4.4.5 -pycparser==2.21 -pydantic==2.4.0 -pydantic-core==2.10.0 -pydata-sphinx-theme==0.15.1 -pygments==2.16.1 -pymatgen==2023.10.11 -pympler==1.0.1 -pymysql==0.9.3 -pynacl==1.5.0 -pyparsing==3.1.1 -pyrsistent==0.19.3 -pytest==7.4.2 -pytest-asyncio==0.16.0 -pytest-benchmark==4.0.0 -pytest-cov==4.1.0 -pytest-datadir==1.5.0 -pytest-regressions==2.5.0 -pytest-rerunfailures==12.0 -pytest-timeout==2.2.0 -pytest-xdist==3.6.1 -python-dateutil==2.8.2 -python-json-logger==2.0.7 -python-memcached==1.59 -pytray==0.3.4 -pytz==2021.3 -pyyaml==6.0.1 -pyzmq==25.1.1 -qtconsole==5.4.4 -qtpy==2.4.1 -requests==2.31.0 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -ruamel-yaml==0.18.2 -ruamel-yaml-clib==0.2.8 -scipy==1.11.3 -scramp==1.4.4 -seekpath==1.9.7 -send2trash==1.8.2 -shortuuid==1.0.11 -six==1.16.0 -sniffio==1.3.0 -snowballstemmer==2.2.0 -soupsieve==2.5 -spglib==2.1.0 -sphinx==7.2.6 -sphinx-copybutton==0.5.2 -sphinx-design==0.5.0 -sphinx-intl==2.1.0 -sphinx-notfound-page==1.0.0 -sphinx-sqlalchemy==0.2.0 -sphinxcontrib-applehelp==1.0.4 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.1 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.9 -sphinxext-rediraffe==0.2.7 -sqlalchemy==2.0.23 -stack-data==0.6.3 -sympy==1.12 -tabulate==0.9.0 -tenacity==8.2.3 -terminado==0.17.1 -tinycss2==1.2.1 -tornado==6.3.3 -tqdm==4.66.1 -traitlets==5.11.2 -typing-extensions==4.8.0 -tzdata==2023.3 -uncertainties==3.1.7 -upf-to-json==0.9.5 -urllib3==2.0.7 -wcwidth==0.2.8 -webencodings==0.5.1 -websocket-client==1.6.4 -werkzeug==3.0.0 -widgetsnbextension==4.0.9 -wrapt==1.15.0 -yarl==1.9.2 -zipp==3.17.0 - -# The following packages are considered to be unsafe in a requirements file: -# setuptools - diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt deleted file mode 100644 index 01f09b157b..0000000000 --- a/requirements/requirements-py-3.9.txt +++ /dev/null @@ -1,220 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --extra=atomic_tools --extra=docs --extra=notebook --extra=rest --extra=tests --no-annotate --output-file=requirements/requirements-py-3.9.txt pyproject.toml -# -accessible-pygments==0.0.5 -aiida-export-migration-tests==0.9.0 -aio-pika==9.4.0 -aiormq==6.8.0 -alabaster==0.7.13 -alembic==1.11.1 -aniso8601==9.0.1 -annotated-types==0.7.0 -anyio==3.7.0 -archive-path==0.4.2 -argon2-cffi==21.3.0 -argon2-cffi-bindings==21.2.0 -ase==3.22.1 -asn1crypto==1.5.1 -asttokens==2.2.1 -async-generator==1.10 -asyncssh@git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh -attrs==23.1.0 -babel==2.12.1 -backcall==0.2.0 -bcrypt==4.0.1 -beautifulsoup4==4.12.2 -bleach==6.0.0 -blinker==1.6.2 -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==3.1.0 -circus==0.18.0 -click==8.1.3 -click-spinner==0.1.10 -comm==0.1.3 -contourpy==1.1.0 -coverage[toml]==7.4.1 -cryptography==41.0.1 -cycler==0.11.0 -debugpy==1.6.7 -decorator==5.1.1 -defusedxml==0.7.1 -deprecation==2.1.0 -disk-objectstore==1.2.0 -docstring-parser==0.15 -docutils==0.20.1 -exceptiongroup==1.2.1 -executing==1.2.0 -fastjsonschema==2.17.1 -flask==2.3.2 -flask-cors==3.0.10 -flask-restful==0.3.10 -fonttools==4.40.0 -future==0.18.3 -get-annotations==0.1.2 ; python_version < "3.10" -graphviz==0.20.1 -greenlet==2.0.2 -idna==3.4 -imagesize==1.4.1 -importlib-metadata==6.8.0 -importlib-resources==6.4.0 -iniconfig==2.0.0 -ipykernel==6.23.2 -ipython==8.14.0 -ipython-genutils==0.2.0 -ipywidgets==8.0.6 -itsdangerous==2.1.2 -jedi==0.18.2 -jinja2==3.1.2 -joblib==1.4.2 -jsonschema[format-nongpl]==3.2.0 -jupyter==1.0.0 -jupyter-cache==0.6.1 -jupyter-client==8.2.0 -jupyter-console==6.6.3 -jupyter-core==5.3.1 -jupyter-events==0.6.3 -jupyter-server==2.6.0 -jupyter-server-terminals==0.4.4 -jupyterlab-pygments==0.2.2 -jupyterlab-widgets==3.0.7 -kiwipy[rmq]==0.8.4 -kiwisolver==1.4.4 -latexcodec==2.0.1 -mako==1.2.4 -markdown-it-py==3.0.0 -markupsafe==2.1.3 -matplotlib==3.7.1 -matplotlib-inline==0.1.6 -mdit-py-plugins==0.4.0 -mdurl==0.1.2 -mistune==3.0.1 -monty==2023.9.25 -mpmath==1.3.0 -multidict==6.0.4 -myst-nb==1.0.0 -myst-parser==2.0.0 -nbclassic==1.0.0 -nbclient==0.7.4 -nbconvert==7.6.0 -nbformat==5.9.0 -nest-asyncio==1.5.6 -networkx==3.1 -notebook==6.5.4 -notebook-shim==0.2.3 -numpy==1.25.0 -overrides==7.3.1 -packaging==23.1 -palettable==3.3.3 -pamqp==3.3.0 -pandas==2.0.2 -pandocfilters==1.5.0 -paramiko==3.4.1 -parso==0.8.3 -pexpect==4.8.0 -pg8000==1.29.8 -pgsu==0.3.0 -pgtest==1.3.2 -pickleshare==0.7.5 -pillow==9.5.0 -platformdirs==3.6.0 -plotly==5.15.0 -pluggy==1.0.0 -plumpy@git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy -prometheus-client==0.17.0 -prompt-toolkit==3.0.38 -psutil==5.9.5 -psycopg[binary]==3.1.18 -ptyprocess==0.7.0 -pure-eval==0.2.2 -py-cpuinfo==9.0.0 -pybtex==0.24.0 -pycifrw==4.4.5 -pycparser==2.21 -pydantic==2.4.0 -pydantic-core==2.10.0 -pydata-sphinx-theme==0.15.1 -pygments==2.15.1 -pymatgen==2023.9.25 -pympler==1.0.1 -pymysql==0.9.3 -pynacl==1.5.0 -pyparsing==3.1.1 -pyrsistent==0.19.3 -pytest==7.3.2 -pytest-asyncio==0.16.0 -pytest-benchmark==4.0.0 -pytest-cov==4.1.0 -pytest-datadir==1.4.1 -pytest-regressions==2.4.2 -pytest-rerunfailures==12.0 -pytest-timeout==2.2.0 -pytest-xdist==3.6.1 -python-dateutil==2.8.2 -python-json-logger==2.0.7 -python-memcached==1.59 -pytray==0.3.4 -pytz==2021.3 -pyyaml==6.0 -pyzmq==25.1.0 -qtconsole==5.4.3 -qtpy==2.3.1 -requests==2.31.0 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -ruamel-yaml==0.17.32 -ruamel-yaml-clib==0.2.7 -scipy==1.10.1 -scramp==1.4.4 -seekpath==1.9.7 -send2trash==1.8.2 -shortuuid==1.0.11 -six==1.16.0 -sniffio==1.3.0 -snowballstemmer==2.2.0 -soupsieve==2.4.1 -spglib==2.0.2 -sphinx==7.2.6 -sphinx-copybutton==0.5.2 -sphinx-design==0.5.0 -sphinx-intl==2.1.0 -sphinx-notfound-page==1.0.0 -sphinx-sqlalchemy==0.2.0 -sphinxcontrib-applehelp==1.0.4 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.1 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.9 -sphinxext-rediraffe==0.2.7 -sqlalchemy==2.0.23 -stack-data==0.6.2 -sympy==1.12 -tabulate==0.9.0 -tenacity==8.2.2 -terminado==0.17.1 -tinycss2==1.2.1 -tomli==2.0.1 -tornado==6.3.2 -tqdm==4.65.0 -traitlets==5.9.0 -typing-extensions==4.6.3 -tzdata==2023.3 -uncertainties==3.1.7 -upf-to-json==0.9.5 -urllib3==2.0.3 -wcwidth==0.2.6 -webencodings==0.5.1 -websocket-client==1.6.0 -werkzeug==2.3.6 -widgetsnbextension==4.0.7 -wrapt==1.15.0 -yarl==1.9.2 -zipp==3.15.0 - -# The following packages are considered to be unsafe in a requirements file: -# setuptools - From 520e58edd8395de33be3e60204716156a790c654 Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 10 Dec 2024 17:05:31 +0100 Subject: [PATCH 21/29] fixed afew self blocking calls in copy_async() --- environment.yml | 4 ++-- pyproject.toml | 6 +++--- src/aiida/transports/plugins/ssh_async.py | 14 ++++++++++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/environment.yml b/environment.yml index b16b2892ee..39f9cbbfae 100644 --- a/environment.yml +++ b/environment.yml @@ -8,7 +8,7 @@ dependencies: - python~=3.9 - alembic~=1.2 - archive-path~=0.4.2 -- asyncssh@ git+https://github.com/ronf/asyncssh.git@033ef54302b2b09d496d68ccf39778b9e5fc89e2#egg=asyncssh +- asyncssh - circus~=0.18.0 - click-spinner~=0.1.8 - click~=8.1 @@ -23,7 +23,7 @@ dependencies: - importlib-metadata~=6.0 - numpy~=1.21 - paramiko~=3.0 -- plumpy@ git+https://github.com/aiidateam/plumpy.git@async-run#egg=plumpy +- plumpy - pgsu~=0.3.0 - psutil~=5.6 - psycopg[binary]~=3.0 diff --git a/pyproject.toml b/pyproject.toml index 3b4ed0539d..e6ff19f241 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ 'tabulate>=0.8.0,<0.10.0', 'tqdm~=4.45', 'upf_to_json~=0.9.2', - 'wrapt~=1.11', + 'wrapt~=1.11' ] description = 'AiiDA is a workflow manager for computational science with a strong focus on provenance, performance and extensibility.' dynamic = ['version'] # read from aiida/__init__.py @@ -513,5 +513,5 @@ commands = molecule {posargs:test} """ [tool.uv.sources] -plumpy = { git = "https://github.com/aiidateam/plumpy", branch = "async-run" } -asyncssh = { git = "https://github.com/ronf/asyncssh", rev = "033ef54302b2b09d496d68ccf39778b9e5fc89e2" } +asyncssh = {git = "https://github.com/ronf/asyncssh", rev = "033ef54302b2b09d496d68ccf39778b9e5fc89e2"} +plumpy = {git = "https://github.com/aiidateam/plumpy", branch = "async-run"} diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 6faace2e08..775a9d36f2 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -655,7 +655,9 @@ async def _exec_cp(cp_exe: str, cp_flags: str, src: str, dst: str): to_copy_list = await self.glob_async(remotesource) if len(to_copy_list) > 1: - if not self.path_exists(remotedestination) or self.isfile(remotedestination): + if not await self.path_exists_async(remotedestination) or await self.isfile_async( + remotedestination + ): raise OSError("Can't copy more than one file in the same destination file") for file in to_copy_list: @@ -718,22 +720,26 @@ async def exec_command_wait_async( stdin: Optional[str] = None, encoding: str = 'utf-8', workdir: Optional[TransportPath] = None, - timeout: Optional[float] = 2, + timeout: Optional[float] = None, **kwargs, ): """Execute a command on the remote machine and wait for it to finish. :param command: the command to execute :param stdin: the input to pass to the command + Default = None :param encoding: (IGNORED) this is here just to keep the same signature as the one in `Transport` class + Default = 'utf-8' :param workdir: the working directory where to execute the command + Default = None :param timeout: the timeout in seconds + Default = None :type command: str :type stdin: str :type encoding: str - :type workdir: Union[TransportPath, None] - :type timeout: float + :type workdir: Optional[TransportPath] + :type timeout: Optional[float] :return: a tuple with the return code, the stdout and the stderr of the command :rtype: tuple(int, str, str) From 90718f42a2d622de4cf0b867fa52714905254468 Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 11 Dec 2024 12:20:17 +0100 Subject: [PATCH 22/29] fix rtd --- .readthedocs.yml | 2 +- docs/source/topics/transport.rst | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 8f1e3118d0..6b64320875 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,7 +20,7 @@ build: - asdf install uv 0.2.9 - asdf global uv 0.2.9 post_install: - - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install .[docs,tests,rest,atomic_tools] + - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install .[docs,tests,rest,atomic_tools] --preview # Let the build fail if there are any warnings sphinx: diff --git a/docs/source/topics/transport.rst b/docs/source/topics/transport.rst index 45706b94a7..d95cbeaac2 100644 --- a/docs/source/topics/transport.rst +++ b/docs/source/topics/transport.rst @@ -24,15 +24,15 @@ The generic transport class contains a set of minimal methods that an implementa If not, a ``NotImplementedError`` will be raised, interrupting the managing of the calculation or whatever is using the transport plugin. As for the general functioning of the plugin, the :py:meth:`~aiida.transports.transport.Transport.__init__` method is used only to initialize the class instance, without actually opening the transport channel. -The connection must be opened only by the :py:meth:`~aiida.transports.transport.Transport.__enter__` method, (and closed by :py:meth:`~aiida.transports.transport.Transport.__exit__`). -The :py:meth:`~aiida.transports.transport.Transport.__enter__` method lets you use the transport class using the ``with`` statement (see `python docs `_), in a way similar to the following: +The connection must be opened only by the :py:meth:`~aiida.transports.transport._BaseTransport.__enter__` method, (and closed by :py:meth:`~aiida.transports.transport._BaseTransport.__exit__`). +The :py:meth:`~aiida.transports.transport._BaseTransport.__enter__` method lets you use the transport class using the ``with`` statement (see `python docs `_), in a way similar to the following: .. code-block:: python with TransportPlugin() as transport: transport.some_method() -To ensure this, for example, the local plugin uses a hidden boolean variable ``_is_open`` that is set when the :py:meth:`~aiida.transports.transport.Transport.__enter__` and :py:meth:`~aiida.transports.transport.Transport.__exit__` methods are called. +To ensure this, for example, the local plugin uses a hidden boolean variable ``_is_open`` that is set when the :py:meth:`~aiida.transports.transport._BaseTransport.__enter__` and :py:meth:`~aiida.transports.transport._BaseTransport.__exit__` methods are called. The ``ssh`` logic is instead given by the property sftp. The other functions that require some care are the copying functions, called using the following terminology: From 343cf9c24732e1a4f9fcac6ea8d638e115946c66 Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 11 Dec 2024 12:26:52 +0100 Subject: [PATCH 23/29] fix uv --- pyproject.toml | 4 ++-- uv.lock | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6ff19f241..c57820e0b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -513,5 +513,5 @@ commands = molecule {posargs:test} """ [tool.uv.sources] -asyncssh = {git = "https://github.com/ronf/asyncssh", rev = "033ef54302b2b09d496d68ccf39778b9e5fc89e2"} -plumpy = {git = "https://github.com/aiidateam/plumpy", branch = "async-run"} +asyncssh = {git = "https://github.com/ronf/asyncssh", branch = "develop"} +plumpy = {git = "https://github.com/aiidateam/plumpy", branch = "master"} diff --git a/uv.lock b/uv.lock index 1efece9bff..ab8343831a 100644 --- a/uv.lock +++ b/uv.lock @@ -160,7 +160,7 @@ requires-dist = [ { name = "alembic", specifier = "~=1.2" }, { name = "archive-path", specifier = "~=0.4.2" }, { name = "ase", marker = "extra == 'atomic-tools'", specifier = "~=3.18" }, - { name = "asyncssh", git = "https://github.com/ronf/asyncssh?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2" }, + { name = "asyncssh", git = "https://github.com/ronf/asyncssh?branch=develop" }, { name = "bpython", marker = "extra == 'bpython'", specifier = "~=0.18.0" }, { name = "circus", specifier = "~=0.18.0" }, { name = "click", specifier = "~=8.1" }, @@ -192,7 +192,7 @@ requires-dist = [ { name = "pg8000", marker = "extra == 'tests'", specifier = "~=1.13" }, { name = "pgsu", specifier = "~=0.3.0" }, { name = "pgtest", marker = "extra == 'tests'", specifier = "~=1.3,>=1.3.1" }, - { name = "plumpy", git = "https://github.com/aiidateam/plumpy?branch=async-run" }, + { name = "plumpy", git = "https://github.com/aiidateam/plumpy?branch=master" }, { name = "pre-commit", marker = "extra == 'pre-commit'", specifier = "~=3.5" }, { name = "psutil", specifier = "~=5.6" }, { name = "psycopg", extras = ["binary"], specifier = "~=3.0" }, @@ -450,7 +450,7 @@ wheels = [ [[package]] name = "asyncssh" version = "2.18.0" -source = { git = "https://github.com/ronf/asyncssh?rev=033ef54302b2b09d496d68ccf39778b9e5fc89e2#033ef54302b2b09d496d68ccf39778b9e5fc89e2" } +source = { git = "https://github.com/ronf/asyncssh?branch=develop#e8169bf19c74e5dcaa184123764db403ad131a36" } dependencies = [ { name = "cryptography" }, { name = "typing-extensions" }, @@ -2929,7 +2929,7 @@ wheels = [ [[package]] name = "plumpy" version = "0.22.3" -source = { git = "https://github.com/aiidateam/plumpy?branch=async-run#2a003072ffe45e570bd0a76aafe049e559836cb6" } +source = { git = "https://github.com/aiidateam/plumpy?branch=master#4611154c76ac0991bcf7371b21488f4390648c28" } dependencies = [ { name = "kiwipy", extra = ["rmq"] }, { name = "nest-asyncio" }, From 482eeca874e94323d40aebc12c4a17e147082183 Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 11 Dec 2024 13:54:11 +0100 Subject: [PATCH 24/29] escape for bash on command --- src/aiida/transports/plugins/ssh_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 775a9d36f2..82ed07e633 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -747,7 +747,7 @@ async def exec_command_wait_async( if workdir: workdir = path_to_str(workdir) - command = f'cd {workdir} && {command}' + command = f'cd {workdir} && ( {command} )' bash_commmand = self._bash_command_str + '-c ' From 5e29e5b06b4d08dd2ea24da082c0f36e290fdf35 Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 11 Dec 2024 17:41:45 +0100 Subject: [PATCH 25/29] fixed many warnings of rtd --- src/aiida/transports/__init__.py | 1 + src/aiida/transports/plugins/ssh_async.py | 139 ++++++++--------- src/aiida/transports/transport.py | 176 +++++++++++----------- tests/transports/test_all_plugins.py | 3 - 4 files changed, 160 insertions(+), 159 deletions(-) diff --git a/src/aiida/transports/__init__.py b/src/aiida/transports/__init__.py index 7a7d472869..0d36fe3980 100644 --- a/src/aiida/transports/__init__.py +++ b/src/aiida/transports/__init__.py @@ -19,6 +19,7 @@ 'AsyncTransport', 'SshTransport', 'Transport', + 'TransportPath', 'convert_to_bool', 'parse_sshconfig', ) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 82ed07e633..81067d79ae 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -17,12 +17,11 @@ import asyncssh import click -from asyncssh import SFTPFileAlreadyExists, SFTPOpUnsupported +from asyncssh import SFTPFileAlreadyExists from aiida.common.escaping import escape_for_bash from aiida.common.exceptions import InvalidOperation - -from ..transport import AsyncTransport, Transport, TransportInternalError, TransportPath, path_to_str +from aiida.transports.transport import AsyncTransport, Transport, TransportInternalError, TransportPath, path_to_str __all__ = ('AsyncSshTransport',) @@ -160,8 +159,8 @@ async def get_async( :param preserve: preserve file attributes Default = False - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type overwrite: bool :type ignore_nonexisting: bool @@ -240,8 +239,8 @@ async def getfile_async( :param preserve: preserve file attributes Default = False - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type overwrite: bool :type preserve: bool @@ -290,8 +289,8 @@ async def gettree_async( :param preserve: preserve file attributes Default = False - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type overwrite: bool :type preserve: bool @@ -363,8 +362,8 @@ async def put_async( :param preserve: preserve file attributes Default = False - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type overwrite: bool :type ignore_nonexisting: bool @@ -443,8 +442,8 @@ async def putfile_async( :param preserve: preserve file attributes Default = False - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type overwrite: bool :type preserve: bool @@ -494,8 +493,8 @@ async def puttree_async( :param preserve: preserve file attributes Default = False - :type localpath: TransportPath - :type remotepath: TransportPath + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type overwrite: bool :type preserve: bool @@ -562,8 +561,8 @@ async def copy_async( :param preserve: preserve file attributes Default = False - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type recursive: bool :type preserve: bool @@ -586,30 +585,31 @@ async def copy_async( # For performance reasons, we should check if the remote copy is supported, if so use # self._sftp.mcopy() & self._sftp.copy() otherwise send a `cp` command to the remote machine. # See here: https://github.com/ronf/asyncssh/issues/724 - try: - if self.has_magic(remotesource): - await self._sftp.mcopy( - remotesource, - remotedestination, - preserve=preserve, - recurse=recursive, - follow_symlinks=dereference, - remote_only=True, - ) - else: - if not await self.path_exists_async(remotesource): - raise OSError(f'The remote path {remotesource} does not exist') - await self._sftp.copy( - remotesource, - remotedestination, - preserve=preserve, - recurse=recursive, - follow_symlinks=dereference, - remote_only=True, - ) - except asyncssh.sftp.SFTPFailure as exc: - raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') - except SFTPOpUnsupported: + if self._sftp.supports_remote_copy: + try: + if self.has_magic(remotesource): + await self._sftp.mcopy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + remote_only=True, + ) + else: + if not await self.path_exists_async(remotesource): + raise OSError(f'The remote path {remotesource} does not exist') + await self._sftp.copy( + remotesource, + remotedestination, + preserve=preserve, + recurse=recursive, + follow_symlinks=dereference, + remote_only=True, + ) + except asyncssh.sftp.SFTPFailure as exc: + raise OSError(f'Error while copying {remotesource} to {remotedestination}: {exc}') + else: self.logger.warning('The remote copy is not supported, using the `cp` command to copy the file/folder') # I copy pasted the whole logic below from SshTransport class: @@ -681,8 +681,8 @@ async def copyfile_async( :param preserve: preserve file attributes Default = False - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type preserve: bool @@ -705,8 +705,8 @@ async def copytree_async( :param preserve: preserve file attributes Default = False - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type preserve: bool @@ -738,8 +738,8 @@ async def exec_command_wait_async( :type command: str :type stdin: str :type encoding: str - :type workdir: Optional[TransportPath] - :type timeout: Optional[float] + :type workdir: :class:`Path `, :class:`PurePosixPath `, or `str` + :type timeout: float :return: a tuple with the return code, the stdout and the stderr of the command :rtype: tuple(int, str, str) @@ -776,7 +776,7 @@ async def get_attribute_async(self, path: TransportPath): :param path: path to file - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: object FixedFieldsAttributeDict """ @@ -809,7 +809,7 @@ async def isdir_async(self, path: TransportPath): :param path: the absolute path to check - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: True if the path is a directory, False otherwise """ @@ -827,7 +827,7 @@ async def isfile_async(self, path: TransportPath): :param path: the absolute path to check - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: True if the path is a file, False otherwise """ @@ -848,7 +848,7 @@ async def listdir_async(self, path: TransportPath, pattern=None): :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a list of strings """ @@ -876,7 +876,7 @@ async def listdir_withattributes_async(self, path: TransportPath, pattern: Optio :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is @@ -911,7 +911,7 @@ async def makedirs_async(self, path, ignore_existing=False): :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raises: OSError, if directory at path already exists """ @@ -934,7 +934,7 @@ async def mkdir_async(self, path: TransportPath, ignore_existing=False): :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raises: OSError, if directory at path already exists """ @@ -965,7 +965,7 @@ async def remove_async(self, path: TransportPath): :param path: path to file to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raise OSError: if the path is a directory """ @@ -984,8 +984,8 @@ async def rename_async(self, oldpath: TransportPath, newpath: TransportPath): :param oldpath: existing name of the file or folder :param newpath: new name for the file or folder - :type oldpath: TransportPath - :type newpath: TransportPath + :type oldpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type newpath: :class:`Path `, :class:`PurePosixPath `, or `str` :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string @@ -1006,7 +1006,7 @@ async def rmdir_async(self, path: TransportPath): :param str path: absolute path to the folder to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` """ path = path_to_str(path) try: @@ -1019,7 +1019,7 @@ async def rmtree_async(self, path: TransportPath): :param str path: absolute path to the folder to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raises OSError: if the operation fails """ @@ -1034,7 +1034,7 @@ async def path_exists_async(self, path: TransportPath): :param path: path to check - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` """ path = path_to_str(path) return await self._sftp.exists(path) @@ -1065,8 +1065,8 @@ async def symlink_async(self, remotesource: TransportPath, remotedestination: Tr :param remotesource: absolute path to remote source :param remotedestination: absolute path to remote destination - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :raises ValueError: if remotedestination has patterns """ @@ -1095,7 +1095,7 @@ async def glob_async(self, pathname: TransportPath): :param pathname: the pathname pattern to match. It should only be absolute path. - :type pathname: TransportPath + :type pathname: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a list of paths matching the pattern. """ @@ -1109,7 +1109,7 @@ async def chmod_async(self, path: TransportPath, mode: int, follow_symlinks: boo :param mode: the new permissions :param bool follow_symlinks: if True, follow symbolic links - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type mode: int :type follow_symlinks: bool @@ -1130,7 +1130,7 @@ async def chown_async(self, path: TransportPath, uid: int, gid: int): :param uid: the new owner id :param gid: the new group id - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type uid: int :type gid: int @@ -1159,9 +1159,10 @@ async def copy_from_remote_to_remote_async( :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get - :type transportdestination: Union['Transport', 'AsyncTransport'] - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type transportdestination: :class:`Transport `, + or :class:`AsyncTransport ` + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` .. note:: the keyword 'dereference' SHOULD be set to False for the final put (onto the destination), while it can be set to the @@ -1210,7 +1211,7 @@ def gotocomputer_command(self, remotedir: TransportPath): :param remotedir: the remote directory to connect to - :type remotedir: TransportPath + :type remotedir: :class:`Path `, :class:`PurePosixPath `, or `str` """ connect_string = self._gotocomputer_string(remotedir) cmd = f'ssh -t {self.machine} {connect_string}' diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index fea7ee4bfa..5a59dbcf09 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -22,7 +22,7 @@ from aiida.common.lang import classproperty from aiida.common.warnings import warn_deprecation -__all__ = ('AsyncTransport', 'Transport') +__all__ = ('AsyncTransport', 'Transport', 'TransportPath') TransportPath = Union[str, Path, PurePosixPath] @@ -302,7 +302,7 @@ def chmod(self, path: TransportPath, mode): :param path: path to file :param mode: new permissions - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type mode: int """ @@ -317,7 +317,7 @@ def chown(self, path: TransportPath, uid: int, gid: int): :param uid: new owner's uid :param gid: new group id - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type uid: int :type gid: int """ @@ -332,8 +332,8 @@ def copy(self, remotesource: TransportPath, remotedestination: TransportPath, de :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves :param recursive: if True copy directories recursively, otherwise only copy the specified file(s) - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type recursive: bool @@ -349,8 +349,8 @@ def copyfile(self, remotesource: TransportPath, remotedestination: TransportPath :param remotedestination: path of the remote destination directory / file :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :raises OSError: if one of src or dst does not exist @@ -365,8 +365,8 @@ def copytree(self, remotesource: TransportPath, remotedestination: TransportPath :param remotedestination: path of the remote destination directory / file :param dereference: if True copy the contents of any symlinks found, otherwise copy the symlinks themselves - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :raise OSError: if one of src or dst does not exist @@ -388,9 +388,10 @@ def copy_from_remote_to_remote( :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get - :type transportdestination: Union['Transport', 'AsyncTransport'] - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type transportdestination: :class:`Transport `, + or :class:`AsyncTransport ` + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` .. note:: the keyword 'dereference' SHOULD be set to False for the final put (onto the destination), while it can be set to the @@ -448,7 +449,7 @@ def _exec_command_internal(self, command: str, workdir: Optional[TransportPath] in the specified working directory. :type command: str - :type workdir: TransportPath + :type workdir: :class:`Path `, :class:`PurePosixPath `, or `str` :return: stdin, stdout, stderr and the session, when this exists \ (can be None). @@ -469,7 +470,7 @@ def exec_command_wait_bytes(self, command: str, stdin=None, workdir: Optional[Tr in the specified working directory. :type command: str - :type workdir: TransportPath + :type workdir: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a tuple: the retcode (int), stdout (bytes) and stderr (bytes). """ @@ -496,7 +497,7 @@ def exec_command_wait( :type command: str :type encoding: str - :type workdir: TransportPath + :type workdir: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both strings, decoded with the specified encoding. @@ -518,8 +519,8 @@ def get(self, remotepath: TransportPath, localpath: TransportPath, *args, **kwar :param remotepath: remote_folder_path :param localpath: (local_folder_path - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -530,8 +531,8 @@ def getfile(self, remotepath: TransportPath, localpath: TransportPath, *args, ** :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -542,8 +543,8 @@ def gettree(self, remotepath: TransportPath, localpath: TransportPath, *args, ** :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -581,7 +582,7 @@ def get_attribute(self, path: TransportPath): :param path: path to file - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: object FixedFieldsAttributeDict """ @@ -591,7 +592,7 @@ def get_mode(self, path: TransportPath): :param path: path to file - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: the portion of the file's mode that can be set by chmod() """ @@ -606,7 +607,7 @@ def isdir(self, path: TransportPath): :param path: path to directory - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: boolean """ @@ -618,7 +619,7 @@ def isfile(self, path: TransportPath): :param path: path to file - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: boolean """ @@ -634,7 +635,7 @@ def listdir(self, path: TransportPath = '.', pattern=None): :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a list of strings """ @@ -649,7 +650,7 @@ def listdir_withattributes(self, path: TransportPath = '.', pattern: Optional[st taken from DEPRECATED `self.getcwd()`. :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is @@ -691,7 +692,7 @@ def makedirs(self, path: TransportPath, ignore_existing=False): :param path: directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raises: OSError, if directory at path already exists """ @@ -703,7 +704,7 @@ def mkdir(self, path: TransportPath, ignore_existing=False): :param path: name of the folder to create :param bool ignore_existing: if True, does not give any error if the directory already exists - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raises: OSError, if directory at path already exists """ @@ -716,7 +717,7 @@ def normalize(self, path: TransportPath = '.'): :param path: path to be normalized - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raise OSError: if the path can't be resolved on the server """ @@ -733,8 +734,8 @@ def put(self, localpath: TransportPath, remotepath: TransportPath, *args, **kwar :param localpath: absolute path to local source :param remotepath: path to remote destination - :type localpath: TransportPath - :type remotepath: TransportPath + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -745,8 +746,8 @@ def putfile(self, localpath: TransportPath, remotepath: TransportPath, *args, ** :param localpath: absolute path to local file :param remotepath: path to remote file - :type localpath: TransportPath - :type remotepath: TransportPath + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -757,8 +758,8 @@ def puttree(self, localpath: TransportPath, remotepath: TransportPath, *args, ** :param localpath: absolute path to local folder :param remotepath: path to remote folder - :type localpath: TransportPath - :type remotepath: TransportPath + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -768,7 +769,7 @@ def remove(self, path: TransportPath): :param path: path to file to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raise OSError: if the path is a directory """ @@ -780,8 +781,8 @@ def rename(self, oldpath: TransportPath, newpath: TransportPath): :param oldpath: existing name of the file or folder :param newpath: new name for the file or folder - :type oldpath: TransportPath - :type newpath: TransportPath + :type oldpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type newpath: :class:`Path `, :class:`PurePosixPath `, or `str` :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string @@ -794,7 +795,7 @@ def rmdir(self, path: TransportPath): :param path: absolute path to the folder to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -803,7 +804,7 @@ def rmtree(self, path: TransportPath): :param path: absolute path to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raise OSError: if the rm execution failed. """ @@ -821,7 +822,7 @@ def gotocomputer_command(self, remotedir: TransportPath): :param remotedir: the full path of the remote directory - :type remotedir: TransportPath + :type remotedir: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -832,8 +833,8 @@ def symlink(self, remotesource: TransportPath, remotedestination: TransportPath) :param remotesource: remote source :param remotedestination: remote destination - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` """ def whoami(self): @@ -861,7 +862,7 @@ def path_exists(self, path: TransportPath): :param path: path to check for existence - :type path: TransportPath""" + :type path: :class:`Path `, :class:`PurePosixPath `, or `str`""" # The following definitions are almost copied and pasted # from the python module glob. @@ -874,7 +875,7 @@ def glob(self, pathname: TransportPath): It should only be an absolute path. DEPRECATED: using relative path is deprecated. - :type pathname: TransportPath + :type pathname: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a list of paths matching the pattern. """ @@ -1135,7 +1136,7 @@ async def chmod_async(self, path: TransportPath, mode: int): :param path: path to file or directory :param mode: new permissions - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type mode: int """ @@ -1147,7 +1148,7 @@ async def chown_async(self, path: TransportPath, uid: int, gid: int): :param uid: user id of the new owner :param gid: group id of the new owner - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type uid: int :type gid: int """ @@ -1162,8 +1163,8 @@ async def copy_async(self, remotesource, remotedestination, dereference=False, r :param dereference: follow symbolic links :param recursive: copy recursively - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :type recursive: bool @@ -1179,8 +1180,8 @@ async def copyfile_async(self, remotesource: TransportPath, remotedestination: T :param remotedestination: path to the remote destination file :param dereference: follow symbolic links - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :raises: OSError, src does not exist or if the copy execution failed.""" @@ -1194,8 +1195,8 @@ async def copytree_async(self, remotesource: TransportPath, remotedestination: T :param remotedestination: path to the remote destination folder :param dereference: follow symbolic links - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` :type dereference: bool :raises: OSError, src does not exist or if the copy execution failed.""" @@ -1216,9 +1217,10 @@ async def copy_from_remote_to_remote_async( :param kwargs: keyword parameters passed to the call to transportdestination.put, except for 'dereference' that is passed to self.get - :type transportdestination: Union['Transport', 'AsyncTransport'] - :type remotesource: TransportPath - :type remotedestination: TransportPath + :type transportdestination: :class:`Transport `, + or :class:`AsyncTransport ` + :type remotesource: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotedestination: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1240,7 +1242,7 @@ async def exec_command_wait_async( :type command: str :type stdin: str :type encoding: str - :type workdir: Union[TransportPath, None] + :type workdir: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a tuple with (return_value, stdout, stderr) where stdout and stderr are both strings. :rtype: Tuple[int, str, str] @@ -1257,8 +1259,8 @@ async def get_async(self, remotepath: TransportPath, localpath: TransportPath, * :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1269,8 +1271,8 @@ async def getfile_async(self, remotepath: TransportPath, localpath: TransportPat :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1281,8 +1283,8 @@ async def gettree_async(self, remotepath: TransportPath, localpath: TransportPat :param remotepath: remote_folder_path :param localpath: local_folder_path - :type remotepath: TransportPath - :type localpath: TransportPath + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1305,7 +1307,7 @@ async def get_attribute_async(self, path: TransportPath): :param path: path to file - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: object FixedFieldsAttributeDict """ @@ -1315,7 +1317,7 @@ async def get_mode_async(self, path: TransportPath): :param str path: path to file - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: the portion of the file's mode that can be set by chmod() """ @@ -1331,7 +1333,7 @@ async def isdir_async(self, path: TransportPath): :param path: path to directory - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: boolean """ @@ -1343,7 +1345,7 @@ async def isfile_async(self, path: TransportPath): :param path: path to file - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: boolean """ @@ -1358,7 +1360,7 @@ async def listdir_async(self, path: TransportPath, pattern: Optional[str] = None :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a list of strings """ @@ -1377,7 +1379,7 @@ async def listdir_withattributes_async( :param pattern: if used, listdir returns a list of files matching filters in Unix style. Unix only. - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :type pattern: str :return: a list of dictionaries, one per entry. The schema of the dictionary is @@ -1403,7 +1405,7 @@ async def makedirs_async(self, path: TransportPath, ignore_existing=False): :param path: directory to create :param bool ignore_existing: if set to true, it doesn't give any error if the leaf directory does already exist - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raises: OSError, if directory at path already exists """ @@ -1415,7 +1417,7 @@ async def mkdir_async(self, path: TransportPath, ignore_existing=False): :param path: name of the folder to create :param bool ignore_existing: if True, does not give any error if the directory already exists. - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raises: OSError, if directory at path already exists """ @@ -1428,7 +1430,7 @@ async def normalize_async(self, path: TransportPath): :param path: path to be normalized - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raise OSError: if the path can't be resolved on the server """ @@ -1445,8 +1447,8 @@ async def put_async(self, localpath: TransportPath, remotepath: TransportPath, * :param localpath: absolute path to local source :param remotepath: path to remote destination - :type localpath: TransportPath - :type remotepath: TransportPath + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1457,8 +1459,8 @@ async def putfile_async(self, localpath: TransportPath, remotepath: TransportPat :param localpath: absolute path to local file :param remotepath: path to remote file - :type localpath: TransportPath - :type remotepath: TransportPath + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1469,8 +1471,8 @@ async def puttree_async(self, localpath: TransportPath, remotepath: TransportPat :param localpath: absolute path to local folder :param remotepath: path to remote folder - :type localpath: TransportPath - :type remotepath: TransportPath + :type localpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type remotepath: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1480,7 +1482,7 @@ async def remove_async(self, path: TransportPath): :param path: path to file to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raise OSError: if the path is a directory """ @@ -1492,8 +1494,8 @@ async def rename_async(self, oldpath: TransportPath, newpath: TransportPath): :param oldpath: existing name of the file or folder :param newpath: new name for the file or folder - :type oldpath: TransportPath - :type newpath: TransportPath + :type oldpath: :class:`Path `, :class:`PurePosixPath `, or `str` + :type newpath: :class:`Path `, :class:`PurePosixPath `, or `str` :raises OSError: if oldpath/newpath is not found :raises ValueError: if oldpath/newpath is not a valid string @@ -1506,7 +1508,7 @@ async def rmdir_async(self, path: TransportPath): :param path: absolute path to the folder to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1515,7 +1517,7 @@ async def rmtree_async(self, path: TransportPath): :param path: absolute path to remove - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` :raise OSError: if the rm execution failed. """ @@ -1536,7 +1538,7 @@ def gotocomputer_command(self, remotedir: TransportPath): :param remotedir: the full path of the remote directory - :type remotedir: TransportPath + :type remotedir: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1565,7 +1567,7 @@ async def path_exists_async(self, path: TransportPath): :param path: path to check for existence - :type path: TransportPath + :type path: :class:`Path `, :class:`PurePosixPath `, or `str` """ @abc.abstractmethod @@ -1577,7 +1579,7 @@ async def glob_async(self, pathname: TransportPath): :param pathname: the pathname pattern to match. It should only be absolute path. - :type pathname: TransportPath + :type pathname: :class:`Path `, :class:`PurePosixPath `, or `str` :return: a list of paths matching the pattern. """ diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 1290e871f3..aff10c4941 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -678,9 +678,6 @@ def test_copy(custom_transport, tmp_path_remote): transport.copy((base_dir), (workdir / 'prova')) assert set(['origin']) == set(transport.listdir((workdir / 'prova'))) assert set(['a.txt', 'b.tmp', 'c.txt']) == set(transport.listdir((workdir / 'prova' / 'origin'))) - transport.rmtree((workdir / 'prova')) - # exit - transport.rmtree((workdir)) def test_put(custom_transport, tmp_path_remote, tmp_path_local): From 1761d949ef3d7e30a54c2fe2d12a8236b32450c8 Mon Sep 17 00:00:00 2001 From: Ali Date: Fri, 13 Dec 2024 11:33:40 +0100 Subject: [PATCH 26/29] implement max_io_allowed --- src/aiida/transports/plugins/ssh_async.py | 66 +++++++++++++++++++---- tests/transports/test_all_plugins.py | 2 +- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 81067d79ae..c0d875b8ac 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -21,12 +21,19 @@ from aiida.common.escaping import escape_for_bash from aiida.common.exceptions import InvalidOperation -from aiida.transports.transport import AsyncTransport, Transport, TransportInternalError, TransportPath, path_to_str +from aiida.transports.transport import ( + AsyncTransport, + Transport, + TransportInternalError, + TransportPath, + path_to_str, + validate_positive_number, +) __all__ = ('AsyncSshTransport',) -def _validate_script(ctx, param, value: str): +def validate_script(ctx, param, value: str): if value == 'None': return value if not os.path.isabs(value): @@ -38,7 +45,7 @@ def _validate_script(ctx, param, value: str): return value -def _validate_machine(ctx, param, value: str): +def validate_machine(ctx, param, value: str): async def attempt_connection(): try: await asyncssh.connect(value) @@ -60,15 +67,29 @@ class AsyncSshTransport(AsyncTransport): # note, I intentionally wanted to keep connection parameters as simple as possible. _valid_auth_options = [ ( - 'machine', + # the underscore is added to avoid conflict with the machine property + # which is passed to __init__ as parameter `machine=computer.hostname` + 'machine_or_host', { 'type': str, - 'prompt': 'machine as in `ssh machine` command', - 'help': 'Password-less host-setup to connect, as in command `ssh machine`. ' - "You'll need to have a `Host machine` " - 'entry defined in your `~/.ssh/config` file. ', + 'prompt': 'Machine(or host) name as in `ssh ` command.' + ' (It should be a password-less setup)', + 'help': 'Password-less host-setup to connect, as in command `ssh `. ' + "You'll need to have a `Host ` entry defined in your `~/.ssh/config` file.", 'non_interactive_default': True, - 'callback': _validate_machine, + 'callback': validate_machine, + }, + ), + ( + 'max_io_allowed', + { + 'type': int, + 'default': 8, + 'prompt': 'Maximum number of concurrent I/O operations.', + 'help': 'Depends on various factors, such as your network bandwidth, the server load, etc.' + ' (An experimental number)', + 'non_interactive_default': True, + 'callback': validate_positive_number, }, ), ( @@ -80,7 +101,7 @@ class AsyncSshTransport(AsyncTransport): 'help': ' (optional) Specify a script to run *before* opening SSH connection. ' 'The script should be executable', 'non_interactive_default': True, - 'callback': _validate_script, + 'callback': validate_script, }, ), ] @@ -96,9 +117,24 @@ def _get_machine_suggestion_string(cls, computer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.machine = kwargs.pop('machine') + self.machine = kwargs.pop('machine_or_host') + self._max_io_allowed = kwargs.pop('max_io_allowed') self.script_before = kwargs.pop('script_before', 'None') + self._councurrent_io = 0 + + @property + def max_io_allowed(self): + return self._max_io_allowed + + async def _lock(self, sleep_time=0.5): + while self._councurrent_io >= self.max_io_allowed: + await asyncio.sleep(sleep_time) + self._councurrent_io += 1 + + async def _unlock(self): + self._councurrent_io -= 1 + async def open_async(self): """Open the transport. This plugin supports running scripts before and during the connection. @@ -258,6 +294,7 @@ async def getfile_async( raise OSError('Destination already exists: not overwriting it') try: + await self._lock() await self._sftp.get( remotepaths=remotepath, localpath=localpath, @@ -265,6 +302,7 @@ async def getfile_async( recurse=False, follow_symlinks=dereference, ) + await self._unlock() except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') @@ -327,6 +365,7 @@ async def gettree_async( content_list = await self.listdir_async(remotepath) for content_ in content_list: try: + await self._lock() await self._sftp.get( remotepaths=PurePath(remotepath) / content_, localpath=localpath, @@ -334,6 +373,7 @@ async def gettree_async( recurse=True, follow_symlinks=dereference, ) + await self._unlock() except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') @@ -462,6 +502,7 @@ async def putfile_async( raise OSError('Destination already exists: not overwriting it') try: + await self._lock() await self._sftp.put( localpaths=localpath, remotepath=remotepath, @@ -469,6 +510,7 @@ async def putfile_async( recurse=False, follow_symlinks=dereference, ) + await self._unlock() except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {localpath}: {exc}') @@ -534,6 +576,7 @@ async def puttree_async( content_list = os.listdir(localpath) for content_ in content_list: try: + await self._lock() await self._sftp.put( localpaths=PurePath(localpath) / content_, remotepath=remotepath, @@ -541,6 +584,7 @@ async def puttree_async( recurse=True, follow_symlinks=dereference, ) + await self._unlock() except (OSError, asyncssh.Error) as exc: raise OSError(f'Error while uploading file {PurePath(localpath)/content_}: {exc}') diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index aff10c4941..6923a31a4b 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -65,7 +65,7 @@ def custom_transport(request, tmp_path_factory, monkeypatch) -> Union['Transport if not filepath_config.exists(): filepath_config.write_text('Host localhost') elif request.param == 'core.ssh_async': - kwargs = {'machine_': 'localhost', 'machine': 'localhost'} + kwargs = {'machine_or_host': 'localhost', 'max_io_allowed': 8} else: kwargs = {} From 6627b21fe308cf4d727d1761b687d6315e8fc4fa Mon Sep 17 00:00:00 2001 From: Ali Date: Fri, 13 Dec 2024 12:35:32 +0100 Subject: [PATCH 27/29] update asyncssh dependency --- environment.yml | 2 +- pyproject.toml | 2 +- src/aiida/transports/plugins/ssh_async.py | 23 ++++++++++++++++------- tests/cmdline/commands/test_computer.py | 10 +++++++++- tests/transports/test_all_plugins.py | 4 +++- 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/environment.yml b/environment.yml index 8862430c45..9ce52da6f5 100644 --- a/environment.yml +++ b/environment.yml @@ -8,7 +8,7 @@ dependencies: - python~=3.9 - alembic~=1.2 - archive-path~=0.4.2 -- asyncssh +- asyncssh~=2.19.0 - circus~=0.18.0 - click-spinner~=0.1.8 - click~=8.1 diff --git a/pyproject.toml b/pyproject.toml index 971f338e89..92afa8986f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ 'alembic~=1.2', 'archive-path~=0.4.2', - "asyncssh", + "asyncssh~=2.19.0", 'circus~=0.18.0', 'click-spinner~=0.1.8', 'click~=8.1', diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index c0d875b8ac..307cffbbfa 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -64,6 +64,8 @@ async def attempt_connection(): class AsyncSshTransport(AsyncTransport): """Transport plugin via SSH, asynchronously.""" + _DEFAULT_max_io_allowed = 8 + # note, I intentionally wanted to keep connection parameters as simple as possible. _valid_auth_options = [ ( @@ -84,7 +86,7 @@ class AsyncSshTransport(AsyncTransport): 'max_io_allowed', { 'type': int, - 'default': 8, + 'default': _DEFAULT_max_io_allowed, 'prompt': 'Maximum number of concurrent I/O operations.', 'help': 'Depends on various factors, such as your network bandwidth, the server load, etc.' ' (An experimental number)', @@ -117,23 +119,30 @@ def _get_machine_suggestion_string(cls, computer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.machine = kwargs.pop('machine_or_host') - self._max_io_allowed = kwargs.pop('max_io_allowed') + # the machine is passed as `machine=computer.hostname` in the codebase + # 'machine' is immutable. + # 'machine_or_host' is mutable, so it can be changed via command: + # 'verdi computer configure core.ssh_async