From 52c2dc1bc360d604081e0b981d69dc6c553a7b0a Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 5 Aug 2023 13:35:29 -0700 Subject: [PATCH] Fix unintended "wait forever" behavior with zero timeouts [#976]. In a few places we did "if timeout:" or "if expiration:" when we really meant "if timeout is not None:". This meant that in the zero timeout case we fell into the "wait forever" path instead of immediately timing out. In the case of UDP queries, we'd be waiting on recvfrom() and if a packet was lost, then the code would never wake up. (cherry picked from commit 0c183f10c78941a4e72046d4dcb2ecf20083b398) --- dns/_asyncio_backend.py | 2 +- dns/_trio_backend.py | 2 +- dns/asyncquery.py | 2 +- dns/resolver.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 94f751b1b..0021f84fe 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -51,7 +51,7 @@ def close(self): async def _maybe_wait_for(awaitable, timeout): - if timeout: + if timeout is not None: try: return await asyncio.wait_for(awaitable, timeout) except asyncio.TimeoutError: diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 14f05280a..d414f0b37 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -13,7 +13,7 @@ def _maybe_timeout(timeout): - if timeout: + if timeout is not None: return trio.move_on_after(timeout) else: return dns._asyncbackend.NullContext() diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 97295a29f..737e1c922 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -72,7 +72,7 @@ def _source_tuple(af, address, port): def _timeout(expiration, now=None): - if expiration: + if expiration is not None: if not now: now = time.time() return max(expiration - now, 0) diff --git a/dns/resolver.py b/dns/resolver.py index bbac49a50..f08f824d0 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -1697,7 +1697,7 @@ def zone_for_name( while 1: try: rlifetime: Optional[float] - if expiration: + if expiration is not None: rlifetime = expiration - time.time() if rlifetime <= 0: rlifetime = 0