Skip to content

Commit

Permalink
Allow sendto the same addr if remote_addr is set
Browse files Browse the repository at this point in the history
* Fixes #319.
  • Loading branch information
fantix committed Apr 11, 2020
1 parent e8eb502 commit 1d9267a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
37 changes: 36 additions & 1 deletion tests/test_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def datagram_received(self, data, addr):

s_transport, server = self.loop.run_until_complete(coro)

host, port, *_ = s_transport.get_extra_info('sockname')
remote_addr = s_transport.get_extra_info('sockname')
host, port, *_ = remote_addr

self.assertIsInstance(server, TestMyDatagramProto)
self.assertEqual('INITIALIZED', server.state)
Expand Down Expand Up @@ -86,6 +87,36 @@ def datagram_received(self, data, addr):
# received
self.assertEqual(8, client.nbytes)

# https://github.com/MagicStack/uvloop/issues/319
# uvloop should behave the same as asyncio when given remote_addr
transport.sendto(b'xxx', remote_addr)
tb.run_until(
self.loop, lambda: server.nbytes > 3 or client.done.done())
self.assertEqual(6, server.nbytes)
tb.run_until(self.loop, lambda: client.nbytes > 8)

# received
self.assertEqual(16, client.nbytes)

# reject sendto with a different port
with self.assertRaisesRegex(
ValueError, "Invalid address.*" + repr(remote_addr)
):
bad_addr = list(remote_addr)
bad_addr[1] += 1
bad_addr = tuple(bad_addr)
transport.sendto(b"xxx", bad_addr)

# reject sento with unresolved hostname
if remote_addr[0] != lc_addr[0]:
with self.assertRaisesRegex(
ValueError, "Invalid address.*" + repr(remote_addr)
):
bad_addr = list(remote_addr)
bad_addr[0] = lc_addr[0]
bad_addr = tuple(bad_addr)
transport.sendto(b"xxx", bad_addr)

# extra info is available
self.assertIsNotNone(transport.get_extra_info('sockname'))

Expand All @@ -100,6 +131,10 @@ def test_create_datagram_endpoint_addrs_ipv4(self):
self._test_create_datagram_endpoint_addrs(
socket.AF_INET, ('127.0.0.1', 0))

def test_create_datagram_endpoint_addrs_ipv4_nameaddr(self):
self._test_create_datagram_endpoint_addrs(
socket.AF_INET, ('localhost', 0))

@unittest.skipUnless(tb.has_IPv6, 'no IPv6')
def test_create_datagram_endpoint_addrs_ipv6(self):
self._test_create_datagram_endpoint_addrs(
Expand Down
2 changes: 2 additions & 0 deletions uvloop/handles/udp.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ cdef class UDPTransport(UVBaseTransport):
cdef:
bint __receiving
int _family
object _address

cdef _init(self, Loop loop, unsigned int family)
cdef _set_address(self, system.addrinfo *addr)

cdef _connect(self, system.sockaddr* addr, size_t addr_len)

Expand Down
14 changes: 14 additions & 0 deletions uvloop/handles/udp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ cdef class UDPTransport(UVBaseTransport):
def __cinit__(self):
self._family = uv.AF_UNSPEC
self.__receiving = 0
self._address = None

cdef _init(self, Loop loop, unsigned int family):
cdef int err
Expand All @@ -78,6 +79,9 @@ cdef class UDPTransport(UVBaseTransport):

self._finish_init()

cdef _set_address(self, system.addrinfo *addr):
self._address = __convert_sockaddr_to_pyaddr(addr.ai_addr)

cdef _connect(self, system.sockaddr* addr, size_t addr_len):
cdef int err
err = uv.uv_udp_connect(<uv.uv_udp_t*>self._handle, addr)
Expand Down Expand Up @@ -279,6 +283,16 @@ cdef class UDPTransport(UVBaseTransport):
# Replicating asyncio logic here.
return

if self._address:
if addr not in (None, self._address):
# Replicating asyncio logic here.
raise ValueError(
'Invalid address: must be None or %s' % (self._address,))

# Instead of setting addr to self._address below like what asyncio
# does, we depend on previous uv_udp_connect() to set the address
addr = None

if self._conn_lost:
# Replicating asyncio logic here.
if self._conn_lost >= LOG_THRESHOLD_FOR_CONNLOST_WRITES:
Expand Down
2 changes: 2 additions & 0 deletions uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3004,6 +3004,7 @@ cdef class Loop:
rai = (<AddrInfo>rads).data
udp._init(self, rai.ai_family)
udp._connect(rai.ai_addr, rai.ai_addrlen)
udp._set_address(rai)
else:
if family not in (uv.AF_INET, uv.AF_INET6):
raise ValueError('unexpected address family')
Expand Down Expand Up @@ -3047,6 +3048,7 @@ cdef class Loop:
rai = rai.ai_next
continue
udp._connect(rai.ai_addr, rai.ai_addrlen)
udp._set_address(rai)
break
else:
raise OSError(
Expand Down

0 comments on commit 1d9267a

Please sign in to comment.