From 033ef54302b2b09d496d68ccf39778b9e5fc89e2 Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Wed, 4 Dec 2024 21:27:40 -0800 Subject: [PATCH] Add support for selecting to only allow remote copy on SFTP This commit adds a new "remote_only" argument to the SFTPClient copy() and mcopy() functions to request that the operation only be performed if it can be done using the "remote copy" feature. It also adds a "supports_remote_copy" property to SFTPClient for an application to test if the connected SFTP server supports this function. --- asyncssh/sftp.py | 78 ++++++++++++++++++++++++++++------------------ docs/api.rst | 7 +++-- tests/test_sftp.py | 42 +++++++++++++++++++++++-- 3 files changed, 91 insertions(+), 36 deletions(-) diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 11b39e8..4b8f0ca 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -811,39 +811,34 @@ async def run(self) -> None: self._progress_handler(self._srcpath, self._dstpath, 0, 0) if self._srcfs == self._dstfs and \ - isinstance(self._srcfs, SFTPClient): - try: - await self._srcfs.remote_copy( - cast(SFTPClientFile, self._src), - cast(SFTPClientFile, self._dst)) - except SFTPOpUnsupported: - pass - else: - self._bytes_copied = self._total_bytes + isinstance(self._srcfs, SFTPClient) and \ + self._srcfs.supports_remote_copy: + await self._srcfs.remote_copy(cast(SFTPClientFile, self._src), + cast(SFTPClientFile, self._dst)) - if self._progress_handler: - self._progress_handler(self._srcpath, self._dstpath, - self._bytes_copied, - self._total_bytes) + self._bytes_copied = self._total_bytes - return - - async for _, datalen in self.iter(): - if datalen: - self._bytes_copied += datalen + if self._progress_handler: + self._progress_handler(self._srcpath, self._dstpath, + self._bytes_copied, + self._total_bytes) + else: + async for _, datalen in self.iter(): + if datalen: + self._bytes_copied += datalen - if self._progress_handler: - self._progress_handler(self._srcpath, self._dstpath, - self._bytes_copied, - self._total_bytes) + if self._progress_handler: + self._progress_handler(self._srcpath, self._dstpath, + self._bytes_copied, + self._total_bytes) - if self._bytes_copied != self._total_bytes: - exc = SFTPFailure('Unexpected EOF during file copy') + if self._bytes_copied != self._total_bytes: + exc = SFTPFailure('Unexpected EOF during file copy') - setattr(exc, 'filename', self._srcpath) - setattr(exc, 'offset', self._bytes_copied) + setattr(exc, 'filename', self._srcpath) + setattr(exc, 'offset', self._bytes_copied) - raise exc + raise exc finally: if self._src: # pragma: no branch await self._src.close() @@ -2500,6 +2495,12 @@ def version(self) -> int: return self._version + @property + def supports_copy_data(self) -> bool: + """Return whether or not SFTP remote copy is supported""" + + return self._supports_copy_data + async def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this SFTP client session""" @@ -3678,6 +3679,12 @@ def limits(self) -> SFTPLimits: return self._handler.limits + @property + def supports_remote_copy(self) -> bool: + """Return whether or not SFTP remote copy is supported""" + + return self._handler.supports_copy_data + @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a POSIX-style path""" @@ -4116,7 +4123,8 @@ async def copy(self, srcpaths: _SFTPPaths, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = _MAX_SFTP_REQUESTS, progress_handler: SFTPProgressHandler = None, - error_handler: SFTPErrorHandler = None) -> None: + error_handler: SFTPErrorHandler = None, + remote_only: bool = False) -> None: """Copy remote files to a new location This method copies one or more files or directories on the @@ -4193,6 +4201,8 @@ async def copy(self, srcpaths: _SFTPPaths, The function to call to report copy progress :param error_handler: (optional) The function to call when an error occurs + :param remote_only: (optional) + Whether or not to only allow this to be a remote copy :type srcpaths: :class:`PurePath `, `str`, or `bytes`, or a sequence of these @@ -4205,12 +4215,16 @@ async def copy(self, srcpaths: _SFTPPaths, :type max_requests: `int` :type progress_handler: `callable` :type error_handler: `callable` + :type remote_only: `bool` :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error """ + if remote_only and not self.supports_remote_copy: + raise SFTPOpUnsupported('Remote copy not supported') + await self._begin_copy(self, self, srcpaths, dstpath, 'remote copy', False, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, @@ -4268,8 +4282,9 @@ async def mcopy(self, srcpaths: _SFTPPaths, follow_symlinks: bool = False, block_size: int = -1, max_requests: int = _MAX_SFTP_REQUESTS, progress_handler: SFTPProgressHandler = None, - error_handler: SFTPErrorHandler = None) -> None: - """Download remote files with glob pattern match + error_handler: SFTPErrorHandler = None, + remote_only: bool = False) -> None: + """Copy remote files with glob pattern match This method copies files and directories on the remote system matching one or more glob patterns. @@ -4280,6 +4295,9 @@ async def mcopy(self, srcpaths: _SFTPPaths, """ + if remote_only and not self.supports_remote_copy: + raise SFTPOpUnsupported('Remote copy not supported') + await self._begin_copy(self, self, srcpaths, dstpath, 'remote mcopy', True, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, diff --git a/docs/api.rst b/docs/api.rst index 434e14f..78de0d0 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1069,13 +1069,14 @@ SFTP Support .. autoclass:: SFTPClient() - ======================================================================= = + ======================================= = SFTP client attributes - ======================================================================= = + ======================================= = .. autoattribute:: logger .. autoattribute:: version .. autoattribute:: limits - ======================================================================= = + .. autoattribute:: supports_remote_copy + ======================================= = =========================== = File transfer methods diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 3a5f3d5..59b377d 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -755,11 +755,11 @@ def test_copy_non_remote(self): async def _test_copy_non_remote(self, sftp): """Test copying without using remote_copy function""" - for src in ('src', b'src', Path('src')): - with self.subTest(src=type(src)): + for method in ('copy', 'mcopy'): + with self.subTest(method=method): try: self._create_file('src') - await sftp.copy(src, 'dst') + await sftp.copy('src', 'dst') self._check_file('src', 'dst') finally: remove('src dst') @@ -768,6 +768,23 @@ async def _test_copy_non_remote(self, sftp): # pylint: disable=no-value-for-parameter _test_copy_non_remote(self) + def test_copy_remote_only(self): + """Test copying while allowing only remote copy""" + + @sftp_test + async def _test_copy_remote_only(self, sftp): + """Test copying with only remote copy allowed""" + + for method in ('copy', 'mcopy'): + with self.subTest(method=method): + with self.assertRaises(SFTPOpUnsupported): + await getattr(sftp, method)('src', 'dst', + remote_only=True) + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_copy_remote_only(self) + @sftp_test async def test_copy_progress(self, sftp): """Test copying a file over SFTP with progress reporting""" @@ -1152,6 +1169,25 @@ def err_handler(exc): finally: remove('src1 src2 dst') + def test_remote_copy_unsupported(self): + """Test remote copy on a server which doesn't support it""" + + @sftp_test + async def _test_remote_copy_unsupported(self, sftp): + """Test remote copy not being supported""" + + try: + self._create_file('src') + + with self.assertRaises(SFTPOpUnsupported): + await sftp.remote_copy('src', 'dst') + finally: + remove('src') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_remote_copy_unsupported(self) + @sftp_test async def test_remote_copy_arguments(self, sftp): """Test remote copy arguments"""