diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 3aaa70c..2fb1ac2 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -179,6 +179,7 @@ _ProtocolFactory = Union[_ClientFactory, _ServerFactory] _Conn = TypeVar('_Conn', 'SSHClientConnection', 'SSHServerConnection') +_ConnSelf = TypeVar('_ConnSelf', bound='SSHConnection') class _TunnelProtocol(Protocol): """Base protocol for connections to tunnel SSH over""" @@ -893,7 +894,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._disable_trivial_auth = False - async def __aenter__(self) -> 'SSHConnection': + async def __aenter__(self: _ConnSelf) -> _ConnSelf: """Allow SSHConnection to be used as an async context manager""" return self