From 95756fa4eb54826adb9156b87443d945fd1ed4f8 Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Sat, 14 Dec 2024 13:56:12 -0800 Subject: [PATCH] Change default on SFTP max_requests to avoid excessive memory usage This commit changes the way the default number of parallel SFTP requests is determined. Instead of always defaulting to 128, lower values are used for large block sizes, to reduce memory usage. With larger block sizes, there's no need for as many parallel requests to keep the pipe full. The minimum default is now 16, for block sizes of 256 KB or more. The maximum default is 128, for block sizes of 32 KB or below. --- asyncssh/sftp.py | 45 ++++++++++++++++++++++++++++++++------------- tests/test_sftp.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 209bee3..62238e9 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -3165,7 +3165,6 @@ def __init__(self, handler: SFTPClientHandler, handle: bytes, self._appending = appending self._encoding = encoding self._errors = errors - self._max_requests = max_requests self._offset = None if appending else 0 self.read_len = \ @@ -3173,6 +3172,15 @@ def __init__(self, handler: SFTPClientHandler, handle: bytes, self.write_len = \ handler.limits.max_write_len if block_size == -1 else block_size + if max_requests <= 0: + if self.read_len: + max_requests = max(16, min(MAX_SFTP_READ_LEN // + self.read_len, 128)) + else: + max_requests = 1 + + self._max_requests = max_requests + async def __aenter__(self) -> Self: """Allow SFTPClientFile to be used as an async context manager""" @@ -3859,6 +3867,9 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, block_size = min(srcfs.limits.max_read_len, dstfs.limits.max_write_len) + if max_requests <= 0: + max_requests = max(16, min(MAX_SFTP_READ_LEN // block_size, 128)) + if isinstance(srcpaths, (bytes, str, PurePath)): srcpaths = [srcpaths] elif not isinstance(srcpaths, list): @@ -3916,7 +3927,7 @@ async def get(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS, + max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Download remote files @@ -3957,7 +3968,9 @@ async def get(self, remotepaths: _SFTPPaths, doesn't advertise limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully downloaded. The arguments @@ -4022,7 +4035,7 @@ async def put(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS, + max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Upload local files @@ -4063,7 +4076,9 @@ async def put(self, localpaths: _SFTPPaths, doesn't advertise limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully uploaded. The arguments @@ -4128,7 +4143,7 @@ async def copy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS, + max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, remote_only: bool = False) -> None: @@ -4170,7 +4185,9 @@ async def copy(self, srcpaths: _SFTPPaths, doesn't advertise limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully copied. The arguments @@ -4238,7 +4255,7 @@ async def mget(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS, + max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Download remote files with glob pattern match @@ -4261,7 +4278,7 @@ async def mput(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS, + max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Upload local files with glob pattern match @@ -4284,7 +4301,7 @@ async def mcopy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS, + max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, remote_only: bool = False) -> None: @@ -4586,7 +4603,7 @@ async def open(self, path: _SFTPPath, attrs: SFTPAttrs = SFTPAttrs(), encoding: Optional[str] = 'utf-8', errors: str = 'strict', block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS) -> SFTPClientFile: + max_requests: int = -1) -> SFTPClientFile: """Open a remote file This method opens a remote file and returns an @@ -4662,7 +4679,9 @@ async def open(self, path: _SFTPPath, default of using the server-advertised limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. :param path: The name of the remote file to open @@ -4718,7 +4737,7 @@ async def open56(self, path: _SFTPPath, attrs: SFTPAttrs = SFTPAttrs(), encoding: Optional[str] = 'utf-8', errors: str = 'strict', block_size: int = -1, - max_requests: int = _MAX_SFTP_REQUESTS) -> SFTPClientFile: + max_requests: int = -1) -> SFTPClientFile: """Open a remote file using SFTP v5/v6 flags This method is very similar to :meth:`open`, but the pflags_or_mode diff --git a/tests/test_sftp.py b/tests/test_sftp.py index a853508..82c284d 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -748,6 +748,21 @@ async def test_copy(self, sftp): finally: remove('src dst') + @sftp_test + async def test_copy_max_requests(self, sftp): + """Test copying a file over SFTP with max requests set""" + + for method in ('get', 'put', 'copy'): + for src in ('src', b'src', Path('src')): + with self.subTest(method=method, src=type(src)): + try: + self._create_file('src', 16*1024*1024*'\0') + await getattr(sftp, method)(src, 'dst', + max_requests=4) + self._check_file('src', 'dst') + finally: + remove('src dst') + def test_copy_non_remote(self): """Test copying without using remote_copy function""" @@ -2305,6 +2320,23 @@ async def test_open_read_parallel(self, sftp): remove('file') + @sftp_test + async def test_open_read_max_requests(self, sftp): + """Test reading data from a file with max requests set""" + + f = None + + try: + self._create_file('file', 16*1024*1024*'\0') + + f = await sftp.open('file', max_requests=4) + self.assertEqual(len(await f.read()), 16*1024*1024) + finally: + if f: # pragma: no branch + await f.close() + + remove('file') + def test_open_read_out_of_order(self): """Test parallel read with out-of-order responses"""